use crate::llm::client::{LLMClient, TokenUsage};
use crate::tools::registry::ToolRegistry;
use crate::types::{Result, ToolCall};
use futures::future::join_all;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::time::timeout;
#[derive(Debug, Clone)]
pub struct ToolCallingConfig {
pub max_iterations: usize,
pub parallel_execution: bool,
pub tool_timeout: Duration,
pub include_tool_results: bool,
pub stop_on_error: bool,
}
impl Default for ToolCallingConfig {
fn default() -> Self {
Self {
max_iterations: 10,
parallel_execution: true,
tool_timeout: Duration::from_secs(30),
include_tool_results: true,
stop_on_error: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRecord {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
pub result: serde_json::Value,
pub success: bool,
pub duration_ms: u64,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum FinishReason {
Stop,
MaxIterations,
Error(String),
UnknownTool(String),
}
impl std::fmt::Display for FinishReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FinishReason::Stop => write!(f, "stop"),
FinishReason::MaxIterations => write!(f, "max_iterations"),
FinishReason::Error(e) => write!(f, "error: {}", e),
FinishReason::UnknownTool(t) => write!(f, "unknown_tool: {}", t),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationMessage {
pub role: MessageRole,
pub content: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCall>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
impl ConversationMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: MessageRole::System,
content: content.into(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: MessageRole::User,
content: content.into(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
pub fn assistant(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
Self {
role: MessageRole::Assistant,
content: content.into(),
tool_calls,
tool_call_id: None,
}
}
pub fn tool_result(tool_call_id: impl Into<String>, result: &serde_json::Value) -> Self {
Self {
role: MessageRole::Tool,
content: serde_json::to_string(result).unwrap_or_else(|_| "{}".to_string()),
tool_calls: Vec::new(),
tool_call_id: Some(tool_call_id.into()),
}
}
pub fn to_role_content(&self) -> (String, String) {
let role = match self.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::Tool => "tool",
};
(role.to_string(), self.content.clone())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoordinatorResult {
pub content: String,
pub tool_calls: Vec<ToolCallRecord>,
pub iterations: usize,
pub finish_reason: FinishReason,
pub total_usage: TokenUsage,
pub message_history: Vec<ConversationMessage>,
}
pub struct ToolCoordinator {
client: Box<dyn LLMClient>,
registry: Arc<ToolRegistry>,
config: ToolCallingConfig,
}
impl ToolCoordinator {
pub fn new(
client: Box<dyn LLMClient>,
registry: Arc<ToolRegistry>,
config: ToolCallingConfig,
) -> Self {
Self {
client,
registry,
config,
}
}
pub fn with_defaults(client: Box<dyn LLMClient>, registry: Arc<ToolRegistry>) -> Self {
Self::new(client, registry, ToolCallingConfig::default())
}
pub async fn execute(&self, system: Option<&str>, prompt: &str) -> Result<CoordinatorResult> {
let tools = self.registry.get_tool_definitions();
let mut messages: Vec<ConversationMessage> = Vec::new();
let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
let mut total_usage = TokenUsage::default();
if let Some(sys) = system {
messages.push(ConversationMessage::system(sys));
}
messages.push(ConversationMessage::user(prompt));
for iteration in 0..self.config.max_iterations {
let response = self
.client
.generate_with_tools_and_history(&messages, &tools)
.await?;
if let Some(usage) = &response.usage {
total_usage = TokenUsage::new(
total_usage.prompt_tokens + usage.prompt_tokens,
total_usage.completion_tokens + usage.completion_tokens,
);
}
messages.push(ConversationMessage::assistant(
&response.content,
response.tool_calls.clone(),
));
if response.tool_calls.is_empty() {
return Ok(CoordinatorResult {
content: response.content,
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::Stop,
total_usage,
message_history: messages,
});
}
for tool_call in &response.tool_calls {
if !self.registry.has_tool(&tool_call.name) {
return Ok(CoordinatorResult {
content: response.content,
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::UnknownTool(tool_call.name.clone()),
total_usage,
message_history: messages,
});
}
}
let tool_results = self.execute_tool_calls(&response.tool_calls).await?;
for record in tool_results {
messages.push(ConversationMessage::tool_result(&record.id, &record.result));
all_tool_calls.push(record);
}
}
Ok(CoordinatorResult {
content: messages
.last()
.map(|m| m.content.clone())
.unwrap_or_default(),
tool_calls: all_tool_calls,
iterations: self.config.max_iterations,
finish_reason: FinishReason::MaxIterations,
total_usage,
message_history: messages,
})
}
async fn execute_tool_calls(&self, calls: &[ToolCall]) -> Result<Vec<ToolCallRecord>> {
if self.config.parallel_execution {
self.execute_parallel(calls).await
} else {
self.execute_sequential(calls).await
}
}
async fn execute_parallel(&self, calls: &[ToolCall]) -> Result<Vec<ToolCallRecord>> {
let futures = calls.iter().map(|call| self.execute_single_tool(call));
let results = join_all(futures).await;
let mut records = Vec::with_capacity(results.len());
for result in results {
match result {
Ok(record) => records.push(record),
Err(e) if self.config.stop_on_error => return Err(e),
Err(e) => {
records.push(ToolCallRecord {
id: "error".to_string(),
name: "unknown".to_string(),
arguments: serde_json::Value::Null,
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms: 0,
error: Some(e.to_string()),
});
}
}
}
Ok(records)
}
async fn execute_sequential(&self, calls: &[ToolCall]) -> Result<Vec<ToolCallRecord>> {
let mut records = Vec::with_capacity(calls.len());
for call in calls {
match self.execute_single_tool(call).await {
Ok(record) => records.push(record),
Err(e) if self.config.stop_on_error => return Err(e),
Err(e) => {
records.push(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms: 0,
error: Some(e.to_string()),
});
}
}
}
Ok(records)
}
async fn execute_single_tool(&self, call: &ToolCall) -> Result<ToolCallRecord> {
let start = Instant::now();
let result = timeout(
self.config.tool_timeout,
self.registry.execute(&call.name, call.arguments.clone()),
)
.await;
let duration_ms = start.elapsed().as_millis() as u64;
match result {
Ok(Ok(value)) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: value,
success: true,
duration_ms,
error: None,
}),
Ok(Err(e)) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms,
error: Some(e.to_string()),
}),
Err(_) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": "Tool execution timed out"}),
success: false,
duration_ms,
error: Some("Tool execution timed out".to_string()),
}),
}
}
pub fn client(&self) -> &dyn LLMClient {
self.client.as_ref()
}
pub fn registry(&self) -> &Arc<ToolRegistry> {
&self.registry
}
pub fn config(&self) -> &ToolCallingConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_calling_config_default() {
let config = ToolCallingConfig::default();
assert_eq!(config.max_iterations, 10);
assert!(config.parallel_execution);
assert_eq!(config.tool_timeout, Duration::from_secs(30));
assert!(config.include_tool_results);
assert!(!config.stop_on_error);
}
#[test]
fn test_conversation_message_system() {
let msg = ConversationMessage::system("You are a helpful assistant.");
assert_eq!(msg.role, MessageRole::System);
assert_eq!(msg.content, "You are a helpful assistant.");
assert!(msg.tool_calls.is_empty());
assert!(msg.tool_call_id.is_none());
}
#[test]
fn test_conversation_message_user() {
let msg = ConversationMessage::user("Hello!");
assert_eq!(msg.role, MessageRole::User);
assert_eq!(msg.content, "Hello!");
}
#[test]
fn test_conversation_message_assistant_with_tool_calls() {
let tool_calls = vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({"a": 1, "b": 2}),
}];
let msg = ConversationMessage::assistant("Let me calculate that.", tool_calls.clone());
assert_eq!(msg.role, MessageRole::Assistant);
assert_eq!(msg.tool_calls.len(), 1);
assert_eq!(msg.tool_calls[0].name, "calculator");
}
#[test]
fn test_conversation_message_tool_result() {
let result = serde_json::json!({"result": 42});
let msg = ConversationMessage::tool_result("call_1", &result);
assert_eq!(msg.role, MessageRole::Tool);
assert_eq!(msg.tool_call_id, Some("call_1".to_string()));
assert!(msg.content.contains("42"));
}
#[test]
fn test_finish_reason_display() {
assert_eq!(FinishReason::Stop.to_string(), "stop");
assert_eq!(FinishReason::MaxIterations.to_string(), "max_iterations");
assert_eq!(
FinishReason::Error("test error".to_string()).to_string(),
"error: test error"
);
assert_eq!(
FinishReason::UnknownTool("unknown".to_string()).to_string(),
"unknown_tool: unknown"
);
}
#[test]
fn test_tool_call_record_serialization() {
let record = ToolCallRecord {
id: "call_1".to_string(),
name: "test_tool".to_string(),
arguments: serde_json::json!({"input": "test"}),
result: serde_json::json!({"output": "result"}),
success: true,
duration_ms: 100,
error: None,
};
let json = serde_json::to_string(&record).unwrap();
assert!(json.contains("test_tool"));
assert!(json.contains("\"success\":true"));
}
#[test]
fn test_message_to_role_content() {
let msg = ConversationMessage::user("Hello");
let (role, content) = msg.to_role_content();
assert_eq!(role, "user");
assert_eq!(content, "Hello");
let msg = ConversationMessage::system("System prompt");
let (role, content) = msg.to_role_content();
assert_eq!(role, "system");
assert_eq!(content, "System prompt");
}
}