1use crate::agent::backend::LlmBackend;
31use crate::agent::{Message, Role, ToolCallRecord, ToolCallRequest, ToolResultMessage, TokenUsage};
32use crate::tools::ToolRegistry;
33use futures::future::join_all;
34use serde::{Deserialize, Serialize};
35use std::sync::Arc;
36use std::time::{Duration, Instant};
37use tokio::time::timeout;
38
39#[derive(Debug, Clone)]
44pub struct ToolCallingConfig {
45 pub max_iterations: usize,
49
50 pub parallel_execution: bool,
54
55 pub tool_timeout: Duration,
57
58 pub stop_on_error: bool,
63}
64
65impl Default for ToolCallingConfig {
66 fn default() -> Self {
67 Self {
68 max_iterations: 10,
69 parallel_execution: true,
70 tool_timeout: Duration::from_secs(30),
71 stop_on_error: false,
72 }
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
78pub enum FinishReason {
79 Stop,
81 MaxIterations,
83 Error(String),
86 UnknownTool(String),
88}
89
90impl std::fmt::Display for FinishReason {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 match self {
93 FinishReason::Stop => write!(f, "stop"),
94 FinishReason::MaxIterations => write!(f, "max_iterations"),
95 FinishReason::Error(e) => write!(f, "error: {}", e),
96 FinishReason::UnknownTool(t) => write!(f, "unknown_tool: {}", t),
97 }
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
103#[serde(rename_all = "lowercase")]
104pub enum MessageRole {
105 System,
107 User,
109 Assistant,
111 Tool,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ConversationMessage {
124 pub role: MessageRole,
126
127 pub content: String,
129
130 #[serde(default, skip_serializing_if = "Vec::is_empty")]
132 pub tool_calls: Vec<ToolCallRequest>,
133
134 #[serde(default, skip_serializing_if = "Option::is_none")]
137 pub tool_call_id: Option<String>,
138}
139
140impl ConversationMessage {
141 pub fn system(content: impl Into<String>) -> Self {
143 Self {
144 role: MessageRole::System,
145 content: content.into(),
146 tool_calls: Vec::new(),
147 tool_call_id: None,
148 }
149 }
150
151 pub fn user(content: impl Into<String>) -> Self {
153 Self {
154 role: MessageRole::User,
155 content: content.into(),
156 tool_calls: Vec::new(),
157 tool_call_id: None,
158 }
159 }
160
161 pub fn assistant(content: impl Into<String>, tool_calls: Vec<ToolCallRequest>) -> Self {
163 Self {
164 role: MessageRole::Assistant,
165 content: content.into(),
166 tool_calls,
167 tool_call_id: None,
168 }
169 }
170
171 pub fn tool_result(tool_call_id: impl Into<String>, result: &serde_json::Value) -> Self {
174 Self {
175 role: MessageRole::Tool,
176 content: serde_json::to_string(result).unwrap_or_else(|_| "{}".to_string()),
177 tool_calls: Vec::new(),
178 tool_call_id: Some(tool_call_id.into()),
179 }
180 }
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct CoordinatorResult {
190 pub content: String,
192
193 pub tool_calls: Vec<ToolCallRecord>,
195
196 pub iterations: usize,
198
199 pub finish_reason: FinishReason,
201
202 pub total_usage: TokenUsage,
204
205 pub message_history: Vec<ConversationMessage>,
208}
209
210fn to_backend_message(msg: &ConversationMessage) -> Message {
225 let role = match msg.role {
226 MessageRole::System => Role::System,
227 MessageRole::User => Role::User,
228 MessageRole::Assistant => Role::Assistant,
229 MessageRole::Tool => Role::Tool,
230 };
231
232 let tool_result = if msg.role == MessageRole::Tool {
233 msg.tool_call_id.as_ref().map(|id| ToolResultMessage {
234 tool_call_id: id.clone(),
235 content: serde_json::from_str(&msg.content).unwrap_or(serde_json::Value::String(msg.content.clone())),
236 success: true,
237 })
238 } else {
239 None
240 };
241
242 Message {
243 role,
244 content: msg.content.clone(),
245 tool_calls: msg.tool_calls.clone(),
246 tool_result,
247 }
248}
249
250pub struct ToolCoordinator {
275 backend: Arc<dyn LlmBackend>,
276 registry: Arc<ToolRegistry>,
277 config: ToolCallingConfig,
278}
279
280impl ToolCoordinator {
281 pub fn new(
283 backend: Arc<dyn LlmBackend>,
284 registry: Arc<ToolRegistry>,
285 config: ToolCallingConfig,
286 ) -> Self {
287 Self { backend, registry, config }
288 }
289
290 pub async fn execute(
294 &self,
295 system_prompt: Option<&str>,
296 user_prompt: &str,
297 ) -> crate::Result<CoordinatorResult> {
298 let mut messages: Vec<ConversationMessage> = Vec::new();
299 if let Some(sys) = system_prompt {
300 messages.push(ConversationMessage::system(sys));
301 }
302 messages.push(ConversationMessage::user(user_prompt));
303 self.execute_with_history(messages).await
304 }
305
306 pub async fn execute_with_history(
312 &self,
313 mut messages: Vec<ConversationMessage>,
314 ) -> crate::Result<CoordinatorResult> {
315 let tool_defs = self.registry.get_definitions();
316 let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
317 let mut total_usage = TokenUsage::default();
318
319 for iteration in 0..self.config.max_iterations {
320 let backend_messages: Vec<Message> =
322 messages.iter().map(to_backend_message).collect();
323
324 let response = self
326 .backend
327 .generate(&backend_messages, &tool_defs, None)
328 .await?;
329
330 if let Some(usage) = &response.usage {
332 total_usage.prompt_tokens += usage.prompt_tokens;
333 total_usage.completion_tokens += usage.completion_tokens;
334 total_usage.total_tokens += usage.total_tokens;
335 total_usage.reasoning_tokens += usage.reasoning_tokens;
336 total_usage.action_tokens += usage.action_tokens;
337 }
338
339 messages.push(ConversationMessage::assistant(
341 &response.content,
342 response.tool_calls.clone(),
343 ));
344
345 if response.tool_calls.is_empty() {
347 return Ok(CoordinatorResult {
348 content: response.content,
349 tool_calls: all_tool_calls,
350 iterations: iteration + 1,
351 finish_reason: FinishReason::Stop,
352 total_usage,
353 message_history: messages,
354 });
355 }
356
357 if response.content.is_empty() && response.tool_calls.is_empty() {
359 return Ok(CoordinatorResult {
360 content: String::new(),
361 tool_calls: all_tool_calls,
362 iterations: iteration + 1,
363 finish_reason: FinishReason::Stop,
364 total_usage,
365 message_history: messages,
366 });
367 }
368
369 for tc in &response.tool_calls {
371 if !self.registry.has_tool(&tc.name) {
372 return Ok(CoordinatorResult {
373 content: response.content,
374 tool_calls: all_tool_calls,
375 iterations: iteration + 1,
376 finish_reason: FinishReason::UnknownTool(tc.name.clone()),
377 total_usage,
378 message_history: messages,
379 });
380 }
381 }
382
383 let records = self.execute_tool_calls(&response.tool_calls).await?;
385
386 if self.config.stop_on_error {
388 if let Some(failed) = records.iter().find(|r| !r.success) {
389 let err_msg = failed
390 .result
391 .get("error")
392 .and_then(|v| v.as_str())
393 .unwrap_or("tool error")
394 .to_string();
395 return Ok(CoordinatorResult {
396 content: response.content,
397 tool_calls: all_tool_calls,
398 iterations: iteration + 1,
399 finish_reason: FinishReason::Error(err_msg),
400 total_usage,
401 message_history: messages,
402 });
403 }
404 }
405
406 for record in records {
408 messages.push(ConversationMessage::tool_result(&record.id, &record.result));
409 all_tool_calls.push(record);
410 }
411 }
412
413 Ok(CoordinatorResult {
415 content: messages
416 .last()
417 .map(|m| m.content.clone())
418 .unwrap_or_default(),
419 tool_calls: all_tool_calls,
420 iterations: self.config.max_iterations,
421 finish_reason: FinishReason::MaxIterations,
422 total_usage,
423 message_history: messages,
424 })
425 }
426
427 async fn execute_tool_calls(
432 &self,
433 calls: &[ToolCallRequest],
434 ) -> crate::Result<Vec<ToolCallRecord>> {
435 if self.config.parallel_execution {
436 self.execute_parallel(calls).await
437 } else {
438 self.execute_sequential(calls).await
439 }
440 }
441
442 async fn execute_parallel(
443 &self,
444 calls: &[ToolCallRequest],
445 ) -> crate::Result<Vec<ToolCallRecord>> {
446 let futures = calls.iter().map(|c| self.execute_single_tool(c));
447 let results = join_all(futures).await;
448
449 let mut records = Vec::with_capacity(results.len());
450 for (i, res) in results.into_iter().enumerate() {
451 match res {
452 Ok(record) => records.push(record),
453 Err(e) if self.config.stop_on_error => return Err(e),
454 Err(e) => {
455 let call = &calls[i];
457 records.push(ToolCallRecord {
458 id: call.id.clone(),
459 name: call.name.clone(),
460 arguments: call.arguments.clone(),
461 result: serde_json::json!({"error": e.to_string()}),
462 success: false,
463 duration_ms: 0,
464 });
465 }
466 }
467 }
468 Ok(records)
469 }
470
471 async fn execute_sequential(
472 &self,
473 calls: &[ToolCallRequest],
474 ) -> crate::Result<Vec<ToolCallRecord>> {
475 let mut records = Vec::with_capacity(calls.len());
476 for call in calls {
477 match self.execute_single_tool(call).await {
478 Ok(record) => records.push(record),
479 Err(e) if self.config.stop_on_error => return Err(e),
480 Err(e) => {
481 records.push(ToolCallRecord {
482 id: call.id.clone(),
483 name: call.name.clone(),
484 arguments: call.arguments.clone(),
485 result: serde_json::json!({"error": e.to_string()}),
486 success: false,
487 duration_ms: 0,
488 });
489 }
490 }
491 }
492 Ok(records)
493 }
494
495 async fn execute_single_tool(&self, call: &ToolCallRequest) -> crate::Result<ToolCallRecord> {
496 let start = Instant::now();
497
498 let result = timeout(
499 self.config.tool_timeout,
500 self.registry.execute(&call.name, call.arguments.clone()),
501 )
502 .await;
503
504 let duration_ms = start.elapsed().as_millis() as u64;
505
506 match result {
507 Ok(Ok(value)) => Ok(ToolCallRecord {
508 id: call.id.clone(),
509 name: call.name.clone(),
510 arguments: call.arguments.clone(),
511 result: value,
512 success: true,
513 duration_ms,
514 }),
515 Ok(Err(e)) => Ok(ToolCallRecord {
516 id: call.id.clone(),
517 name: call.name.clone(),
518 arguments: call.arguments.clone(),
519 result: serde_json::json!({"error": e.to_string()}),
520 success: false,
521 duration_ms,
522 }),
523 Err(_elapsed) => Ok(ToolCallRecord {
524 id: call.id.clone(),
525 name: call.name.clone(),
526 arguments: call.arguments.clone(),
527 result: serde_json::json!({"error": "tool execution timed out"}),
528 success: false,
529 duration_ms,
530 }),
531 }
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538 use serde_json::json;
539
540 #[test]
541 fn tool_calling_config_default_values() {
542 let cfg = ToolCallingConfig::default();
543 assert_eq!(cfg.max_iterations, 10);
544 assert!(cfg.parallel_execution);
545 assert_eq!(cfg.tool_timeout, Duration::from_secs(30));
546 assert!(!cfg.stop_on_error);
547 }
548
549 #[test]
550 fn finish_reason_display_matches_snake_case_contract() {
551 assert_eq!(FinishReason::Stop.to_string(), "stop");
552 assert_eq!(FinishReason::MaxIterations.to_string(), "max_iterations");
553 assert_eq!(
554 FinishReason::Error("boom".into()).to_string(),
555 "error: boom"
556 );
557 assert_eq!(
558 FinishReason::UnknownTool("ghost".into()).to_string(),
559 "unknown_tool: ghost"
560 );
561 }
562
563 #[test]
564 fn finish_reason_round_trips_through_json() {
565 for variant in [
566 FinishReason::Stop,
567 FinishReason::MaxIterations,
568 FinishReason::Error("oops".into()),
569 FinishReason::UnknownTool("nope".into()),
570 ] {
571 let encoded = serde_json::to_string(&variant).unwrap();
572 let decoded: FinishReason = serde_json::from_str(&encoded).unwrap();
573 assert_eq!(
574 decoded, variant,
575 "{} did not round-trip through JSON",
576 variant
577 );
578 }
579 }
580
581 #[test]
582 fn message_role_serializes_as_lowercase_string() {
583 assert_eq!(
584 serde_json::to_string(&MessageRole::System).unwrap(),
585 "\"system\""
586 );
587 assert_eq!(
588 serde_json::to_string(&MessageRole::User).unwrap(),
589 "\"user\""
590 );
591 assert_eq!(
592 serde_json::to_string(&MessageRole::Assistant).unwrap(),
593 "\"assistant\""
594 );
595 assert_eq!(
596 serde_json::to_string(&MessageRole::Tool).unwrap(),
597 "\"tool\""
598 );
599 }
600
601 #[test]
602 fn conversation_message_system_builder_sets_role_and_content() {
603 let msg = ConversationMessage::system("you are an assistant");
604 assert_eq!(msg.role, MessageRole::System);
605 assert_eq!(msg.content, "you are an assistant");
606 assert!(msg.tool_calls.is_empty());
607 assert!(msg.tool_call_id.is_none());
608 }
609
610 #[test]
611 fn conversation_message_user_builder_sets_role_and_content() {
612 let msg = ConversationMessage::user("what is 2 + 2?");
613 assert_eq!(msg.role, MessageRole::User);
614 assert_eq!(msg.content, "what is 2 + 2?");
615 assert!(msg.tool_calls.is_empty());
616 assert!(msg.tool_call_id.is_none());
617 }
618
619 #[test]
620 fn conversation_message_assistant_builder_preserves_tool_calls() {
621 let calls = vec![ToolCallRequest {
622 id: "call_1".into(),
623 name: "search".into(),
624 arguments: json!({"q": "rust"}),
625 }];
626 let msg = ConversationMessage::assistant("let me search", calls.clone());
627 assert_eq!(msg.role, MessageRole::Assistant);
628 assert_eq!(msg.content, "let me search");
629 assert_eq!(msg.tool_calls.len(), 1);
630 assert_eq!(msg.tool_calls[0].id, "call_1");
631 assert_eq!(msg.tool_calls[0].name, "search");
632 assert!(msg.tool_call_id.is_none());
633 }
634
635 #[test]
636 fn conversation_message_tool_result_serializes_result_into_content() {
637 let result = json!({"answer": 42, "units": "none"});
638 let msg = ConversationMessage::tool_result("call_1", &result);
639 assert_eq!(msg.role, MessageRole::Tool);
640 assert_eq!(msg.tool_call_id.as_deref(), Some("call_1"));
641 assert!(msg.tool_calls.is_empty());
642 let parsed: serde_json::Value = serde_json::from_str(&msg.content).unwrap();
644 assert_eq!(parsed, result);
645 }
646
647 #[test]
648 fn conversation_message_tool_result_falls_back_on_serialize_failure() {
649 let msg = ConversationMessage::tool_result("call_1", &json!(null));
654 assert_eq!(msg.content, "null");
655 }
656
657 #[test]
658 fn conversation_message_serde_skips_empty_tool_calls_and_none_id() {
659 let msg = ConversationMessage::user("hi");
660 let encoded = serde_json::to_string(&msg).unwrap();
661 assert!(!encoded.contains("tool_calls"));
663 assert!(!encoded.contains("tool_call_id"));
664 assert!(encoded.contains("\"role\":\"user\""));
665 assert!(encoded.contains("\"content\":\"hi\""));
666 }
667
668 #[test]
669 fn coordinator_result_round_trips_through_json() {
670 let result = CoordinatorResult {
671 content: "done".into(),
672 tool_calls: vec![ToolCallRecord {
673 id: "call_1".into(),
674 name: "echo".into(),
675 arguments: json!({"text": "hi"}),
676 result: json!({"text": "hi"}),
677 success: true,
678 duration_ms: 12,
679 }],
680 iterations: 2,
681 finish_reason: FinishReason::Stop,
682 total_usage: TokenUsage {
683 prompt_tokens: 100,
684 completion_tokens: 20,
685 total_tokens: 120,
686 reasoning_tokens: 0,
687 action_tokens: 20,
688 },
689 message_history: vec![
690 ConversationMessage::system("be brief"),
691 ConversationMessage::user("echo hi"),
692 ConversationMessage::assistant(
693 "",
694 vec![ToolCallRequest {
695 id: "call_1".into(),
696 name: "echo".into(),
697 arguments: json!({"text": "hi"}),
698 }],
699 ),
700 ConversationMessage::tool_result("call_1", &json!({"text": "hi"})),
701 ConversationMessage::assistant("done", vec![]),
702 ],
703 };
704 let encoded = serde_json::to_string(&result).unwrap();
705 let decoded: CoordinatorResult = serde_json::from_str(&encoded).unwrap();
706 assert_eq!(decoded.content, "done");
707 assert_eq!(decoded.iterations, 2);
708 assert_eq!(decoded.finish_reason, FinishReason::Stop);
709 assert_eq!(decoded.tool_calls.len(), 1);
710 assert_eq!(decoded.tool_calls[0].id, "call_1");
711 assert_eq!(decoded.message_history.len(), 5);
712 assert_eq!(decoded.total_usage.total_tokens, 120);
713 }
714
715 #[tokio::test]
723 async fn execute_with_empty_registry_returns_model_response() {
724 use crate::agent::backend::mock::MockBackend;
725
726 let backend = Arc::new(MockBackend::with_text("Hello, world!"));
727 let registry = Arc::new(ToolRegistry::new());
728 let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
729
730 let result = coordinator
731 .execute(None, "Say hello")
732 .await
733 .expect("coordinator should not error");
734
735 assert_eq!(result.content, "Hello, world!");
736 assert_eq!(result.finish_reason, FinishReason::Stop);
737 assert_eq!(result.iterations, 1);
738 assert!(result.tool_calls.is_empty());
739 assert_eq!(result.message_history.len(), 2);
741 }
742
743 #[test]
745 fn tool_calling_config_defaults_are_sensible() {
746 let cfg = ToolCallingConfig::default();
747 assert_eq!(cfg.max_iterations, 10, "max_iterations default changed");
748 assert!(cfg.parallel_execution, "parallel_execution should default to true");
749 assert_eq!(cfg.tool_timeout, Duration::from_secs(30), "tool_timeout default changed");
750 assert!(!cfg.stop_on_error, "stop_on_error should default to false");
751 }
752
753 #[tokio::test]
758 async fn coordinator_result_captures_finish_reason_max_iterations() {
759 use crate::agent::backend::mock::{MockBackend, MockResponse};
760 use async_trait::async_trait;
761 use crate::tools::Tool;
762 use serde_json::Value;
763
764 struct NoOpTool;
766
767 #[async_trait]
768 impl Tool for NoOpTool {
769 fn name(&self) -> &str { "noop" }
770 fn description(&self) -> &str { "does nothing" }
771 fn parameters_schema(&self) -> Value {
772 serde_json::json!({"type": "object", "properties": {}})
773 }
774 async fn execute(&self, _args: Value) -> crate::Result<Value> {
775 Ok(serde_json::json!({"ok": true}))
776 }
777 }
778
779 let responses: Vec<MockResponse> = (0..15)
782 .map(|_| MockResponse::tool_call("noop", serde_json::json!({})))
783 .collect();
784 let backend = Arc::new(MockBackend::new(responses));
785
786 let mut registry = ToolRegistry::new();
787 registry.register(std::sync::Arc::new(NoOpTool));
788 let registry = Arc::new(registry);
789
790 let config = ToolCallingConfig {
791 max_iterations: 3,
792 parallel_execution: false,
793 ..ToolCallingConfig::default()
794 };
795 let coordinator = ToolCoordinator::new(backend, registry, config);
796
797 let result = coordinator
798 .execute(None, "loop forever")
799 .await
800 .expect("coordinator should not hard-error");
801
802 assert_eq!(
803 result.finish_reason,
804 FinishReason::MaxIterations,
805 "expected MaxIterations, got {:?}",
806 result.finish_reason
807 );
808 assert_eq!(result.iterations, 3);
809 assert_eq!(result.tool_calls.len(), 3);
811 assert!(result.tool_calls.iter().all(|tc| tc.success));
812 }
813}