Skip to main content

lellm_agent/tools/
runtime.rs

1//! Agent Runtime — LLM ↔ 工具调用闭环。
2//!
3//! 负责 LLM 返回 tool_calls → 执行工具 → 结果注入 → 再次调用 LLM 的循环,
4//! 直到 LLM 返回纯文本或达到最大轮次。
5
6use lellm_core::{ChatRequest, ChatResponse, LlmError, Message};
7use lellm_provider::ResolvedModel;
8
9use super::executor::ToolExecutor;
10use super::{AgentEvent, AgentStream, StopReason};
11
12/// 工具执行结果
13#[derive(Debug, Clone)]
14pub enum ToolCallResult {
15    Ok(String),
16    Err(String),
17}
18
19/// ToolUseLoop 执行结果
20#[derive(Debug, Clone)]
21pub struct ToolUseResult {
22    pub stop_reason: StopReason,
23    pub response: ChatResponse,
24    pub messages: Vec<Message>,
25    pub iterations: usize,
26    /// 执行过程中调用的工具总次数
27    pub tool_calls_executed: usize,
28}
29
30/// 管理 LLM 与工具调用闭环
31pub struct ToolUseLoop {
32    model: ResolvedModel,
33    executor: ToolExecutor,
34    max_iterations: usize,
35}
36
37impl ToolUseLoop {
38    pub fn new(model: ResolvedModel, executor: ToolExecutor) -> Self {
39        Self {
40            model,
41            executor,
42            max_iterations: 15,
43        }
44    }
45
46    pub fn set_max_iterations(mut self, max: usize) -> Self {
47        self.max_iterations = max;
48        self
49    }
50
51    /// 非流式执行
52    ///
53    /// 语义:
54    /// - `Ok(ToolUseResult)` — Agent 层完成(含 MaxIterationsReached)
55    /// - `Err(LlmError)` — Provider 调用失败
56    pub async fn execute(self, messages: Vec<Message>) -> Result<ToolUseResult, LlmError> {
57        let mut req = ChatRequest {
58            model: self.model.model.clone(),
59            messages,
60            ..Default::default()
61        };
62
63        let mut tool_calls_executed = 0usize;
64        let mut last_response: Option<ChatResponse> = None;
65
66        for iteration in 1..=self.max_iterations {
67            let response = self.model.provider.call(&req).await?;
68            last_response = Some(response);
69
70            if last_response.as_ref().unwrap().tool_calls.is_empty() {
71                return Ok(ToolUseResult {
72                    stop_reason: StopReason::Complete,
73                    response: last_response.unwrap(),
74                    messages: req.messages,
75                    iterations: iteration,
76                    tool_calls_executed,
77                });
78            }
79
80            let tool_calls = last_response.as_ref().unwrap().tool_calls.clone();
81            tool_calls_executed += tool_calls.len();
82
83            req.messages.push(Message::Assistant {
84                content: last_response.as_ref().unwrap().content.clone(),
85            });
86
87            let tool_results = self.executor.execute_batch(&tool_calls).await;
88            req.messages.extend(tool_results);
89
90            tracing::debug!(
91                iteration,
92                tool_calls = tool_calls.len(),
93                "tool-use loop iteration"
94            );
95        }
96
97        // 达到最大轮次 — Agent 层正常终止,不是 Provider 错误
98        Ok(ToolUseResult {
99            stop_reason: StopReason::MaxIterationsReached,
100            response: last_response.unwrap(),
101            messages: req.messages,
102            iterations: self.max_iterations,
103            tool_calls_executed,
104        })
105    }
106
107    /// 流式执行,返回事件接收器
108    ///
109    /// 终态契约:
110    /// - 正常结束:`LoopEnd` 恰好一次,然后 channel 关闭
111    /// - 异常结束:`LoopError` 恰好一次,然后 channel 关闭
112    /// - 终态事件后不再发送任何事件
113    /// - 绝不会发送伪造的 `ToolEnd { tool_call_id: "", .. }`
114    pub fn execute_stream(self, messages: Vec<Message>) -> AgentStream {
115        let (tx, rx) = tokio::sync::mpsc::channel(32);
116        let model = self.model.clone();
117        let executor = self.executor;
118        let max_iterations = self.max_iterations;
119
120        tokio::spawn(async move {
121            let mut req = ChatRequest {
122                model: model.model.clone(),
123                messages,
124                ..Default::default()
125            };
126
127            let mut tool_calls_executed = 0usize;
128            let mut last_response: Option<ChatResponse> = None;
129            let mut completed = false;
130
131            for iteration in 1..=max_iterations {
132                let _ = tx
133                    .send(AgentEvent::Provider(lellm_provider::ProviderEvent::Start {
134                        model: model.model.clone(),
135                    }))
136                    .await;
137
138                match model.provider.stream(&req).await {
139                    Ok(stream) => {
140                        use futures_util::StreamExt;
141                        let mut stream = stream;
142                        let mut text_buffer = String::new();
143                        let mut pending_tool_calls = Vec::<lellm_core::ToolCall>::new();
144
145                        let mut iteration_over = false;
146
147                        while let Some(event) = stream.next().await {
148                            match event {
149                                Ok(lellm_provider::ProviderEvent::Start { .. }) => {
150                                    let _ = tx
151                                        .send(AgentEvent::Provider(
152                                            lellm_provider::ProviderEvent::Start {
153                                                model: model.model.clone(),
154                                            },
155                                        ))
156                                        .await;
157                                }
158                                Ok(lellm_provider::ProviderEvent::Token { token }) => {
159                                    text_buffer.push_str(&token);
160                                    let _ = tx
161                                        .send(AgentEvent::Provider(
162                                            lellm_provider::ProviderEvent::Token { token },
163                                        ))
164                                        .await;
165                                }
166                                Ok(lellm_provider::ProviderEvent::Done { tool_calls, usage }) => {
167                                    pending_tool_calls = tool_calls;
168                                    let usage_val = usage.unwrap_or_default();
169
170                                    // 统一构建 ChatResponse — 无论有无 tool_calls
171                                    let content: Vec<lellm_core::ContentBlock> =
172                                        lellm_core::text_block(text_buffer.clone())
173                                            .into_iter()
174                                            .chain(pending_tool_calls.iter().map(|tc| {
175                                                lellm_core::ContentBlock::ToolCall(tc.clone())
176                                            }))
177                                            .collect();
178
179                                    let response = ChatResponse::new(
180                                        content,
181                                        usage_val,
182                                        serde_json::json!(null),
183                                    );
184
185                                    if !pending_tool_calls.is_empty() {
186                                        req.messages.push(Message::Assistant {
187                                            content: response.content.clone(),
188                                        });
189                                        tool_calls_executed += pending_tool_calls.len();
190
191                                        let mut tool_results = Vec::new();
192                                        for tc in &pending_tool_calls {
193                                            let _ = tx
194                                                .send(AgentEvent::ToolStart {
195                                                    tool_call_id: tc.id.clone(),
196                                                    name: tc.name.clone(),
197                                                })
198                                                .await;
199
200                                            let result = executor.execute(tc).await;
201
202                                            let _ = tx
203                                                .send(AgentEvent::ToolEnd {
204                                                    tool_call_id: tc.id.clone(),
205                                                    result: result.clone(),
206                                                })
207                                                .await;
208
209                                            let content_str = match &result {
210                                                ToolCallResult::Ok(s) => s.clone(),
211                                                ToolCallResult::Err(e) => {
212                                                    format!("tool error: {e}")
213                                                }
214                                            };
215
216                                            tool_results.push(Message::ToolResult {
217                                                tool_call_id: tc.id.clone(),
218                                                content: lellm_core::text_block(content_str),
219                                            });
220                                        }
221                                        req.messages.extend(tool_results);
222
223                                        // 保存为 last_response,供 MaxIterationsReached 使用
224                                        last_response = Some(response);
225
226                                        tracing::debug!(
227                                            iteration,
228                                            tool_calls = pending_tool_calls.len(),
229                                            "tool-use stream iteration"
230                                        );
231                                    } else {
232                                        let _ = tx
233                                            .send(AgentEvent::Provider(
234                                                lellm_provider::ProviderEvent::Done {
235                                                    tool_calls: Vec::new(),
236                                                    usage: Some(response.usage),
237                                                },
238                                            ))
239                                            .await;
240
241                                        let _ = tx
242                                            .send(AgentEvent::LoopEnd {
243                                                result: ToolUseResult {
244                                                    stop_reason: StopReason::Complete,
245                                                    response,
246                                                    messages: req.messages.clone(),
247                                                    iterations: iteration,
248                                                    tool_calls_executed,
249                                                },
250                                            })
251                                            .await;
252
253                                        completed = true;
254                                        iteration_over = true;
255                                        break;
256                                    }
257                                }
258                                Err(e) => {
259                                    let _ = tx
260                                        .send(AgentEvent::LoopError {
261                                            error: e,
262                                            iterations: iteration,
263                                            messages: req.messages.clone(),
264                                        })
265                                        .await;
266                                    iteration_over = true;
267                                    break;
268                                }
269                            }
270                        }
271
272                        if iteration_over {
273                            break;
274                        }
275                    }
276                    Err(e) => {
277                        let _ = tx
278                            .send(AgentEvent::LoopError {
279                                error: e,
280                                iterations: iteration,
281                                messages: req.messages.clone(),
282                            })
283                            .await;
284                        break;
285                    }
286                }
287            }
288
289            // 达到最大轮次 — 仅在未完成时发送
290            if !completed {
291                let response = last_response.unwrap_or_else(|| {
292                    ChatResponse::new(
293                        lellm_core::text_block(String::new()),
294                        lellm_core::TokenUsage::default(),
295                        serde_json::Value::Null,
296                    )
297                });
298                let _ = tx
299                    .send(AgentEvent::LoopEnd {
300                        result: ToolUseResult {
301                            stop_reason: StopReason::MaxIterationsReached,
302                            response,
303                            messages: req.messages,
304                            iterations: max_iterations,
305                            tool_calls_executed,
306                        },
307                    })
308                    .await;
309            }
310        });
311
312        rx
313    }
314}