1use crate::llm::client::{LLMClient, TokenUsage};
33use crate::tools::registry::ToolRegistry;
34use crate::types::{Result, ToolCall};
35use futures::future::join_all;
36use serde::{Deserialize, Serialize};
37use std::sync::Arc;
38use std::time::{Duration, Instant};
39use tokio::time::timeout;
40
41#[derive(Debug, Clone)]
46pub struct ToolCallingConfig {
47 pub max_iterations: usize,
50
51 pub parallel_execution: bool,
54
55 pub tool_timeout: Duration,
57
58 pub include_tool_results: bool,
60
61 pub stop_on_error: bool,
63}
64
65impl Default for ToolCallingConfig {
66 fn default() -> Self {
67 Self {
68 max_iterations: 10,
69 parallel_execution: true,
70 tool_timeout: Duration::from_secs(30),
71 include_tool_results: true,
72 stop_on_error: false,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ToolCallRecord {
83 pub id: String,
85 pub name: String,
87 pub arguments: serde_json::Value,
89 pub result: serde_json::Value,
91 pub success: bool,
93 pub duration_ms: u64,
95 pub error: Option<String>,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
101pub enum FinishReason {
102 Stop,
104 MaxIterations,
106 Error(String),
108 UnknownTool(String),
110}
111
112impl std::fmt::Display for FinishReason {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 match self {
115 FinishReason::Stop => write!(f, "stop"),
116 FinishReason::MaxIterations => write!(f, "max_iterations"),
117 FinishReason::Error(e) => write!(f, "error: {}", e),
118 FinishReason::UnknownTool(t) => write!(f, "unknown_tool: {}", t),
119 }
120 }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ConversationMessage {
129 pub role: MessageRole,
131 pub content: String,
133 #[serde(default, skip_serializing_if = "Vec::is_empty")]
135 pub tool_calls: Vec<ToolCall>,
136 #[serde(skip_serializing_if = "Option::is_none")]
138 pub tool_call_id: Option<String>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
143#[serde(rename_all = "lowercase")]
144pub enum MessageRole {
145 System,
147 User,
149 Assistant,
151 Tool,
153}
154
155impl ConversationMessage {
156 pub fn system(content: impl Into<String>) -> Self {
158 Self {
159 role: MessageRole::System,
160 content: content.into(),
161 tool_calls: Vec::new(),
162 tool_call_id: None,
163 }
164 }
165
166 pub fn user(content: impl Into<String>) -> Self {
168 Self {
169 role: MessageRole::User,
170 content: content.into(),
171 tool_calls: Vec::new(),
172 tool_call_id: None,
173 }
174 }
175
176 pub fn assistant(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
178 Self {
179 role: MessageRole::Assistant,
180 content: content.into(),
181 tool_calls,
182 tool_call_id: None,
183 }
184 }
185
186 pub fn tool_result(tool_call_id: impl Into<String>, result: &serde_json::Value) -> Self {
188 Self {
189 role: MessageRole::Tool,
190 content: serde_json::to_string(result).unwrap_or_else(|_| "{}".to_string()),
191 tool_calls: Vec::new(),
192 tool_call_id: Some(tool_call_id.into()),
193 }
194 }
195
196 pub fn to_role_content(&self) -> (String, String) {
198 let role = match self.role {
199 MessageRole::System => "system",
200 MessageRole::User => "user",
201 MessageRole::Assistant => "assistant",
202 MessageRole::Tool => "tool",
203 };
204 (role.to_string(), self.content.clone())
205 }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct CoordinatorResult {
215 pub content: String,
217
218 pub tool_calls: Vec<ToolCallRecord>,
220
221 pub iterations: usize,
223
224 pub finish_reason: FinishReason,
226
227 pub total_usage: TokenUsage,
229
230 pub message_history: Vec<ConversationMessage>,
232}
233
234pub struct ToolCoordinator {
248 client: Box<dyn LLMClient>,
249 registry: Arc<ToolRegistry>,
250 config: ToolCallingConfig,
251}
252
253impl ToolCoordinator {
254 pub fn new(
256 client: Box<dyn LLMClient>,
257 registry: Arc<ToolRegistry>,
258 config: ToolCallingConfig,
259 ) -> Self {
260 Self {
261 client,
262 registry,
263 config,
264 }
265 }
266
267 pub fn with_defaults(client: Box<dyn LLMClient>, registry: Arc<ToolRegistry>) -> Self {
269 Self::new(client, registry, ToolCallingConfig::default())
270 }
271
272 pub async fn execute(&self, system: Option<&str>, prompt: &str) -> Result<CoordinatorResult> {
290 let tools = self.registry.get_tool_definitions();
291 let mut messages: Vec<ConversationMessage> = Vec::new();
292 let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
293 let mut total_usage = TokenUsage::default();
294
295 if let Some(sys) = system {
297 messages.push(ConversationMessage::system(sys));
298 }
299
300 messages.push(ConversationMessage::user(prompt));
302
303 for iteration in 0..self.config.max_iterations {
304 let response = self
306 .client
307 .generate_with_tools_and_history(&messages, &tools)
308 .await?;
309
310 if let Some(usage) = &response.usage {
312 total_usage = TokenUsage::new(
313 total_usage.prompt_tokens + usage.prompt_tokens,
314 total_usage.completion_tokens + usage.completion_tokens,
315 );
316 }
317
318 messages.push(ConversationMessage::assistant(
320 &response.content,
321 response.tool_calls.clone(),
322 ));
323
324 if response.tool_calls.is_empty() {
326 return Ok(CoordinatorResult {
327 content: response.content,
328 tool_calls: all_tool_calls,
329 iterations: iteration + 1,
330 finish_reason: FinishReason::Stop,
331 total_usage,
332 message_history: messages,
333 });
334 }
335
336 for tool_call in &response.tool_calls {
338 if !self.registry.has_tool(&tool_call.name) {
339 return Ok(CoordinatorResult {
340 content: response.content,
341 tool_calls: all_tool_calls,
342 iterations: iteration + 1,
343 finish_reason: FinishReason::UnknownTool(tool_call.name.clone()),
344 total_usage,
345 message_history: messages,
346 });
347 }
348 }
349
350 let tool_results = self.execute_tool_calls(&response.tool_calls).await?;
352
353 for record in tool_results {
355 messages.push(ConversationMessage::tool_result(&record.id, &record.result));
357 all_tool_calls.push(record);
358 }
359 }
360
361 Ok(CoordinatorResult {
363 content: messages
364 .last()
365 .map(|m| m.content.clone())
366 .unwrap_or_default(),
367 tool_calls: all_tool_calls,
368 iterations: self.config.max_iterations,
369 finish_reason: FinishReason::MaxIterations,
370 total_usage,
371 message_history: messages,
372 })
373 }
374
375 async fn execute_tool_calls(&self, calls: &[ToolCall]) -> Result<Vec<ToolCallRecord>> {
377 if self.config.parallel_execution {
378 self.execute_parallel(calls).await
379 } else {
380 self.execute_sequential(calls).await
381 }
382 }
383
384 async fn execute_parallel(&self, calls: &[ToolCall]) -> Result<Vec<ToolCallRecord>> {
386 let futures = calls.iter().map(|call| self.execute_single_tool(call));
387 let results = join_all(futures).await;
388
389 let mut records = Vec::with_capacity(results.len());
390 for result in results {
391 match result {
392 Ok(record) => records.push(record),
393 Err(e) if self.config.stop_on_error => return Err(e),
394 Err(e) => {
395 records.push(ToolCallRecord {
397 id: "error".to_string(),
398 name: "unknown".to_string(),
399 arguments: serde_json::Value::Null,
400 result: serde_json::json!({"error": e.to_string()}),
401 success: false,
402 duration_ms: 0,
403 error: Some(e.to_string()),
404 });
405 }
406 }
407 }
408 Ok(records)
409 }
410
411 async fn execute_sequential(&self, calls: &[ToolCall]) -> Result<Vec<ToolCallRecord>> {
413 let mut records = Vec::with_capacity(calls.len());
414 for call in calls {
415 match self.execute_single_tool(call).await {
416 Ok(record) => records.push(record),
417 Err(e) if self.config.stop_on_error => return Err(e),
418 Err(e) => {
419 records.push(ToolCallRecord {
420 id: call.id.clone(),
421 name: call.name.clone(),
422 arguments: call.arguments.clone(),
423 result: serde_json::json!({"error": e.to_string()}),
424 success: false,
425 duration_ms: 0,
426 error: Some(e.to_string()),
427 });
428 }
429 }
430 }
431 Ok(records)
432 }
433
434 async fn execute_single_tool(&self, call: &ToolCall) -> Result<ToolCallRecord> {
436 let start = Instant::now();
437
438 let result = timeout(
439 self.config.tool_timeout,
440 self.registry.execute(&call.name, call.arguments.clone()),
441 )
442 .await;
443
444 let duration_ms = start.elapsed().as_millis() as u64;
445
446 match result {
447 Ok(Ok(value)) => Ok(ToolCallRecord {
448 id: call.id.clone(),
449 name: call.name.clone(),
450 arguments: call.arguments.clone(),
451 result: value,
452 success: true,
453 duration_ms,
454 error: None,
455 }),
456 Ok(Err(e)) => Ok(ToolCallRecord {
457 id: call.id.clone(),
458 name: call.name.clone(),
459 arguments: call.arguments.clone(),
460 result: serde_json::json!({"error": e.to_string()}),
461 success: false,
462 duration_ms,
463 error: Some(e.to_string()),
464 }),
465 Err(_) => Ok(ToolCallRecord {
466 id: call.id.clone(),
467 name: call.name.clone(),
468 arguments: call.arguments.clone(),
469 result: serde_json::json!({"error": "Tool execution timed out"}),
470 success: false,
471 duration_ms,
472 error: Some("Tool execution timed out".to_string()),
473 }),
474 }
475 }
476
477 pub fn client(&self) -> &dyn LLMClient {
479 self.client.as_ref()
480 }
481
482 pub fn registry(&self) -> &Arc<ToolRegistry> {
484 &self.registry
485 }
486
487 pub fn config(&self) -> &ToolCallingConfig {
489 &self.config
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_tool_calling_config_default() {
499 let config = ToolCallingConfig::default();
500 assert_eq!(config.max_iterations, 10);
501 assert!(config.parallel_execution);
502 assert_eq!(config.tool_timeout, Duration::from_secs(30));
503 assert!(config.include_tool_results);
504 assert!(!config.stop_on_error);
505 }
506
507 #[test]
508 fn test_conversation_message_system() {
509 let msg = ConversationMessage::system("You are a helpful assistant.");
510 assert_eq!(msg.role, MessageRole::System);
511 assert_eq!(msg.content, "You are a helpful assistant.");
512 assert!(msg.tool_calls.is_empty());
513 assert!(msg.tool_call_id.is_none());
514 }
515
516 #[test]
517 fn test_conversation_message_user() {
518 let msg = ConversationMessage::user("Hello!");
519 assert_eq!(msg.role, MessageRole::User);
520 assert_eq!(msg.content, "Hello!");
521 }
522
523 #[test]
524 fn test_conversation_message_assistant_with_tool_calls() {
525 let tool_calls = vec![ToolCall {
526 id: "call_1".to_string(),
527 name: "calculator".to_string(),
528 arguments: serde_json::json!({"a": 1, "b": 2}),
529 }];
530 let msg = ConversationMessage::assistant("Let me calculate that.", tool_calls.clone());
531 assert_eq!(msg.role, MessageRole::Assistant);
532 assert_eq!(msg.tool_calls.len(), 1);
533 assert_eq!(msg.tool_calls[0].name, "calculator");
534 }
535
536 #[test]
537 fn test_conversation_message_tool_result() {
538 let result = serde_json::json!({"result": 42});
539 let msg = ConversationMessage::tool_result("call_1", &result);
540 assert_eq!(msg.role, MessageRole::Tool);
541 assert_eq!(msg.tool_call_id, Some("call_1".to_string()));
542 assert!(msg.content.contains("42"));
543 }
544
545 #[test]
546 fn test_finish_reason_display() {
547 assert_eq!(FinishReason::Stop.to_string(), "stop");
548 assert_eq!(FinishReason::MaxIterations.to_string(), "max_iterations");
549 assert_eq!(
550 FinishReason::Error("test error".to_string()).to_string(),
551 "error: test error"
552 );
553 assert_eq!(
554 FinishReason::UnknownTool("unknown".to_string()).to_string(),
555 "unknown_tool: unknown"
556 );
557 }
558
559 #[test]
560 fn test_tool_call_record_serialization() {
561 let record = ToolCallRecord {
562 id: "call_1".to_string(),
563 name: "test_tool".to_string(),
564 arguments: serde_json::json!({"input": "test"}),
565 result: serde_json::json!({"output": "result"}),
566 success: true,
567 duration_ms: 100,
568 error: None,
569 };
570
571 let json = serde_json::to_string(&record).unwrap();
572 assert!(json.contains("test_tool"));
573 assert!(json.contains("\"success\":true"));
574 }
575
576 #[test]
577 fn test_message_to_role_content() {
578 let msg = ConversationMessage::user("Hello");
579 let (role, content) = msg.to_role_content();
580 assert_eq!(role, "user");
581 assert_eq!(content, "Hello");
582
583 let msg = ConversationMessage::system("System prompt");
584 let (role, content) = msg.to_role_content();
585 assert_eq!(role, "system");
586 assert_eq!(content, "System prompt");
587 }
588}