use crate::client::{
CompletionRequest, CompletionResponse, LlmClient, Message, Role, ToolChoice, ToolRequest,
ToolUseBlock,
};
use crate::error::Error;
use futures::future::BoxFuture;
use std::collections::HashMap;
use tracing::{error, warn};
#[derive(Debug, Clone)]
pub struct ToolError {
pub message: String,
}
impl std::fmt::Display for ToolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message)
}
}
pub struct ToolDef {
pub name: String,
pub description: String,
pub parameters_schema: serde_json::Value,
pub handler: Box<
dyn Fn(serde_json::Value) -> BoxFuture<'static, Result<serde_json::Value, ToolError>>
+ Send
+ Sync,
>,
}
pub fn make_handler<F, Fut>(
f: F,
) -> Box<
dyn Fn(serde_json::Value) -> BoxFuture<'static, Result<serde_json::Value, ToolError>>
+ Send
+ Sync,
>
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
{
Box::new(move |input| Box::pin(f(input)))
}
pub struct ToolRegistry {
tools: HashMap<String, ToolDef>,
max_iterations: u32,
}
impl ToolRegistry {
pub fn new(max_iterations: u32) -> Self {
Self {
tools: HashMap::new(),
max_iterations,
}
}
pub fn with_default_iterations() -> Self {
Self::new(10)
}
pub fn register(&mut self, tool: ToolDef) {
self.tools.insert(tool.name.clone(), tool);
}
fn build_request(&self, messages: Vec<Message>) -> CompletionRequest {
let tool_requests: Vec<ToolRequest> = self
.tools
.values()
.map(|t| ToolRequest {
name: t.name.clone(),
description: t.description.clone(),
parameters_schema: t.parameters_schema.clone(),
})
.collect();
CompletionRequest {
system: None,
messages,
max_tokens: 4096,
model_override: None,
schema: None,
tools: if tool_requests.is_empty() {
None
} else {
Some(tool_requests)
},
tool_choice: Some(ToolChoice::Auto),
}
}
fn result_to_message(block_id: &str, result: Result<serde_json::Value, ToolError>) -> Message {
let content = match result {
Ok(value) => value.to_string(),
Err(te) => te.message,
};
Message {
role: Role::Tool,
content,
tool_call_id: Some(block_id.to_string()),
}
}
pub async fn dispatch(
&self,
mut messages: Vec<Message>,
client: &dyn LlmClient,
) -> Result<Vec<Message>, Error> {
for iteration in 0..=self.max_iterations {
if iteration == 5 && self.max_iterations > 5 {
warn!(
iteration,
max = self.max_iterations,
"tool dispatch at iteration 5"
);
}
if iteration == self.max_iterations {
error!(
max_iterations = self.max_iterations,
"tool dispatch hit iteration limit"
);
return Err(Error::ToolIterationLimit(self.max_iterations));
}
let request = self.build_request(messages.clone());
let response = client.complete_with_tools(request).await?;
match response {
CompletionResponse::Text(text) => {
messages.push(Message {
role: Role::Assistant,
content: text,
tool_call_id: None,
});
return Ok(messages);
}
CompletionResponse::ToolUse {
blocks,
assistant_content,
} => {
messages.push(Message {
role: Role::Assistant,
content: assistant_content,
tool_call_id: None,
});
for block in &blocks {
let result = self.call_tool(block).await;
messages.push(Self::result_to_message(&block.id, result));
}
}
}
}
unreachable!()
}
async fn call_tool(&self, block: &ToolUseBlock) -> Result<serde_json::Value, ToolError> {
match self.tools.get(&block.name) {
None => Err(ToolError {
message: format!("tool '{}' is not registered", block.name),
}),
Some(tool) => (tool.handler)(block.input.clone()).await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::{CompletionRequest, TokenStream};
use async_trait::async_trait;
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};
#[tokio::test]
async fn tool_def_construction() {
let schema = serde_json::json!({"type": "object", "properties": {"x": {"type": "string"}}});
let def = ToolDef {
name: "my_tool".into(),
description: "does a thing".into(),
parameters_schema: schema.clone(),
handler: make_handler(
|_input| async move { Ok(serde_json::json!({"result": "done"})) },
),
};
assert_eq!(def.name, "my_tool");
assert_eq!(def.description, "does a thing");
assert_eq!(def.parameters_schema, schema);
let result = (def.handler)(serde_json::json!({})).await;
assert!(result.is_ok());
}
#[test]
fn tool_error_is_model_legible() {
let err = ToolError {
message: "domain message".into(),
};
assert_eq!(format!("{err}"), "domain message");
let debug_str = format!("{err:?}");
assert!(debug_str.contains("domain message"));
}
#[test]
fn tool_registry_requires_max_iterations() {
let r1 = ToolRegistry::new(3);
assert_eq!(r1.max_iterations, 3);
let r2 = ToolRegistry::with_default_iterations();
assert_eq!(r2.max_iterations, 10);
}
struct LoopingClient {
calls: Arc<AtomicU32>,
stop_after: u32,
tool_name: String,
}
#[async_trait]
impl LlmClient for LoopingClient {
fn default_model(&self) -> &str {
"test"
}
async fn complete(&self, _: CompletionRequest) -> Result<String, Error> {
Err(Error::Unsupported)
}
async fn complete_stream(&self, _: CompletionRequest) -> Result<TokenStream, Error> {
Err(Error::Unsupported)
}
async fn embed(&self, _: &str) -> Result<Vec<f32>, Error> {
Err(Error::Unsupported)
}
async fn complete_with_tools(
&self,
_: CompletionRequest,
) -> Result<CompletionResponse, Error> {
let n = self.calls.fetch_add(1, Ordering::SeqCst);
if n >= self.stop_after {
Ok(CompletionResponse::Text("done".into()))
} else {
Ok(CompletionResponse::ToolUse {
blocks: vec![ToolUseBlock {
id: format!("call_{n}"),
name: self.tool_name.clone(),
input: serde_json::json!({}),
}],
assistant_content: format!(
r#"[{{"type":"tool_use","id":"call_{n}","name":"{}","input":{{}}}}]"#,
self.tool_name
),
})
}
}
}
#[tokio::test]
async fn tool_registry_enforces_max_iterations() {
let registry = ToolRegistry::new(3);
let calls = Arc::new(AtomicU32::new(0));
let client = LoopingClient {
calls,
stop_after: 99, tool_name: "no_op".into(),
};
let result = registry.dispatch(vec![], &client).await;
assert!(
matches!(result, Err(Error::ToolIterationLimit(3))),
"expected ToolIterationLimit(3), got {result:?}"
);
}
#[tokio::test]
async fn dispatch_returns_on_text() {
let registry = ToolRegistry::new(5);
let calls = Arc::new(AtomicU32::new(0));
let client = LoopingClient {
calls,
stop_after: 0, tool_name: "no_op".into(),
};
let result = registry.dispatch(vec![], &client).await;
assert!(result.is_ok());
let messages = result.unwrap();
assert!(
messages
.iter()
.any(|m| matches!(m.role, Role::Assistant) && m.content == "done"),
"expected assistant message with 'done'"
);
}
#[tokio::test]
async fn dispatch_surfaces_tool_error() {
let mut registry = ToolRegistry::new(5);
registry.register(ToolDef {
name: "failing_tool".into(),
description: "always fails".into(),
parameters_schema: serde_json::json!({}),
handler: make_handler(|_| async move {
Err(ToolError {
message: "order not found".into(),
})
}),
});
let calls = Arc::new(AtomicU32::new(0));
let client = LoopingClient {
calls,
stop_after: 1, tool_name: "failing_tool".into(),
};
let result = registry.dispatch(vec![], &client).await;
assert!(
result.is_ok(),
"dispatch must complete even after tool error"
);
let messages = result.unwrap();
let tool_result = messages.iter().find(|m| matches!(m.role, Role::Tool));
assert!(
tool_result.is_some(),
"expected a Role::Tool result message"
);
let content = &tool_result.unwrap().content;
assert!(
content.contains("order not found"),
"ToolError message must appear in tool result, got: {content}"
);
assert!(
!content.contains("panicked at"),
"tool result must not contain panic text"
);
}
#[tokio::test]
async fn dispatch_includes_assistant_turn_before_tool_results() {
let mut registry = ToolRegistry::new(5);
registry.register(ToolDef {
name: "echo".into(),
description: "echoes input".into(),
parameters_schema: serde_json::json!({}),
handler: make_handler(|_| async move { Ok(serde_json::json!({"result": "ok"})) }),
});
let calls = Arc::new(AtomicU32::new(0));
let client = LoopingClient {
calls,
stop_after: 1,
tool_name: "echo".into(),
};
let messages = registry.dispatch(vec![], &client).await.unwrap();
let assistant_pos = messages
.iter()
.position(|m| matches!(m.role, Role::Assistant) && m.content.contains("tool_use"))
.expect("must have an assistant turn with tool_use content");
let tool_result_pos = messages
.iter()
.position(|m| matches!(m.role, Role::Tool))
.expect("must have a tool result message");
assert!(
assistant_pos < tool_result_pos,
"assistant tool-use turn (pos {assistant_pos}) must precede tool result (pos {tool_result_pos})"
);
let tool_msg = &messages[tool_result_pos];
assert!(
tool_msg.tool_call_id.is_some(),
"tool result message must carry tool_call_id"
);
assert!(
!tool_msg.content.contains("call_"),
"tool_call_id must not be embedded in content string, got: {}",
tool_msg.content
);
}
#[tokio::test]
async fn dispatch_surfaces_unknown_tool_as_tool_error() {
let registry = ToolRegistry::new(5);
let calls = Arc::new(AtomicU32::new(0));
let client = LoopingClient {
calls,
stop_after: 1,
tool_name: "nonexistent_tool".into(),
};
let result = registry.dispatch(vec![], &client).await;
assert!(
result.is_ok(),
"dispatch must not abort for unknown tool; got {result:?}"
);
let messages = result.unwrap();
let tool_msg = messages
.iter()
.find(|m| matches!(m.role, Role::Tool))
.expect("must have a tool result message for the unknown tool");
assert!(
tool_msg.content.contains("not registered"),
"unknown tool error must surface to LLM as a message, got: {}",
tool_msg.content
);
}
}