1pub mod types;
31pub use types::*;
32
33use crate::agent::backend::LlmBackend;
34use crate::agent::{Message, Role, ToolCallRecord, ToolCallRequest, ToolResultMessage, TokenUsage};
35use crate::tools::ToolRegistry;
36use futures::future::join_all;
37use std::sync::Arc;
38use std::time::Instant;
39use tokio::time::timeout;
40
41fn to_backend_message(msg: &ConversationMessage) -> Message {
56 let tool_result = if msg.role == Role::Tool {
57 msg.tool_call_id.as_ref().map(|id| ToolResultMessage {
58 tool_call_id: id.clone(),
59 content: serde_json::from_str(&msg.content).unwrap_or(serde_json::Value::String(msg.content.clone())),
60 success: true,
61 })
62 } else {
63 None
64 };
65
66 Message {
67 role: msg.role.clone(),
68 content: msg.content.clone(),
69 tool_calls: msg.tool_calls.clone(),
70 tool_result,
71 }
72}
73
74pub struct ToolCoordinator {
99 backend: Arc<dyn LlmBackend>,
100 registry: Arc<ToolRegistry>,
101 config: ToolCallingConfig,
102}
103
104impl ToolCoordinator {
105 pub fn new(
107 backend: Arc<dyn LlmBackend>,
108 registry: Arc<ToolRegistry>,
109 config: ToolCallingConfig,
110 ) -> Self {
111 Self { backend, registry, config }
112 }
113
114 pub async fn execute(
118 &self,
119 system_prompt: Option<&str>,
120 user_prompt: &str,
121 ) -> crate::Result<CoordinatorResult> {
122 let mut messages: Vec<ConversationMessage> = Vec::new();
123 if let Some(sys) = system_prompt {
124 messages.push(ConversationMessage::system(sys));
125 }
126 messages.push(ConversationMessage::user(user_prompt));
127 self.execute_with_history(messages).await
128 }
129
130 pub async fn execute_with_history(
136 &self,
137 mut messages: Vec<ConversationMessage>,
138 ) -> crate::Result<CoordinatorResult> {
139 let tool_defs = self.registry.get_definitions();
140 let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
141 let mut total_usage = TokenUsage::default();
142
143 for iteration in 0..self.config.max_iterations {
144 let backend_messages: Vec<Message> =
146 messages.iter().map(to_backend_message).collect();
147
148 let response = self
150 .backend
151 .generate(&backend_messages, &tool_defs, None)
152 .await?;
153
154 if let Some(usage) = &response.usage {
156 total_usage.prompt_tokens += usage.prompt_tokens;
157 total_usage.completion_tokens += usage.completion_tokens;
158 total_usage.total_tokens += usage.total_tokens;
159 total_usage.reasoning_tokens += usage.reasoning_tokens;
160 total_usage.action_tokens += usage.action_tokens;
161 }
162
163 messages.push(ConversationMessage::assistant(
165 &response.content,
166 response.tool_calls.clone(),
167 ));
168
169 if response.tool_calls.is_empty() {
171 return Ok(CoordinatorResult {
172 content: response.content,
173 tool_calls: all_tool_calls,
174 iterations: iteration + 1,
175 finish_reason: FinishReason::Stop,
176 total_usage,
177 message_history: messages,
178 });
179 }
180
181 if response.content.is_empty() && response.tool_calls.is_empty() {
183 return Ok(CoordinatorResult {
184 content: String::new(),
185 tool_calls: all_tool_calls,
186 iterations: iteration + 1,
187 finish_reason: FinishReason::Stop,
188 total_usage,
189 message_history: messages,
190 });
191 }
192
193 for tc in &response.tool_calls {
195 if !self.registry.has_tool(&tc.name) {
196 return Ok(CoordinatorResult {
197 content: response.content,
198 tool_calls: all_tool_calls,
199 iterations: iteration + 1,
200 finish_reason: FinishReason::UnknownTool(tc.name.clone()),
201 total_usage,
202 message_history: messages,
203 });
204 }
205 }
206
207 let records = self.execute_tool_calls(&response.tool_calls).await?;
209
210 if self.config.stop_on_error {
212 if let Some(failed) = records.iter().find(|r| !r.success) {
213 let err_msg = failed
214 .result
215 .get("error")
216 .and_then(|v| v.as_str())
217 .unwrap_or("tool error")
218 .to_string();
219 return Ok(CoordinatorResult {
220 content: response.content,
221 tool_calls: all_tool_calls,
222 iterations: iteration + 1,
223 finish_reason: FinishReason::Error(err_msg),
224 total_usage,
225 message_history: messages,
226 });
227 }
228 }
229
230 for record in records {
232 messages.push(ConversationMessage::tool_result(&record.id, &record.result));
233 all_tool_calls.push(record);
234 }
235 }
236
237 Ok(CoordinatorResult {
239 content: messages
240 .last()
241 .map(|m| m.content.clone())
242 .unwrap_or_default(),
243 tool_calls: all_tool_calls,
244 iterations: self.config.max_iterations,
245 finish_reason: FinishReason::MaxIterations,
246 total_usage,
247 message_history: messages,
248 })
249 }
250
251 async fn execute_tool_calls(
256 &self,
257 calls: &[ToolCallRequest],
258 ) -> crate::Result<Vec<ToolCallRecord>> {
259 if self.config.parallel_execution {
260 self.execute_parallel(calls).await
261 } else {
262 self.execute_sequential(calls).await
263 }
264 }
265
266 async fn execute_parallel(
267 &self,
268 calls: &[ToolCallRequest],
269 ) -> crate::Result<Vec<ToolCallRecord>> {
270 let futures = calls.iter().map(|c| self.execute_single_tool(c));
271 let results = join_all(futures).await;
272
273 let mut records = Vec::with_capacity(results.len());
274 for (i, res) in results.into_iter().enumerate() {
275 match res {
276 Ok(record) => records.push(record),
277 Err(e) if self.config.stop_on_error => return Err(e),
278 Err(e) => {
279 let call = &calls[i];
281 records.push(ToolCallRecord {
282 id: call.id.clone(),
283 name: call.name.clone(),
284 arguments: call.arguments.clone(),
285 result: serde_json::json!({"error": e.to_string()}),
286 success: false,
287 duration_ms: 0,
288 });
289 }
290 }
291 }
292 Ok(records)
293 }
294
295 async fn execute_sequential(
296 &self,
297 calls: &[ToolCallRequest],
298 ) -> crate::Result<Vec<ToolCallRecord>> {
299 let mut records = Vec::with_capacity(calls.len());
300 for call in calls {
301 match self.execute_single_tool(call).await {
302 Ok(record) => records.push(record),
303 Err(e) if self.config.stop_on_error => return Err(e),
304 Err(e) => {
305 records.push(ToolCallRecord {
306 id: call.id.clone(),
307 name: call.name.clone(),
308 arguments: call.arguments.clone(),
309 result: serde_json::json!({"error": e.to_string()}),
310 success: false,
311 duration_ms: 0,
312 });
313 }
314 }
315 }
316 Ok(records)
317 }
318
319 async fn execute_single_tool(&self, call: &ToolCallRequest) -> crate::Result<ToolCallRecord> {
320 let start = Instant::now();
321
322 let result = timeout(
323 self.config.tool_timeout,
324 self.registry.execute(&call.name, call.arguments.clone()),
325 )
326 .await;
327
328 let duration_ms = start.elapsed().as_millis() as u64;
329
330 match result {
331 Ok(Ok(value)) => Ok(ToolCallRecord {
332 id: call.id.clone(),
333 name: call.name.clone(),
334 arguments: call.arguments.clone(),
335 result: value,
336 success: true,
337 duration_ms,
338 }),
339 Ok(Err(e)) => Ok(ToolCallRecord {
340 id: call.id.clone(),
341 name: call.name.clone(),
342 arguments: call.arguments.clone(),
343 result: serde_json::json!({"error": e.to_string()}),
344 success: false,
345 duration_ms,
346 }),
347 Err(_elapsed) => Ok(ToolCallRecord {
348 id: call.id.clone(),
349 name: call.name.clone(),
350 arguments: call.arguments.clone(),
351 result: serde_json::json!({"error": "tool execution timed out"}),
352 success: false,
353 duration_ms,
354 }),
355 }
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use std::sync::Arc;
363
364 #[tokio::test]
368 async fn execute_with_empty_registry_returns_model_response() {
369 use crate::agent::backend::mock::MockBackend;
370
371 let backend = Arc::new(MockBackend::with_text("Hello, world!"));
372 let registry = Arc::new(ToolRegistry::new());
373 let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
374
375 let result = coordinator
376 .execute(None, "Say hello")
377 .await
378 .expect("coordinator should not error");
379
380 assert_eq!(result.content, "Hello, world!");
381 assert_eq!(result.finish_reason, FinishReason::Stop);
382 assert_eq!(result.iterations, 1);
383 assert!(result.tool_calls.is_empty());
384 assert_eq!(result.message_history.len(), 2);
386 }
387
388 #[test]
390 fn tool_calling_config_defaults_are_sensible() {
391 use std::time::Duration;
392 let cfg = ToolCallingConfig::default();
393 assert_eq!(cfg.max_iterations, 10, "max_iterations default changed");
394 assert!(cfg.parallel_execution, "parallel_execution should default to true");
395 assert_eq!(cfg.tool_timeout, Duration::from_secs(30), "tool_timeout default changed");
396 assert!(!cfg.stop_on_error, "stop_on_error should default to false");
397 }
398
399 #[tokio::test]
404 async fn coordinator_result_captures_finish_reason_max_iterations() {
405 use crate::agent::backend::mock::{MockBackend, MockResponse};
406 use async_trait::async_trait;
407 use crate::tools::Tool;
408 use serde_json::Value;
409
410 struct NoOpTool;
412
413 #[async_trait]
414 impl Tool for NoOpTool {
415 fn name(&self) -> &str { "noop" }
416 fn description(&self) -> &str { "does nothing" }
417 fn parameters_schema(&self) -> Value {
418 serde_json::json!({"type": "object", "properties": {}})
419 }
420 async fn execute(&self, _args: Value) -> crate::Result<Value> {
421 Ok(serde_json::json!({"ok": true}))
422 }
423 }
424
425 let responses: Vec<MockResponse> = (0..15)
428 .map(|_| MockResponse::tool_call("noop", serde_json::json!({})))
429 .collect();
430 let backend = Arc::new(MockBackend::new(responses));
431
432 let mut registry = ToolRegistry::new();
433 registry.register(std::sync::Arc::new(NoOpTool));
434 let registry = Arc::new(registry);
435
436 let config = ToolCallingConfig {
437 max_iterations: 3,
438 parallel_execution: false,
439 ..ToolCallingConfig::default()
440 };
441 let coordinator = ToolCoordinator::new(backend, registry, config);
442
443 let result = coordinator
444 .execute(None, "loop forever")
445 .await
446 .expect("coordinator should not hard-error");
447
448 assert_eq!(
449 result.finish_reason,
450 FinishReason::MaxIterations,
451 "expected MaxIterations, got {:?}",
452 result.finish_reason
453 );
454 assert_eq!(result.iterations, 3);
455 assert_eq!(result.tool_calls.len(), 3);
457 assert!(result.tool_calls.iter().all(|tc| tc.success));
458 }
459}