lellm_agent/tools/runtime.rs
1//! Agent Runtime — LLM ↔ 工具调用闭环。
2//!
3//! 负责 LLM 返回 tool_calls → 执行工具 → 结果注入 → 再次调用 LLM 的循环,
4//! 直到 LLM 返回纯文本或达到最大轮次。
5
6use lellm_core::{ChatRequest, ChatResponse, LlmError, Message};
7use lellm_provider::ResolvedModel;
8
9use super::executor::ToolExecutor;
10use super::{AgentEvent, AgentStream, StopReason};
11
12/// 工具执行结果
13#[derive(Debug, Clone)]
14pub enum ToolCallResult {
15 Ok(String),
16 Err(String),
17}
18
19/// ToolUseLoop 执行结果
20#[derive(Debug, Clone)]
21pub struct ToolUseResult {
22 pub stop_reason: StopReason,
23 pub response: ChatResponse,
24 pub messages: Vec<Message>,
25 pub iterations: usize,
26 /// 执行过程中调用的工具总次数
27 pub tool_calls_executed: usize,
28}
29
30/// 管理 LLM 与工具调用闭环
31pub struct ToolUseLoop {
32 model: ResolvedModel,
33 executor: ToolExecutor,
34 max_iterations: usize,
35}
36
37impl ToolUseLoop {
38 pub fn new(model: ResolvedModel, executor: ToolExecutor) -> Self {
39 Self {
40 model,
41 executor,
42 max_iterations: 15,
43 }
44 }
45
46 pub fn set_max_iterations(mut self, max: usize) -> Self {
47 self.max_iterations = max;
48 self
49 }
50
51 /// 非流式执行
52 ///
53 /// 语义:
54 /// - `Ok(ToolUseResult)` — Agent 层完成(含 MaxIterationsReached)
55 /// - `Err(LlmError)` — Provider 调用失败
56 pub async fn execute(self, messages: Vec<Message>) -> Result<ToolUseResult, LlmError> {
57 let mut req = ChatRequest {
58 model: self.model.model.clone(),
59 messages,
60 ..Default::default()
61 };
62
63 let mut tool_calls_executed = 0usize;
64 let mut last_response: Option<ChatResponse> = None;
65
66 for iteration in 1..=self.max_iterations {
67 let response = self.model.provider.call(&req).await?;
68 last_response = Some(response);
69
70 if last_response.as_ref().unwrap().tool_calls.is_empty() {
71 return Ok(ToolUseResult {
72 stop_reason: StopReason::Complete,
73 response: last_response.unwrap(),
74 messages: req.messages,
75 iterations: iteration,
76 tool_calls_executed,
77 });
78 }
79
80 let tool_calls = last_response.as_ref().unwrap().tool_calls.clone();
81 tool_calls_executed += tool_calls.len();
82
83 req.messages.push(Message::Assistant {
84 content: last_response.as_ref().unwrap().content.clone(),
85 });
86
87 let tool_results = self.executor.execute_batch(&tool_calls).await;
88 req.messages.extend(tool_results);
89
90 tracing::debug!(
91 iteration,
92 tool_calls = tool_calls.len(),
93 "tool-use loop iteration"
94 );
95 }
96
97 // 达到最大轮次 — Agent 层正常终止,不是 Provider 错误
98 Ok(ToolUseResult {
99 stop_reason: StopReason::MaxIterationsReached,
100 response: last_response.unwrap(),
101 messages: req.messages,
102 iterations: self.max_iterations,
103 tool_calls_executed,
104 })
105 }
106
107 /// 流式执行,返回事件接收器
108 ///
109 /// 终态契约:
110 /// - 正常结束:`LoopEnd` 恰好一次,然后 channel 关闭
111 /// - 异常结束:`LoopError` 恰好一次,然后 channel 关闭
112 /// - 终态事件后不再发送任何事件
113 /// - 绝不会发送伪造的 `ToolEnd { tool_call_id: "", .. }`
114 pub fn execute_stream(self, messages: Vec<Message>) -> AgentStream {
115 let (tx, rx) = tokio::sync::mpsc::channel(32);
116 let model = self.model.clone();
117 let executor = self.executor;
118 let max_iterations = self.max_iterations;
119
120 tokio::spawn(async move {
121 let mut req = ChatRequest {
122 model: model.model.clone(),
123 messages,
124 ..Default::default()
125 };
126
127 let mut tool_calls_executed = 0usize;
128 let mut last_response: Option<ChatResponse> = None;
129 let mut completed = false;
130
131 for iteration in 1..=max_iterations {
132 let _ = tx
133 .send(AgentEvent::Provider(lellm_provider::ProviderEvent::Start {
134 model: model.model.clone(),
135 }))
136 .await;
137
138 match model.provider.stream(&req).await {
139 Ok(stream) => {
140 use futures_util::StreamExt;
141 let mut stream = stream;
142 let mut text_buffer = String::new();
143 let mut pending_tool_calls = Vec::<lellm_core::ToolCall>::new();
144
145 let mut iteration_over = false;
146
147 while let Some(event) = stream.next().await {
148 match event {
149 Ok(lellm_provider::ProviderEvent::Start { .. }) => {
150 let _ = tx
151 .send(AgentEvent::Provider(
152 lellm_provider::ProviderEvent::Start {
153 model: model.model.clone(),
154 },
155 ))
156 .await;
157 }
158 Ok(lellm_provider::ProviderEvent::Token { token }) => {
159 text_buffer.push_str(&token);
160 let _ = tx
161 .send(AgentEvent::Provider(
162 lellm_provider::ProviderEvent::Token { token },
163 ))
164 .await;
165 }
166 Ok(lellm_provider::ProviderEvent::Done { tool_calls, usage }) => {
167 pending_tool_calls = tool_calls;
168 let usage_val = usage.unwrap_or_default();
169
170 // 统一构建 ChatResponse — 无论有无 tool_calls
171 let content: Vec<lellm_core::ContentBlock> =
172 lellm_core::text_block(text_buffer.clone())
173 .into_iter()
174 .chain(pending_tool_calls.iter().map(|tc| {
175 lellm_core::ContentBlock::ToolCall(tc.clone())
176 }))
177 .collect();
178
179 let response = ChatResponse::new(
180 content,
181 usage_val,
182 serde_json::json!(null),
183 );
184
185 if !pending_tool_calls.is_empty() {
186 req.messages.push(Message::Assistant {
187 content: response.content.clone(),
188 });
189 tool_calls_executed += pending_tool_calls.len();
190
191 let mut tool_results = Vec::new();
192 for tc in &pending_tool_calls {
193 let _ = tx
194 .send(AgentEvent::ToolStart {
195 tool_call_id: tc.id.clone(),
196 name: tc.name.clone(),
197 })
198 .await;
199
200 let result = executor.execute(tc).await;
201
202 let _ = tx
203 .send(AgentEvent::ToolEnd {
204 tool_call_id: tc.id.clone(),
205 result: result.clone(),
206 })
207 .await;
208
209 let content_str = match &result {
210 ToolCallResult::Ok(s) => s.clone(),
211 ToolCallResult::Err(e) => {
212 format!("tool error: {e}")
213 }
214 };
215
216 tool_results.push(Message::ToolResult {
217 tool_call_id: tc.id.clone(),
218 content: lellm_core::text_block(content_str),
219 });
220 }
221 req.messages.extend(tool_results);
222
223 // 保存为 last_response,供 MaxIterationsReached 使用
224 last_response = Some(response);
225
226 tracing::debug!(
227 iteration,
228 tool_calls = pending_tool_calls.len(),
229 "tool-use stream iteration"
230 );
231 } else {
232 let _ = tx
233 .send(AgentEvent::Provider(
234 lellm_provider::ProviderEvent::Done {
235 tool_calls: Vec::new(),
236 usage: Some(response.usage),
237 },
238 ))
239 .await;
240
241 let _ = tx
242 .send(AgentEvent::LoopEnd {
243 result: ToolUseResult {
244 stop_reason: StopReason::Complete,
245 response,
246 messages: req.messages.clone(),
247 iterations: iteration,
248 tool_calls_executed,
249 },
250 })
251 .await;
252
253 completed = true;
254 iteration_over = true;
255 break;
256 }
257 }
258 Err(e) => {
259 let _ = tx
260 .send(AgentEvent::LoopError {
261 error: e,
262 iterations: iteration,
263 messages: req.messages.clone(),
264 })
265 .await;
266 iteration_over = true;
267 break;
268 }
269 }
270 }
271
272 if iteration_over {
273 break;
274 }
275 }
276 Err(e) => {
277 let _ = tx
278 .send(AgentEvent::LoopError {
279 error: e,
280 iterations: iteration,
281 messages: req.messages.clone(),
282 })
283 .await;
284 break;
285 }
286 }
287 }
288
289 // 达到最大轮次 — 仅在未完成时发送
290 if !completed {
291 let response = last_response.unwrap_or_else(|| {
292 ChatResponse::new(
293 lellm_core::text_block(String::new()),
294 lellm_core::TokenUsage::default(),
295 serde_json::Value::Null,
296 )
297 });
298 let _ = tx
299 .send(AgentEvent::LoopEnd {
300 result: ToolUseResult {
301 stop_reason: StopReason::MaxIterationsReached,
302 response,
303 messages: req.messages,
304 iterations: max_iterations,
305 tool_calls_executed,
306 },
307 })
308 .await;
309 }
310 });
311
312 rx
313 }
314}