Skip to main content

ares/llm/
coordinator.rs

1//! Generic Tool Coordinator for Multi-Turn Tool Calling
2//!
3//! This module provides a provider-agnostic `ToolCoordinator` that works with any
4//! `LLMClient` implementation. It handles the complete tool calling loop:
5//!
6//! 1. Send prompt with available tools to the LLM
7//! 2. If the model requests tool calls, execute them
8//! 3. Send tool results back to the model  
9//! 4. Repeat until completion or max iterations
10//!
11//! # Example
12//!
13//! ```rust,ignore
14//! use ares::llm::coordinator::{ToolCoordinator, ToolCallingConfig};
15//! use ares::llm::Provider;
16//! use ares::tools::ToolRegistry;
17//! use std::sync::Arc;
18//!
19//! let client = Provider::from_env()?.create_client().await?;
20//! let registry = Arc::new(ToolRegistry::new());
21//! let coordinator = ToolCoordinator::new(client, registry, ToolCallingConfig::default());
22//!
23//! let result = coordinator.execute(
24//!     Some("You are a helpful assistant."),
25//!     "What's 2 + 2?"
26//! ).await?;
27//!
28//! println!("Response: {}", result.content);
29//! println!("Tool calls made: {}", result.tool_calls.len());
30//! ```
31
32use crate::llm::client::{LLMClient, TokenUsage};
33use crate::tools::registry::ToolRegistry;
34use crate::types::{Result, ToolCall};
35use futures::future::join_all;
36use serde::{Deserialize, Serialize};
37use std::sync::Arc;
38use std::time::{Duration, Instant};
39use tokio::time::timeout;
40
41/// Configuration for tool calling coordination behavior.
42///
43/// Controls how the coordinator handles multi-turn tool calling,
44/// including iteration limits, parallelism, and timeout settings.
45#[derive(Debug, Clone)]
46pub struct ToolCallingConfig {
47    /// Maximum number of LLM iterations (not tool calls) before stopping.
48    /// Each iteration is one round-trip to the LLM.
49    pub max_iterations: usize,
50
51    /// Whether to execute multiple tool calls in parallel.
52    /// When false, tools are executed sequentially.
53    pub parallel_execution: bool,
54
55    /// Timeout for individual tool execution.
56    pub tool_timeout: Duration,
57
58    /// Whether to include tool results in the final response context.
59    pub include_tool_results: bool,
60
61    /// Whether to stop on the first tool error, or continue with remaining tools.
62    pub stop_on_error: bool,
63}
64
65impl Default for ToolCallingConfig {
66    fn default() -> Self {
67        Self {
68            max_iterations: 10,
69            parallel_execution: true,
70            tool_timeout: Duration::from_secs(30),
71            include_tool_results: true,
72            stop_on_error: false,
73        }
74    }
75}
76
77/// Record of a single tool call execution.
78///
79/// Captures all details about a tool invocation including timing,
80/// success status, and any errors that occurred.
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ToolCallRecord {
83    /// Unique identifier for this tool call (from the LLM).
84    pub id: String,
85    /// Name of the tool that was called.
86    pub name: String,
87    /// Arguments passed to the tool.
88    pub arguments: serde_json::Value,
89    /// Result returned by the tool (or error object).
90    pub result: serde_json::Value,
91    /// Whether the tool execution was successful.
92    pub success: bool,
93    /// Time taken to execute the tool in milliseconds.
94    pub duration_ms: u64,
95    /// Error message if the tool failed.
96    pub error: Option<String>,
97}
98
99/// Reason why a tool coordination session ended.
100#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
101pub enum FinishReason {
102    /// Model decided to stop (no more tool calls).
103    Stop,
104    /// Hit the maximum iterations limit.
105    MaxIterations,
106    /// An unrecoverable error occurred.
107    Error(String),
108    /// Model tried to call an unknown tool.
109    UnknownTool(String),
110}
111
112impl std::fmt::Display for FinishReason {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        match self {
115            FinishReason::Stop => write!(f, "stop"),
116            FinishReason::MaxIterations => write!(f, "max_iterations"),
117            FinishReason::Error(e) => write!(f, "error: {}", e),
118            FinishReason::UnknownTool(t) => write!(f, "unknown_tool: {}", t),
119        }
120    }
121}
122
123/// A message in a tool-calling conversation.
124///
125/// Represents all message types that can appear in a multi-turn
126/// conversation with tool calling.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ConversationMessage {
129    /// The role of the message sender.
130    pub role: MessageRole,
131    /// The text content of the message.
132    pub content: String,
133    /// Tool calls requested by the assistant (only for Assistant role).
134    #[serde(default, skip_serializing_if = "Vec::is_empty")]
135    pub tool_calls: Vec<ToolCall>,
136    /// Tool result content (only for Tool role).
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub tool_call_id: Option<String>,
139}
140
141/// Role of a message sender in a tool-calling conversation.
142#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
143#[serde(rename_all = "lowercase")]
144pub enum MessageRole {
145    /// System instructions.
146    System,
147    /// User message.
148    User,
149    /// Assistant response.
150    Assistant,
151    /// Tool execution result.
152    Tool,
153}
154
155impl ConversationMessage {
156    /// Create a system message.
157    pub fn system(content: impl Into<String>) -> Self {
158        Self {
159            role: MessageRole::System,
160            content: content.into(),
161            tool_calls: Vec::new(),
162            tool_call_id: None,
163        }
164    }
165
166    /// Create a user message.
167    pub fn user(content: impl Into<String>) -> Self {
168        Self {
169            role: MessageRole::User,
170            content: content.into(),
171            tool_calls: Vec::new(),
172            tool_call_id: None,
173        }
174    }
175
176    /// Create an assistant message with optional tool calls.
177    pub fn assistant(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
178        Self {
179            role: MessageRole::Assistant,
180            content: content.into(),
181            tool_calls,
182            tool_call_id: None,
183        }
184    }
185
186    /// Create a tool result message.
187    pub fn tool_result(tool_call_id: impl Into<String>, result: &serde_json::Value) -> Self {
188        Self {
189            role: MessageRole::Tool,
190            content: serde_json::to_string(result).unwrap_or_else(|_| "{}".to_string()),
191            tool_calls: Vec::new(),
192            tool_call_id: Some(tool_call_id.into()),
193        }
194    }
195
196    /// Convert to the simple (role, content) format for LLMClient::generate_with_history.
197    pub fn to_role_content(&self) -> (String, String) {
198        let role = match self.role {
199            MessageRole::System => "system",
200            MessageRole::User => "user",
201            MessageRole::Assistant => "assistant",
202            MessageRole::Tool => "tool",
203        };
204        (role.to_string(), self.content.clone())
205    }
206}
207
208/// Result of a complete tool coordination session.
209///
210/// Contains all information about what happened during the multi-turn
211/// conversation, including the final response, all tool calls made,
212/// token usage, and message history.
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct CoordinatorResult {
215    /// Final text response from the model.
216    pub content: String,
217
218    /// All tool calls made during the session.
219    pub tool_calls: Vec<ToolCallRecord>,
220
221    /// Number of LLM iterations (round-trips) performed.
222    pub iterations: usize,
223
224    /// Why the session ended.
225    pub finish_reason: FinishReason,
226
227    /// Accumulated token usage across all iterations.
228    pub total_usage: TokenUsage,
229
230    /// Full message history (useful for debugging and training data).
231    pub message_history: Vec<ConversationMessage>,
232}
233
234/// Generic tool coordinator that works with any LLMClient.
235///
236/// Manages multi-turn tool calling conversations by:
237/// 1. Sending prompts with tool definitions to the LLM
238/// 2. Parsing tool call requests from the response
239/// 3. Executing tools and collecting results
240/// 4. Sending results back to the LLM
241/// 5. Repeating until the LLM produces a final response
242///
243/// # Type Parameters
244///
245/// The coordinator is generic over the LLMClient, but typically you'll use
246/// it with `Box<dyn LLMClient>` for maximum flexibility.
247pub struct ToolCoordinator {
248    client: Box<dyn LLMClient>,
249    registry: Arc<ToolRegistry>,
250    config: ToolCallingConfig,
251}
252
253impl ToolCoordinator {
254    /// Create a new ToolCoordinator with the given client, registry, and config.
255    pub fn new(
256        client: Box<dyn LLMClient>,
257        registry: Arc<ToolRegistry>,
258        config: ToolCallingConfig,
259    ) -> Self {
260        Self {
261            client,
262            registry,
263            config,
264        }
265    }
266
267    /// Create a new ToolCoordinator with default configuration.
268    pub fn with_defaults(client: Box<dyn LLMClient>, registry: Arc<ToolRegistry>) -> Self {
269        Self::new(client, registry, ToolCallingConfig::default())
270    }
271
272    /// Execute a complete tool-calling conversation loop.
273    ///
274    /// This method handles the full tool calling loop:
275    /// 1. Send the initial prompt with available tools
276    /// 2. If the model requests tool calls, execute them
277    /// 3. Send tool results back to the model
278    /// 4. Repeat until the model produces a final response or max iterations reached
279    ///
280    /// # Arguments
281    ///
282    /// * `system` - Optional system prompt
283    /// * `prompt` - The user's prompt
284    ///
285    /// # Returns
286    ///
287    /// A `CoordinatorResult` containing the final response, all tool calls made,
288    /// and execution metadata.
289    pub async fn execute(&self, system: Option<&str>, prompt: &str) -> Result<CoordinatorResult> {
290        let tools = self.registry.get_tool_definitions();
291        let mut messages: Vec<ConversationMessage> = Vec::new();
292        let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
293        let mut total_usage = TokenUsage::default();
294
295        // Add system message if provided
296        if let Some(sys) = system {
297            messages.push(ConversationMessage::system(sys));
298        }
299
300        // Add user message
301        messages.push(ConversationMessage::user(prompt));
302
303        for iteration in 0..self.config.max_iterations {
304            // Call LLM with tools
305            let response = self
306                .client
307                .generate_with_tools_and_history(&messages, &tools)
308                .await?;
309
310            // Accumulate usage
311            if let Some(usage) = &response.usage {
312                total_usage = TokenUsage::new(
313                    total_usage.prompt_tokens + usage.prompt_tokens,
314                    total_usage.completion_tokens + usage.completion_tokens,
315                );
316            }
317
318            // Add assistant message to history
319            messages.push(ConversationMessage::assistant(
320                &response.content,
321                response.tool_calls.clone(),
322            ));
323
324            // Check if we're done (no tool calls)
325            if response.tool_calls.is_empty() {
326                return Ok(CoordinatorResult {
327                    content: response.content,
328                    tool_calls: all_tool_calls,
329                    iterations: iteration + 1,
330                    finish_reason: FinishReason::Stop,
331                    total_usage,
332                    message_history: messages,
333                });
334            }
335
336            // Validate that all requested tools exist
337            for tool_call in &response.tool_calls {
338                if !self.registry.has_tool(&tool_call.name) {
339                    return Ok(CoordinatorResult {
340                        content: response.content,
341                        tool_calls: all_tool_calls,
342                        iterations: iteration + 1,
343                        finish_reason: FinishReason::UnknownTool(tool_call.name.clone()),
344                        total_usage,
345                        message_history: messages,
346                    });
347                }
348            }
349
350            // Execute tool calls
351            let tool_results = self.execute_tool_calls(&response.tool_calls).await?;
352
353            // Record tool calls and add results to message history
354            for record in tool_results {
355                // Add tool result to messages
356                messages.push(ConversationMessage::tool_result(&record.id, &record.result));
357                all_tool_calls.push(record);
358            }
359        }
360
361        // Hit max iterations
362        Ok(CoordinatorResult {
363            content: messages
364                .last()
365                .map(|m| m.content.clone())
366                .unwrap_or_default(),
367            tool_calls: all_tool_calls,
368            iterations: self.config.max_iterations,
369            finish_reason: FinishReason::MaxIterations,
370            total_usage,
371            message_history: messages,
372        })
373    }
374
375    /// Execute tool calls, either in parallel or sequentially based on config.
376    async fn execute_tool_calls(&self, calls: &[ToolCall]) -> Result<Vec<ToolCallRecord>> {
377        if self.config.parallel_execution {
378            self.execute_parallel(calls).await
379        } else {
380            self.execute_sequential(calls).await
381        }
382    }
383
384    /// Execute tool calls in parallel.
385    async fn execute_parallel(&self, calls: &[ToolCall]) -> Result<Vec<ToolCallRecord>> {
386        let futures = calls.iter().map(|call| self.execute_single_tool(call));
387        let results = join_all(futures).await;
388
389        let mut records = Vec::with_capacity(results.len());
390        for result in results {
391            match result {
392                Ok(record) => records.push(record),
393                Err(e) if self.config.stop_on_error => return Err(e),
394                Err(e) => {
395                    // Create an error record for failed tools
396                    records.push(ToolCallRecord {
397                        id: "error".to_string(),
398                        name: "unknown".to_string(),
399                        arguments: serde_json::Value::Null,
400                        result: serde_json::json!({"error": e.to_string()}),
401                        success: false,
402                        duration_ms: 0,
403                        error: Some(e.to_string()),
404                    });
405                }
406            }
407        }
408        Ok(records)
409    }
410
411    /// Execute tool calls sequentially.
412    async fn execute_sequential(&self, calls: &[ToolCall]) -> Result<Vec<ToolCallRecord>> {
413        let mut records = Vec::with_capacity(calls.len());
414        for call in calls {
415            match self.execute_single_tool(call).await {
416                Ok(record) => records.push(record),
417                Err(e) if self.config.stop_on_error => return Err(e),
418                Err(e) => {
419                    records.push(ToolCallRecord {
420                        id: call.id.clone(),
421                        name: call.name.clone(),
422                        arguments: call.arguments.clone(),
423                        result: serde_json::json!({"error": e.to_string()}),
424                        success: false,
425                        duration_ms: 0,
426                        error: Some(e.to_string()),
427                    });
428                }
429            }
430        }
431        Ok(records)
432    }
433
434    /// Execute a single tool call with timeout.
435    async fn execute_single_tool(&self, call: &ToolCall) -> Result<ToolCallRecord> {
436        let start = Instant::now();
437
438        let result = timeout(
439            self.config.tool_timeout,
440            self.registry.execute(&call.name, call.arguments.clone()),
441        )
442        .await;
443
444        let duration_ms = start.elapsed().as_millis() as u64;
445
446        match result {
447            Ok(Ok(value)) => Ok(ToolCallRecord {
448                id: call.id.clone(),
449                name: call.name.clone(),
450                arguments: call.arguments.clone(),
451                result: value,
452                success: true,
453                duration_ms,
454                error: None,
455            }),
456            Ok(Err(e)) => Ok(ToolCallRecord {
457                id: call.id.clone(),
458                name: call.name.clone(),
459                arguments: call.arguments.clone(),
460                result: serde_json::json!({"error": e.to_string()}),
461                success: false,
462                duration_ms,
463                error: Some(e.to_string()),
464            }),
465            Err(_) => Ok(ToolCallRecord {
466                id: call.id.clone(),
467                name: call.name.clone(),
468                arguments: call.arguments.clone(),
469                result: serde_json::json!({"error": "Tool execution timed out"}),
470                success: false,
471                duration_ms,
472                error: Some("Tool execution timed out".to_string()),
473            }),
474        }
475    }
476
477    /// Get a reference to the underlying LLM client.
478    pub fn client(&self) -> &dyn LLMClient {
479        self.client.as_ref()
480    }
481
482    /// Get a reference to the tool registry.
483    pub fn registry(&self) -> &Arc<ToolRegistry> {
484        &self.registry
485    }
486
487    /// Get a reference to the configuration.
488    pub fn config(&self) -> &ToolCallingConfig {
489        &self.config
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn test_tool_calling_config_default() {
499        let config = ToolCallingConfig::default();
500        assert_eq!(config.max_iterations, 10);
501        assert!(config.parallel_execution);
502        assert_eq!(config.tool_timeout, Duration::from_secs(30));
503        assert!(config.include_tool_results);
504        assert!(!config.stop_on_error);
505    }
506
507    #[test]
508    fn test_conversation_message_system() {
509        let msg = ConversationMessage::system("You are a helpful assistant.");
510        assert_eq!(msg.role, MessageRole::System);
511        assert_eq!(msg.content, "You are a helpful assistant.");
512        assert!(msg.tool_calls.is_empty());
513        assert!(msg.tool_call_id.is_none());
514    }
515
516    #[test]
517    fn test_conversation_message_user() {
518        let msg = ConversationMessage::user("Hello!");
519        assert_eq!(msg.role, MessageRole::User);
520        assert_eq!(msg.content, "Hello!");
521    }
522
523    #[test]
524    fn test_conversation_message_assistant_with_tool_calls() {
525        let tool_calls = vec![ToolCall {
526            id: "call_1".to_string(),
527            name: "calculator".to_string(),
528            arguments: serde_json::json!({"a": 1, "b": 2}),
529        }];
530        let msg = ConversationMessage::assistant("Let me calculate that.", tool_calls.clone());
531        assert_eq!(msg.role, MessageRole::Assistant);
532        assert_eq!(msg.tool_calls.len(), 1);
533        assert_eq!(msg.tool_calls[0].name, "calculator");
534    }
535
536    #[test]
537    fn test_conversation_message_tool_result() {
538        let result = serde_json::json!({"result": 42});
539        let msg = ConversationMessage::tool_result("call_1", &result);
540        assert_eq!(msg.role, MessageRole::Tool);
541        assert_eq!(msg.tool_call_id, Some("call_1".to_string()));
542        assert!(msg.content.contains("42"));
543    }
544
545    #[test]
546    fn test_finish_reason_display() {
547        assert_eq!(FinishReason::Stop.to_string(), "stop");
548        assert_eq!(FinishReason::MaxIterations.to_string(), "max_iterations");
549        assert_eq!(
550            FinishReason::Error("test error".to_string()).to_string(),
551            "error: test error"
552        );
553        assert_eq!(
554            FinishReason::UnknownTool("unknown".to_string()).to_string(),
555            "unknown_tool: unknown"
556        );
557    }
558
559    #[test]
560    fn test_tool_call_record_serialization() {
561        let record = ToolCallRecord {
562            id: "call_1".to_string(),
563            name: "test_tool".to_string(),
564            arguments: serde_json::json!({"input": "test"}),
565            result: serde_json::json!({"output": "result"}),
566            success: true,
567            duration_ms: 100,
568            error: None,
569        };
570
571        let json = serde_json::to_string(&record).unwrap();
572        assert!(json.contains("test_tool"));
573        assert!(json.contains("\"success\":true"));
574    }
575
576    #[test]
577    fn test_message_to_role_content() {
578        let msg = ConversationMessage::user("Hello");
579        let (role, content) = msg.to_role_content();
580        assert_eq!(role, "user");
581        assert_eq!(content, "Hello");
582
583        let msg = ConversationMessage::system("System prompt");
584        let (role, content) = msg.to_role_content();
585        assert_eq!(role, "system");
586        assert_eq!(content, "System prompt");
587    }
588}