use crate::agent::backend::LlmBackend;
use crate::agent::{Message, Role, ToolCallRecord, ToolCallRequest, ToolResultMessage, TokenUsage};
use crate::tools::ToolRegistry;
use futures::future::join_all;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::time::timeout;
#[derive(Debug, Clone)]
pub struct ToolCallingConfig {
pub max_iterations: usize,
pub parallel_execution: bool,
pub tool_timeout: Duration,
pub stop_on_error: bool,
}
impl Default for ToolCallingConfig {
fn default() -> Self {
Self {
max_iterations: 10,
parallel_execution: true,
tool_timeout: Duration::from_secs(30),
stop_on_error: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum FinishReason {
Stop,
MaxIterations,
Error(String),
UnknownTool(String),
}
impl std::fmt::Display for FinishReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FinishReason::Stop => write!(f, "stop"),
FinishReason::MaxIterations => write!(f, "max_iterations"),
FinishReason::Error(e) => write!(f, "error: {}", e),
FinishReason::UnknownTool(t) => write!(f, "unknown_tool: {}", t),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationMessage {
pub role: MessageRole,
pub content: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCallRequest>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl ConversationMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: MessageRole::System,
content: content.into(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: MessageRole::User,
content: content.into(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
pub fn assistant(content: impl Into<String>, tool_calls: Vec<ToolCallRequest>) -> Self {
Self {
role: MessageRole::Assistant,
content: content.into(),
tool_calls,
tool_call_id: None,
}
}
pub fn tool_result(tool_call_id: impl Into<String>, result: &serde_json::Value) -> Self {
Self {
role: MessageRole::Tool,
content: serde_json::to_string(result).unwrap_or_else(|_| "{}".to_string()),
tool_calls: Vec::new(),
tool_call_id: Some(tool_call_id.into()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoordinatorResult {
pub content: String,
pub tool_calls: Vec<ToolCallRecord>,
pub iterations: usize,
pub finish_reason: FinishReason,
pub total_usage: TokenUsage,
pub message_history: Vec<ConversationMessage>,
}
fn to_backend_message(msg: &ConversationMessage) -> Message {
let role = match msg.role {
MessageRole::System => Role::System,
MessageRole::User => Role::User,
MessageRole::Assistant => Role::Assistant,
MessageRole::Tool => Role::Tool,
};
let tool_result = if msg.role == MessageRole::Tool {
msg.tool_call_id.as_ref().map(|id| ToolResultMessage {
tool_call_id: id.clone(),
content: serde_json::from_str(&msg.content).unwrap_or(serde_json::Value::String(msg.content.clone())),
success: true,
})
} else {
None
};
Message {
role,
content: msg.content.clone(),
tool_calls: msg.tool_calls.clone(),
tool_result,
}
}
pub struct ToolCoordinator {
backend: Arc<dyn LlmBackend>,
registry: Arc<ToolRegistry>,
config: ToolCallingConfig,
}
impl ToolCoordinator {
pub fn new(
backend: Arc<dyn LlmBackend>,
registry: Arc<ToolRegistry>,
config: ToolCallingConfig,
) -> Self {
Self { backend, registry, config }
}
pub async fn execute(
&self,
system_prompt: Option<&str>,
user_prompt: &str,
) -> crate::Result<CoordinatorResult> {
let mut messages: Vec<ConversationMessage> = Vec::new();
if let Some(sys) = system_prompt {
messages.push(ConversationMessage::system(sys));
}
messages.push(ConversationMessage::user(user_prompt));
self.execute_with_history(messages).await
}
pub async fn execute_with_history(
&self,
mut messages: Vec<ConversationMessage>,
) -> crate::Result<CoordinatorResult> {
let tool_defs = self.registry.get_definitions();
let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
let mut total_usage = TokenUsage::default();
for iteration in 0..self.config.max_iterations {
let backend_messages: Vec<Message> =
messages.iter().map(to_backend_message).collect();
let response = self
.backend
.generate(&backend_messages, &tool_defs, None)
.await?;
if let Some(usage) = &response.usage {
total_usage.prompt_tokens += usage.prompt_tokens;
total_usage.completion_tokens += usage.completion_tokens;
total_usage.total_tokens += usage.total_tokens;
total_usage.reasoning_tokens += usage.reasoning_tokens;
total_usage.action_tokens += usage.action_tokens;
}
messages.push(ConversationMessage::assistant(
&response.content,
response.tool_calls.clone(),
));
if response.tool_calls.is_empty() {
return Ok(CoordinatorResult {
content: response.content,
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::Stop,
total_usage,
message_history: messages,
});
}
if response.content.is_empty() && response.tool_calls.is_empty() {
return Ok(CoordinatorResult {
content: String::new(),
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::Stop,
total_usage,
message_history: messages,
});
}
for tc in &response.tool_calls {
if !self.registry.has_tool(&tc.name) {
return Ok(CoordinatorResult {
content: response.content,
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::UnknownTool(tc.name.clone()),
total_usage,
message_history: messages,
});
}
}
let records = self.execute_tool_calls(&response.tool_calls).await?;
if self.config.stop_on_error {
if let Some(failed) = records.iter().find(|r| !r.success) {
let err_msg = failed
.result
.get("error")
.and_then(|v| v.as_str())
.unwrap_or("tool error")
.to_string();
return Ok(CoordinatorResult {
content: response.content,
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::Error(err_msg),
total_usage,
message_history: messages,
});
}
}
for record in records {
messages.push(ConversationMessage::tool_result(&record.id, &record.result));
all_tool_calls.push(record);
}
}
Ok(CoordinatorResult {
content: messages
.last()
.map(|m| m.content.clone())
.unwrap_or_default(),
tool_calls: all_tool_calls,
iterations: self.config.max_iterations,
finish_reason: FinishReason::MaxIterations,
total_usage,
message_history: messages,
})
}
async fn execute_tool_calls(
&self,
calls: &[ToolCallRequest],
) -> crate::Result<Vec<ToolCallRecord>> {
if self.config.parallel_execution {
self.execute_parallel(calls).await
} else {
self.execute_sequential(calls).await
}
}
async fn execute_parallel(
&self,
calls: &[ToolCallRequest],
) -> crate::Result<Vec<ToolCallRecord>> {
let futures = calls.iter().map(|c| self.execute_single_tool(c));
let results = join_all(futures).await;
let mut records = Vec::with_capacity(results.len());
for (i, res) in results.into_iter().enumerate() {
match res {
Ok(record) => records.push(record),
Err(e) if self.config.stop_on_error => return Err(e),
Err(e) => {
let call = &calls[i];
records.push(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms: 0,
});
}
}
}
Ok(records)
}
async fn execute_sequential(
&self,
calls: &[ToolCallRequest],
) -> crate::Result<Vec<ToolCallRecord>> {
let mut records = Vec::with_capacity(calls.len());
for call in calls {
match self.execute_single_tool(call).await {
Ok(record) => records.push(record),
Err(e) if self.config.stop_on_error => return Err(e),
Err(e) => {
records.push(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms: 0,
});
}
}
}
Ok(records)
}
async fn execute_single_tool(&self, call: &ToolCallRequest) -> crate::Result<ToolCallRecord> {
let start = Instant::now();
let result = timeout(
self.config.tool_timeout,
self.registry.execute(&call.name, call.arguments.clone()),
)
.await;
let duration_ms = start.elapsed().as_millis() as u64;
match result {
Ok(Ok(value)) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: value,
success: true,
duration_ms,
}),
Ok(Err(e)) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms,
}),
Err(_elapsed) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": "tool execution timed out"}),
success: false,
duration_ms,
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn tool_calling_config_default_values() {
let cfg = ToolCallingConfig::default();
assert_eq!(cfg.max_iterations, 10);
assert!(cfg.parallel_execution);
assert_eq!(cfg.tool_timeout, Duration::from_secs(30));
assert!(!cfg.stop_on_error);
}
#[test]
fn finish_reason_display_matches_snake_case_contract() {
assert_eq!(FinishReason::Stop.to_string(), "stop");
assert_eq!(FinishReason::MaxIterations.to_string(), "max_iterations");
assert_eq!(
FinishReason::Error("boom".into()).to_string(),
"error: boom"
);
assert_eq!(
FinishReason::UnknownTool("ghost".into()).to_string(),
"unknown_tool: ghost"
);
}
#[test]
fn finish_reason_round_trips_through_json() {
for variant in [
FinishReason::Stop,
FinishReason::MaxIterations,
FinishReason::Error("oops".into()),
FinishReason::UnknownTool("nope".into()),
] {
let encoded = serde_json::to_string(&variant).unwrap();
let decoded: FinishReason = serde_json::from_str(&encoded).unwrap();
assert_eq!(
decoded, variant,
"{} did not round-trip through JSON",
variant
);
}
}
#[test]
fn message_role_serializes_as_lowercase_string() {
assert_eq!(
serde_json::to_string(&MessageRole::System).unwrap(),
"\"system\""
);
assert_eq!(
serde_json::to_string(&MessageRole::User).unwrap(),
"\"user\""
);
assert_eq!(
serde_json::to_string(&MessageRole::Assistant).unwrap(),
"\"assistant\""
);
assert_eq!(
serde_json::to_string(&MessageRole::Tool).unwrap(),
"\"tool\""
);
}
#[test]
fn conversation_message_system_builder_sets_role_and_content() {
let msg = ConversationMessage::system("you are an assistant");
assert_eq!(msg.role, MessageRole::System);
assert_eq!(msg.content, "you are an assistant");
assert!(msg.tool_calls.is_empty());
assert!(msg.tool_call_id.is_none());
}
#[test]
fn conversation_message_user_builder_sets_role_and_content() {
let msg = ConversationMessage::user("what is 2 + 2?");
assert_eq!(msg.role, MessageRole::User);
assert_eq!(msg.content, "what is 2 + 2?");
assert!(msg.tool_calls.is_empty());
assert!(msg.tool_call_id.is_none());
}
#[test]
fn conversation_message_assistant_builder_preserves_tool_calls() {
let calls = vec![ToolCallRequest {
id: "call_1".into(),
name: "search".into(),
arguments: json!({"q": "rust"}),
}];
let msg = ConversationMessage::assistant("let me search", calls.clone());
assert_eq!(msg.role, MessageRole::Assistant);
assert_eq!(msg.content, "let me search");
assert_eq!(msg.tool_calls.len(), 1);
assert_eq!(msg.tool_calls[0].id, "call_1");
assert_eq!(msg.tool_calls[0].name, "search");
assert!(msg.tool_call_id.is_none());
}
#[test]
fn conversation_message_tool_result_serializes_result_into_content() {
let result = json!({"answer": 42, "units": "none"});
let msg = ConversationMessage::tool_result("call_1", &result);
assert_eq!(msg.role, MessageRole::Tool);
assert_eq!(msg.tool_call_id.as_deref(), Some("call_1"));
assert!(msg.tool_calls.is_empty());
let parsed: serde_json::Value = serde_json::from_str(&msg.content).unwrap();
assert_eq!(parsed, result);
}
#[test]
fn conversation_message_tool_result_falls_back_on_serialize_failure() {
let msg = ConversationMessage::tool_result("call_1", &json!(null));
assert_eq!(msg.content, "null");
}
#[test]
fn conversation_message_serde_skips_empty_tool_calls_and_none_id() {
let msg = ConversationMessage::user("hi");
let encoded = serde_json::to_string(&msg).unwrap();
assert!(!encoded.contains("tool_calls"));
assert!(!encoded.contains("tool_call_id"));
assert!(encoded.contains("\"role\":\"user\""));
assert!(encoded.contains("\"content\":\"hi\""));
}
#[test]
fn coordinator_result_round_trips_through_json() {
let result = CoordinatorResult {
content: "done".into(),
tool_calls: vec![ToolCallRecord {
id: "call_1".into(),
name: "echo".into(),
arguments: json!({"text": "hi"}),
result: json!({"text": "hi"}),
success: true,
duration_ms: 12,
}],
iterations: 2,
finish_reason: FinishReason::Stop,
total_usage: TokenUsage {
prompt_tokens: 100,
completion_tokens: 20,
total_tokens: 120,
reasoning_tokens: 0,
action_tokens: 20,
},
message_history: vec![
ConversationMessage::system("be brief"),
ConversationMessage::user("echo hi"),
ConversationMessage::assistant(
"",
vec![ToolCallRequest {
id: "call_1".into(),
name: "echo".into(),
arguments: json!({"text": "hi"}),
}],
),
ConversationMessage::tool_result("call_1", &json!({"text": "hi"})),
ConversationMessage::assistant("done", vec![]),
],
};
let encoded = serde_json::to_string(&result).unwrap();
let decoded: CoordinatorResult = serde_json::from_str(&encoded).unwrap();
assert_eq!(decoded.content, "done");
assert_eq!(decoded.iterations, 2);
assert_eq!(decoded.finish_reason, FinishReason::Stop);
assert_eq!(decoded.tool_calls.len(), 1);
assert_eq!(decoded.tool_calls[0].id, "call_1");
assert_eq!(decoded.message_history.len(), 5);
assert_eq!(decoded.total_usage.total_tokens, 120);
}
#[tokio::test]
async fn execute_with_empty_registry_returns_model_response() {
use crate::agent::backend::mock::MockBackend;
let backend = Arc::new(MockBackend::with_text("Hello, world!"));
let registry = Arc::new(ToolRegistry::new());
let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
let result = coordinator
.execute(None, "Say hello")
.await
.expect("coordinator should not error");
assert_eq!(result.content, "Hello, world!");
assert_eq!(result.finish_reason, FinishReason::Stop);
assert_eq!(result.iterations, 1);
assert!(result.tool_calls.is_empty());
assert_eq!(result.message_history.len(), 2);
}
#[test]
fn tool_calling_config_defaults_are_sensible() {
let cfg = ToolCallingConfig::default();
assert_eq!(cfg.max_iterations, 10, "max_iterations default changed");
assert!(cfg.parallel_execution, "parallel_execution should default to true");
assert_eq!(cfg.tool_timeout, Duration::from_secs(30), "tool_timeout default changed");
assert!(!cfg.stop_on_error, "stop_on_error should default to false");
}
#[tokio::test]
async fn coordinator_result_captures_finish_reason_max_iterations() {
use crate::agent::backend::mock::{MockBackend, MockResponse};
use async_trait::async_trait;
use crate::tools::Tool;
use serde_json::Value;
struct NoOpTool;
#[async_trait]
impl Tool for NoOpTool {
fn name(&self) -> &str { "noop" }
fn description(&self) -> &str { "does nothing" }
fn parameters_schema(&self) -> Value {
serde_json::json!({"type": "object", "properties": {}})
}
async fn execute(&self, _args: Value) -> crate::Result<Value> {
Ok(serde_json::json!({"ok": true}))
}
}
let responses: Vec<MockResponse> = (0..15)
.map(|_| MockResponse::tool_call("noop", serde_json::json!({})))
.collect();
let backend = Arc::new(MockBackend::new(responses));
let mut registry = ToolRegistry::new();
registry.register(std::sync::Arc::new(NoOpTool));
let registry = Arc::new(registry);
let config = ToolCallingConfig {
max_iterations: 3,
parallel_execution: false,
..ToolCallingConfig::default()
};
let coordinator = ToolCoordinator::new(backend, registry, config);
let result = coordinator
.execute(None, "loop forever")
.await
.expect("coordinator should not hard-error");
assert_eq!(
result.finish_reason,
FinishReason::MaxIterations,
"expected MaxIterations, got {:?}",
result.finish_reason
);
assert_eq!(result.iterations, 3);
assert_eq!(result.tool_calls.len(), 3);
assert!(result.tool_calls.iter().all(|tc| tc.success));
}
}