pub mod types;
pub use types::*;
use crate::agent::backend::LlmBackend;
use crate::agent::{Message, Role, TokenUsage, ToolCallRecord, ToolCallRequest, ToolResultMessage};
use crate::tools::ToolRegistry;
use futures::future::join_all;
use std::sync::Arc;
use std::time::Instant;
use tokio::time::timeout;
fn to_backend_message(msg: &ConversationMessage) -> Message {
let tool_result = if msg.role == Role::Tool {
msg.tool_call_id.as_ref().map(|id| ToolResultMessage {
tool_call_id: id.clone(),
content: serde_json::from_str(&msg.content)
.unwrap_or(serde_json::Value::String(msg.content.clone())),
success: true,
})
} else {
None
};
Message {
role: msg.role.clone(),
content: msg.content.clone(),
tool_calls: msg.tool_calls.clone(),
tool_result,
}
}
pub struct ToolCoordinator {
backend: Arc<dyn LlmBackend>,
registry: Arc<ToolRegistry>,
config: ToolCallingConfig,
}
impl ToolCoordinator {
pub fn new(
backend: Arc<dyn LlmBackend>,
registry: Arc<ToolRegistry>,
config: ToolCallingConfig,
) -> Self {
Self {
backend,
registry,
config,
}
}
pub async fn execute(
&self,
system_prompt: Option<&str>,
user_prompt: &str,
) -> crate::Result<CoordinatorResult> {
let mut messages: Vec<ConversationMessage> = Vec::new();
if let Some(sys) = system_prompt {
messages.push(ConversationMessage::system(sys));
}
messages.push(ConversationMessage::user(user_prompt));
self.execute_with_history(messages).await
}
pub async fn execute_with_history(
&self,
mut messages: Vec<ConversationMessage>,
) -> crate::Result<CoordinatorResult> {
let tool_defs = self.registry.get_definitions();
let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
let mut total_usage = TokenUsage::default();
for iteration in 0..self.config.max_iterations {
let backend_messages: Vec<Message> = messages.iter().map(to_backend_message).collect();
let response = self
.backend
.generate(&backend_messages, &tool_defs, None)
.await?;
if let Some(usage) = &response.usage {
total_usage.prompt_tokens += usage.prompt_tokens;
total_usage.completion_tokens += usage.completion_tokens;
total_usage.total_tokens += usage.total_tokens;
total_usage.reasoning_tokens += usage.reasoning_tokens;
total_usage.action_tokens += usage.action_tokens;
}
messages.push(ConversationMessage::assistant(
&response.content,
response.tool_calls.clone(),
));
if response.tool_calls.is_empty() {
return Ok(CoordinatorResult {
content: response.content,
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::Stop,
total_usage,
message_history: messages,
});
}
if response.content.is_empty() && response.tool_calls.is_empty() {
return Ok(CoordinatorResult {
content: String::new(),
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::Stop,
total_usage,
message_history: messages,
});
}
for tc in &response.tool_calls {
if !self.registry.has_tool(&tc.name) {
return Ok(CoordinatorResult {
content: response.content,
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::UnknownTool(tc.name.clone()),
total_usage,
message_history: messages,
});
}
}
let records = self.execute_tool_calls(&response.tool_calls).await?;
if self.config.stop_on_error {
if let Some(failed) = records.iter().find(|r| !r.success) {
let err_msg = failed
.result
.get("error")
.and_then(|v| v.as_str())
.unwrap_or("tool error")
.to_string();
return Ok(CoordinatorResult {
content: response.content,
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::Error(err_msg),
total_usage,
message_history: messages,
});
}
}
for record in records {
messages.push(ConversationMessage::tool_result(&record.id, &record.result));
all_tool_calls.push(record);
}
}
Ok(CoordinatorResult {
content: messages
.last()
.map(|m| m.content.clone())
.unwrap_or_default(),
tool_calls: all_tool_calls,
iterations: self.config.max_iterations,
finish_reason: FinishReason::MaxIterations,
total_usage,
message_history: messages,
})
}
async fn execute_tool_calls(
&self,
calls: &[ToolCallRequest],
) -> crate::Result<Vec<ToolCallRecord>> {
if self.config.parallel_execution {
self.execute_parallel(calls).await
} else {
self.execute_sequential(calls).await
}
}
async fn execute_parallel(
&self,
calls: &[ToolCallRequest],
) -> crate::Result<Vec<ToolCallRecord>> {
let futures = calls.iter().map(|c| self.execute_single_tool(c));
let results = join_all(futures).await;
let mut records = Vec::with_capacity(results.len());
for (i, res) in results.into_iter().enumerate() {
match res {
Ok(record) => records.push(record),
Err(e) if self.config.stop_on_error => return Err(e),
Err(e) => {
let call = &calls[i];
records.push(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms: 0,
});
}
}
}
Ok(records)
}
async fn execute_sequential(
&self,
calls: &[ToolCallRequest],
) -> crate::Result<Vec<ToolCallRecord>> {
let mut records = Vec::with_capacity(calls.len());
for call in calls {
match self.execute_single_tool(call).await {
Ok(record) => records.push(record),
Err(e) if self.config.stop_on_error => return Err(e),
Err(e) => {
records.push(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms: 0,
});
}
}
}
Ok(records)
}
async fn execute_single_tool(&self, call: &ToolCallRequest) -> crate::Result<ToolCallRecord> {
let start = Instant::now();
let result = timeout(
self.config.tool_timeout,
self.registry.execute(&call.name, call.arguments.clone()),
)
.await;
let duration_ms = start.elapsed().as_millis() as u64;
match result {
Ok(Ok(value)) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: value,
success: true,
duration_ms,
}),
Ok(Err(e)) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms,
}),
Err(_elapsed) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": "tool execution timed out"}),
success: false,
duration_ms,
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[tokio::test]
async fn execute_with_empty_registry_returns_model_response() {
use crate::agent::backend::mock::MockBackend;
let backend = Arc::new(MockBackend::with_text("Hello, world!"));
let registry = Arc::new(ToolRegistry::new());
let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
let result = coordinator
.execute(None, "Say hello")
.await
.expect("coordinator should not error");
assert_eq!(result.content, "Hello, world!");
assert_eq!(result.finish_reason, FinishReason::Stop);
assert_eq!(result.iterations, 1);
assert!(result.tool_calls.is_empty());
assert_eq!(result.message_history.len(), 2);
}
#[test]
fn tool_calling_config_defaults_are_sensible() {
use std::time::Duration;
let cfg = ToolCallingConfig::default();
assert_eq!(cfg.max_iterations, 10, "max_iterations default changed");
assert!(
cfg.parallel_execution,
"parallel_execution should default to true"
);
assert_eq!(
cfg.tool_timeout,
Duration::from_secs(30),
"tool_timeout default changed"
);
assert!(!cfg.stop_on_error, "stop_on_error should default to false");
}
#[tokio::test]
async fn coordinator_result_captures_finish_reason_max_iterations() {
use crate::agent::backend::mock::{MockBackend, MockResponse};
use crate::tools::Tool;
use async_trait::async_trait;
use serde_json::Value;
struct NoOpTool;
#[async_trait]
impl Tool for NoOpTool {
fn name(&self) -> &str {
"noop"
}
fn description(&self) -> &str {
"does nothing"
}
fn parameters_schema(&self) -> Value {
serde_json::json!({"type": "object", "properties": {}})
}
async fn execute(&self, _args: Value) -> crate::Result<Value> {
Ok(serde_json::json!({"ok": true}))
}
}
let responses: Vec<MockResponse> = (0..15)
.map(|_| MockResponse::tool_call("noop", serde_json::json!({})))
.collect();
let backend = Arc::new(MockBackend::new(responses));
let mut registry = ToolRegistry::new();
registry.register(std::sync::Arc::new(NoOpTool));
let registry = Arc::new(registry);
let config = ToolCallingConfig {
max_iterations: 3,
parallel_execution: false,
..ToolCallingConfig::default()
};
let coordinator = ToolCoordinator::new(backend, registry, config);
let result = coordinator
.execute(None, "loop forever")
.await
.expect("coordinator should not hard-error");
assert_eq!(
result.finish_reason,
FinishReason::MaxIterations,
"expected MaxIterations, got {:?}",
result.finish_reason
);
assert_eq!(result.iterations, 3);
assert_eq!(result.tool_calls.len(), 3);
assert!(result.tool_calls.iter().all(|tc| tc.success));
}
}