pub mod types;
pub use types::*;
use crate::agent::backend::LlmBackend;
use crate::agent::{Message, Role, TokenUsage, ToolCallRecord, ToolCallRequest, ToolResultMessage};
use crate::tools::ToolRegistry;
use async_trait::async_trait;
use futures::future::join_all;
use serde_json::Value;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Instant;
use tokio::time::timeout;
fn to_backend_message(msg: &ConversationMessage) -> Message {
let tool_result = if msg.role == Role::Tool {
msg.tool_call_id.as_ref().map(|id| ToolResultMessage {
tool_call_id: id.clone(),
content: serde_json::from_str(&msg.content)
.unwrap_or(serde_json::Value::String(msg.content.clone())),
success: true,
})
} else {
None
};
Message {
role: msg.role.clone(),
content: msg.content.clone(),
tool_calls: msg.tool_calls.clone(),
tool_result,
}
}
pub struct ToolCoordinator {
backend: Arc<dyn LlmBackend>,
registry: Arc<ToolRegistry>,
config: ToolCallingConfig,
}
impl ToolCoordinator {
pub fn new(
backend: Arc<dyn LlmBackend>,
registry: Arc<ToolRegistry>,
config: ToolCallingConfig,
) -> Self {
Self {
backend,
registry,
config,
}
}
pub async fn execute(
&self,
system_prompt: Option<&str>,
user_prompt: &str,
) -> crate::Result<CoordinatorResult> {
let mut messages: Vec<ConversationMessage> = Vec::new();
if let Some(sys) = system_prompt {
messages.push(ConversationMessage::system(sys));
}
messages.push(ConversationMessage::user(user_prompt));
self.execute_with_history(messages).await
}
pub async fn execute_with_history(
&self,
mut messages: Vec<ConversationMessage>,
) -> crate::Result<CoordinatorResult> {
let tool_defs = self.registry.get_definitions();
let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
let mut total_usage = TokenUsage::default();
for iteration in 0..self.config.max_iterations {
let backend_messages: Vec<Message> = messages.iter().map(to_backend_message).collect();
let response = self
.backend
.generate(&backend_messages, &tool_defs, None)
.await?;
if let Some(usage) = &response.usage {
total_usage.prompt_tokens += usage.prompt_tokens;
total_usage.completion_tokens += usage.completion_tokens;
total_usage.total_tokens += usage.total_tokens;
total_usage.reasoning_tokens += usage.reasoning_tokens;
total_usage.action_tokens += usage.action_tokens;
}
messages.push(ConversationMessage::assistant(
&response.content,
response.tool_calls.clone(),
));
if response.tool_calls.is_empty() {
return Ok(CoordinatorResult {
content: response.content,
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::Stop,
total_usage,
message_history: messages,
});
}
if response.content.is_empty() && response.tool_calls.is_empty() {
return Ok(CoordinatorResult {
content: String::new(),
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::Stop,
total_usage,
message_history: messages,
});
}
for tc in &response.tool_calls {
if !self.registry.has_tool(&tc.name) {
return Ok(CoordinatorResult {
content: response.content,
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::UnknownTool(tc.name.clone()),
total_usage,
message_history: messages,
});
}
}
let records = self.execute_tool_calls(&response.tool_calls).await?;
if self.config.stop_on_error {
if let Some(failed) = records.iter().find(|r| !r.success) {
let err_msg = failed
.result
.get("error")
.and_then(|v| v.as_str())
.unwrap_or("tool error")
.to_string();
return Ok(CoordinatorResult {
content: response.content,
tool_calls: all_tool_calls,
iterations: iteration + 1,
finish_reason: FinishReason::Error(err_msg),
total_usage,
message_history: messages,
});
}
}
for record in records {
messages.push(ConversationMessage::tool_result(&record.id, &record.result));
all_tool_calls.push(record);
}
}
Ok(CoordinatorResult {
content: messages
.last()
.map(|m| m.content.clone())
.unwrap_or_default(),
tool_calls: all_tool_calls,
iterations: self.config.max_iterations,
finish_reason: FinishReason::MaxIterations,
total_usage,
message_history: messages,
})
}
async fn execute_tool_calls(
&self,
calls: &[ToolCallRequest],
) -> crate::Result<Vec<ToolCallRecord>> {
if self.config.parallel_execution {
self.execute_parallel(calls).await
} else {
self.execute_sequential(calls).await
}
}
async fn execute_parallel(
&self,
calls: &[ToolCallRequest],
) -> crate::Result<Vec<ToolCallRecord>> {
let futures = calls.iter().map(|c| self.execute_single_tool(c));
let results = join_all(futures).await;
let mut records = Vec::with_capacity(results.len());
for (i, res) in results.into_iter().enumerate() {
match res {
Ok(record) => records.push(record),
Err(e) if self.config.stop_on_error => return Err(e),
Err(e) => {
let call = &calls[i];
records.push(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms: 0,
});
}
}
}
Ok(records)
}
async fn execute_sequential(
&self,
calls: &[ToolCallRequest],
) -> crate::Result<Vec<ToolCallRecord>> {
let mut records = Vec::with_capacity(calls.len());
for call in calls {
match self.execute_single_tool(call).await {
Ok(record) => records.push(record),
Err(e) if self.config.stop_on_error => return Err(e),
Err(e) => {
records.push(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms: 0,
});
}
}
}
Ok(records)
}
async fn execute_single_tool(&self, call: &ToolCallRequest) -> crate::Result<ToolCallRecord> {
let start = Instant::now();
let result = timeout(
self.config.tool_timeout,
self.registry.execute(&call.name, call.arguments.clone()),
)
.await;
let duration_ms = start.elapsed().as_millis() as u64;
match result {
Ok(Ok(value)) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: value,
success: true,
duration_ms,
}),
Ok(Err(e)) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": e.to_string()}),
success: false,
duration_ms,
}),
Err(_elapsed) => Ok(ToolCallRecord {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: serde_json::json!({"error": "tool execution timed out"}),
success: false,
duration_ms,
}),
}
}
}
pub const KNOWN_AGENT_TYPES: &[&str] = &[
"explore",
"plan",
"task",
"reviewer",
"designer",
"librarian",
];
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ScheduledTask {
pub id: String,
pub agent_type: String,
pub assignment: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ScheduleError {
EmptyTaskList,
InvalidAgentType(String),
DuplicateTaskId(String),
}
impl std::fmt::Display for ScheduleError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ScheduleError::EmptyTaskList => write!(f, "task list must not be empty"),
ScheduleError::InvalidAgentType(agent) => write!(
f,
"unknown agent type '{agent}'. Valid types: {}",
KNOWN_AGENT_TYPES.join(", ")
),
ScheduleError::DuplicateTaskId(id) => write!(f, "duplicate task id '{id}'"),
}
}
}
pub fn validate_task_schedule(tasks: &[ScheduledTask]) -> Result<(), ScheduleError> {
if tasks.is_empty() {
return Err(ScheduleError::EmptyTaskList);
}
let mut seen_ids = HashSet::with_capacity(tasks.len());
for task in tasks {
if !KNOWN_AGENT_TYPES.contains(&task.agent_type.as_str()) {
return Err(ScheduleError::InvalidAgentType(task.agent_type.clone()));
}
if !seen_ids.insert(task.id.clone()) {
return Err(ScheduleError::DuplicateTaskId(task.id.clone()));
}
if task.assignment.trim().is_empty() {
return Err(ScheduleError::InvalidAgentType(
"assignment must be non-empty".into(),
));
}
}
Ok(())
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ScheduledTaskResult {
pub id: String,
pub output: Value,
}
#[async_trait]
pub trait TaskRunner: Send + Sync {
async fn run(&self, task: &ScheduledTask) -> crate::Result<Value>;
}
pub struct TaskScheduleCoordinator<R> {
runner: Arc<R>,
}
impl<R: TaskRunner> TaskScheduleCoordinator<R> {
pub fn new(runner: Arc<R>) -> Self {
Self { runner }
}
pub async fn schedule(
&self,
tasks: &[ScheduledTask],
) -> Result<Vec<ScheduledTaskResult>, ScheduleError> {
validate_task_schedule(tasks)?;
let mut results = Vec::with_capacity(tasks.len());
for task in tasks {
let output = self.runner.run(task).await.map_err(|e| {
ScheduleError::InvalidAgentType(format!("task '{}' failed: {e}", task.id))
})?;
results.push(ScheduledTaskResult {
id: task.id.clone(),
output,
});
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[tokio::test]
async fn execute_with_empty_registry_returns_model_response() {
use crate::agent::backend::mock::MockBackend;
let backend = Arc::new(MockBackend::with_text("Hello, world!"));
let registry = Arc::new(ToolRegistry::new());
let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
let result = coordinator
.execute(None, "Say hello")
.await
.expect("coordinator should not error");
assert_eq!(result.content, "Hello, world!");
assert_eq!(result.finish_reason, FinishReason::Stop);
assert_eq!(result.iterations, 1);
assert!(result.tool_calls.is_empty());
assert_eq!(result.message_history.len(), 2);
}
#[test]
fn tool_calling_config_defaults_are_sensible() {
use std::time::Duration;
let cfg = ToolCallingConfig::default();
assert_eq!(cfg.max_iterations, 10, "max_iterations default changed");
assert!(
cfg.parallel_execution,
"parallel_execution should default to true"
);
assert_eq!(
cfg.tool_timeout,
Duration::from_secs(30),
"tool_timeout default changed"
);
assert!(!cfg.stop_on_error, "stop_on_error should default to false");
}
#[tokio::test]
async fn coordinator_result_captures_finish_reason_max_iterations() {
use crate::agent::backend::mock::{MockBackend, MockResponse};
use crate::tools::Tool;
use async_trait::async_trait;
use serde_json::Value;
struct NoOpTool;
#[async_trait]
impl Tool for NoOpTool {
fn name(&self) -> &str {
"noop"
}
fn description(&self) -> &str {
"does nothing"
}
fn parameters_schema(&self) -> Value {
serde_json::json!({"type": "object", "properties": {}})
}
async fn execute(&self, _args: Value) -> crate::Result<Value> {
Ok(serde_json::json!({"ok": true}))
}
}
let responses: Vec<MockResponse> = (0..15)
.map(|_| MockResponse::tool_call("noop", serde_json::json!({})))
.collect();
let backend = Arc::new(MockBackend::new(responses));
let mut registry = ToolRegistry::new();
registry.register(std::sync::Arc::new(NoOpTool));
let registry = Arc::new(registry);
let config = ToolCallingConfig {
max_iterations: 3,
parallel_execution: false,
..ToolCallingConfig::default()
};
let coordinator = ToolCoordinator::new(backend, registry, config);
let result = coordinator
.execute(None, "loop forever")
.await
.expect("coordinator should not hard-error");
assert_eq!(
result.finish_reason,
FinishReason::MaxIterations,
"expected MaxIterations, got {:?}",
result.finish_reason
);
assert_eq!(result.iterations, 3);
assert_eq!(result.tool_calls.len(), 3);
assert!(result.tool_calls.iter().all(|tc| tc.success));
}
#[tokio::test]
async fn test_unknown_tool_validation_returns_unknown_tool_finish_reason() {
use crate::agent::backend::mock::MockBackend;
let backend = Arc::new(MockBackend::with_tool_call(
"call_ghost",
"definitely_not_registered",
serde_json::json!({}),
"should not reach this",
));
let registry = Arc::new(ToolRegistry::new());
let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
let result = coordinator
.execute(None, "use a ghost tool")
.await
.expect("unknown tool should surface as a coordinator result, not a hard error");
assert_eq!(
result.finish_reason,
FinishReason::UnknownTool("definitely_not_registered".into())
);
assert!(
result.tool_calls.is_empty(),
"unknown tool must not be executed"
);
assert_eq!(result.iterations, 1);
}
#[tokio::test]
async fn test_stop_on_error_halts_on_failed_tool_execution() {
use crate::agent::backend::mock::{MockBackend, MockResponse};
use crate::tools::Tool;
use async_trait::async_trait;
use serde_json::Value;
struct FailingTool;
#[async_trait]
impl Tool for FailingTool {
fn name(&self) -> &str {
"fail_me"
}
fn description(&self) -> &str {
"always fails"
}
fn parameters_schema(&self) -> Value {
serde_json::json!({"type": "object", "properties": {}})
}
async fn execute(&self, _args: Value) -> crate::Result<Value> {
Err(crate::PawanError::Tool("intentional failure".into()))
}
}
let backend = Arc::new(MockBackend::new(vec![
MockResponse::tool_call("fail_me", serde_json::json!({})),
MockResponse::text("unreachable"),
]));
let mut registry = ToolRegistry::new();
registry.register(Arc::new(FailingTool));
let registry = Arc::new(registry);
let config = ToolCallingConfig {
stop_on_error: true,
parallel_execution: false,
..ToolCallingConfig::default()
};
let coordinator = ToolCoordinator::new(backend, registry, config);
let result = coordinator
.execute(None, "trigger failure")
.await
.expect("stop_on_error should return Ok with Error finish reason");
match &result.finish_reason {
FinishReason::Error(msg) => {
assert!(
msg.contains("intentional failure"),
"error message should propagate from tool, got: {}",
msg
);
}
other => panic!("expected FinishReason::Error, got {:?}", other),
}
assert_eq!(result.iterations, 1);
}
#[tokio::test]
async fn test_tool_timeout_records_failed_tool_call() {
use crate::agent::backend::mock::{MockBackend, MockResponse};
use crate::tools::Tool;
use async_trait::async_trait;
use serde_json::Value;
use std::time::Duration;
struct SlowTool;
#[async_trait]
impl Tool for SlowTool {
fn name(&self) -> &str {
"slow_tool"
}
fn description(&self) -> &str {
"sleeps longer than the coordinator timeout"
}
fn parameters_schema(&self) -> Value {
serde_json::json!({"type": "object", "properties": {}})
}
async fn execute(&self, _args: Value) -> crate::Result<Value> {
tokio::time::sleep(Duration::from_secs(2)).await;
Ok(serde_json::json!({"ok": true}))
}
}
let backend = Arc::new(MockBackend::new(vec![
MockResponse::tool_call("slow_tool", serde_json::json!({})),
MockResponse::text("done after timeout"),
]));
let mut registry = ToolRegistry::new();
registry.register(Arc::new(SlowTool));
let registry = Arc::new(registry);
let config = ToolCallingConfig {
tool_timeout: Duration::from_millis(50),
parallel_execution: false,
..ToolCallingConfig::default()
};
let coordinator = ToolCoordinator::new(backend, registry, config);
let result = coordinator
.execute(None, "run slow tool")
.await
.expect("timeout should be absorbed into a failed tool record");
assert_eq!(result.tool_calls.len(), 1);
let record = &result.tool_calls[0];
assert!(
!record.success,
"timed-out tool must be marked unsuccessful"
);
assert_eq!(
record.result.get("error").and_then(|v| v.as_str()),
Some("tool execution timed out")
);
assert_eq!(result.finish_reason, FinishReason::Stop);
assert_eq!(result.iterations, 2);
}
#[tokio::test]
async fn test_execute_with_system_prompt_prepends_system_message() {
use crate::agent::backend::mock::MockBackend;
let backend = Arc::new(MockBackend::with_text("acknowledged"));
let registry = Arc::new(ToolRegistry::new());
let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
let result = coordinator
.execute(Some("be concise"), "hello")
.await
.expect("execute should succeed");
assert_eq!(result.message_history.len(), 3);
assert_eq!(result.message_history[0].role, Role::System);
assert_eq!(result.message_history[0].content, "be concise");
assert_eq!(result.message_history[1].role, Role::User);
assert_eq!(result.message_history[1].content, "hello");
assert_eq!(result.message_history[2].role, Role::Assistant);
}
#[tokio::test]
async fn test_token_usage_captured_from_backend_response() {
use crate::agent::backend::mock::{MockBackend, MockResponse};
use crate::tools::Tool;
use async_trait::async_trait;
use serde_json::Value;
struct NoOpTool;
#[async_trait]
impl Tool for NoOpTool {
fn name(&self) -> &str {
"noop"
}
fn description(&self) -> &str {
"does nothing"
}
fn parameters_schema(&self) -> Value {
serde_json::json!({"type": "object", "properties": {}})
}
async fn execute(&self, _args: Value) -> crate::Result<Value> {
Ok(serde_json::json!({"ok": true}))
}
}
let backend = Arc::new(MockBackend::new(vec![
MockResponse::tool_call("noop", serde_json::json!({})),
MockResponse::TextWithUsage {
text: "done".into(),
usage: TokenUsage {
prompt_tokens: 20,
completion_tokens: 8,
total_tokens: 28,
reasoning_tokens: 3,
action_tokens: 5,
},
},
]));
let mut registry = ToolRegistry::new();
registry.register(Arc::new(NoOpTool));
let registry = Arc::new(registry);
let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
let result = coordinator
.execute(None, "count tokens")
.await
.expect("execute should succeed");
assert_eq!(result.total_usage.prompt_tokens, 20);
assert_eq!(result.total_usage.completion_tokens, 8);
assert_eq!(result.total_usage.total_tokens, 28);
assert_eq!(result.total_usage.reasoning_tokens, 3);
assert_eq!(result.total_usage.action_tokens, 5);
assert_eq!(result.iterations, 2);
}
#[tokio::test]
async fn test_parallel_execution_dispatches_multiple_tools_in_one_turn() {
use crate::agent::backend::mock::MockBackend;
use crate::tools::Tool;
use async_trait::async_trait;
use serde_json::Value;
struct EchoTool {
suffix: &'static str,
}
#[async_trait]
impl Tool for EchoTool {
fn name(&self) -> &str {
self.suffix
}
fn description(&self) -> &str {
"echoes a suffix"
}
fn parameters_schema(&self) -> Value {
serde_json::json!({"type": "object", "properties": {}})
}
async fn execute(&self, _args: Value) -> crate::Result<Value> {
Ok(serde_json::json!({ "tool": self.suffix }))
}
}
let backend = Arc::new(MockBackend::with_multiple_tool_calls(vec![
("call_a", "echo_a", serde_json::json!({})),
("call_b", "echo_b", serde_json::json!({})),
]));
let mut registry = ToolRegistry::new();
registry.register(Arc::new(EchoTool { suffix: "echo_a" }));
registry.register(Arc::new(EchoTool { suffix: "echo_b" }));
let registry = Arc::new(registry);
let config = ToolCallingConfig {
parallel_execution: true,
..ToolCallingConfig::default()
};
let coordinator = ToolCoordinator::new(backend, registry, config);
let result = coordinator
.execute(None, "run both")
.await
.expect("parallel tool execution should succeed");
assert_eq!(result.tool_calls.len(), 2);
assert!(result.tool_calls.iter().all(|r| r.success));
let names: Vec<&str> = result.tool_calls.iter().map(|r| r.name.as_str()).collect();
assert!(names.contains(&"echo_a"));
assert!(names.contains(&"echo_b"));
assert_eq!(result.finish_reason, FinishReason::Stop);
assert_eq!(result.iterations, 2);
}
use async_trait::async_trait;
use serde_json::json;
use std::sync::Mutex;
struct MockTaskRunner {
dispatched: Mutex<Vec<String>>,
}
impl MockTaskRunner {
fn new() -> Self {
Self {
dispatched: Mutex::new(Vec::new()),
}
}
fn dispatched_ids(&self) -> Vec<String> {
self.dispatched.lock().unwrap().clone()
}
}
#[async_trait]
impl TaskRunner for MockTaskRunner {
async fn run(&self, task: &ScheduledTask) -> crate::Result<Value> {
self.dispatched.lock().unwrap().push(task.id.clone());
Ok(json!({
"id": task.id,
"agent": task.agent_type,
"assignment": task.assignment,
}))
}
}
#[tokio::test]
async fn schedule_empty_task_list_rejects_without_dispatch() {
let runner = Arc::new(MockTaskRunner::new());
let coordinator = TaskScheduleCoordinator::new(runner.clone());
let err = coordinator
.schedule(&[])
.await
.expect_err("empty task list should fail validation");
assert_eq!(err, ScheduleError::EmptyTaskList);
assert!(runner.dispatched_ids().is_empty());
}
#[tokio::test]
async fn schedule_invalid_agent_type_rejects_without_dispatch() {
let runner = Arc::new(MockTaskRunner::new());
let coordinator = TaskScheduleCoordinator::new(runner.clone());
let tasks = [ScheduledTask {
id: "AuthProbe".into(),
agent_type: "not_a_real_agent".into(),
assignment: "probe auth".into(),
}];
let err = coordinator
.schedule(&tasks)
.await
.expect_err("invalid agent type should fail validation");
assert_eq!(
err,
ScheduleError::InvalidAgentType("not_a_real_agent".into())
);
assert!(runner.dispatched_ids().is_empty());
}
#[tokio::test]
async fn schedule_duplicate_task_ids_rejects_without_dispatch() {
let runner = Arc::new(MockTaskRunner::new());
let coordinator = TaskScheduleCoordinator::new(runner.clone());
let tasks = [
ScheduledTask {
id: "DupId".into(),
agent_type: "explore".into(),
assignment: "first".into(),
},
ScheduledTask {
id: "DupId".into(),
agent_type: "plan".into(),
assignment: "second".into(),
},
];
let err = coordinator
.schedule(&tasks)
.await
.expect_err("duplicate ids should fail validation");
assert_eq!(err, ScheduleError::DuplicateTaskId("DupId".into()));
assert!(runner.dispatched_ids().is_empty());
}
#[tokio::test]
async fn schedule_valid_tasks_dispatches_via_mock_runner() {
let runner = Arc::new(MockTaskRunner::new());
let coordinator = TaskScheduleCoordinator::new(runner.clone());
let tasks = [
ScheduledTask {
id: "Alpha".into(),
agent_type: "explore".into(),
assignment: "scan src/".into(),
},
ScheduledTask {
id: "Beta".into(),
agent_type: "plan".into(),
assignment: "draft refactor".into(),
},
];
let results = coordinator
.schedule(&tasks)
.await
.expect("valid schedule should succeed");
assert_eq!(runner.dispatched_ids(), vec!["Alpha", "Beta"]);
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "Alpha");
assert_eq!(results[0].output["agent"], "explore");
assert_eq!(results[1].id, "Beta");
assert_eq!(results[1].output["agent"], "plan");
}
}