use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::types::{
ContentPart, FunctionCall, FunctionTool, InputItem, MessageRole, OutputItem, Response, Tool,
ToolCall, ToolCallDelta, ToolType,
};
pub const USER_ASK_TOOL_NAME: &str = "user_ask";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct UserAskOption {
pub label: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct UserAskArgs {
pub question: String,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub options: Vec<UserAskOption>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_freeform: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct UserAskResponse {
pub answer: String,
pub is_freeform: bool,
}
pub trait ResponseExt {
fn has_tool_calls(&self) -> bool;
fn first_tool_call(&self) -> Option<&ToolCall>;
}
impl ResponseExt for Response {
fn has_tool_calls(&self) -> bool {
self.output.iter().any(|item| match item {
OutputItem::Message { tool_calls, .. } => {
tool_calls.as_ref().is_some_and(|t| !t.is_empty())
}
})
}
fn first_tool_call(&self) -> Option<&ToolCall> {
for item in &self.output {
if let OutputItem::Message {
tool_calls: Some(tool_calls),
..
} = item
{
if let Some(first) = tool_calls.first() {
return Some(first);
}
}
}
None
}
}
pub fn tool_result_message(
tool_call_id: impl Into<String>,
result: impl Into<String>,
) -> InputItem {
InputItem::tool_result(tool_call_id, result)
}
pub fn tool_result_message_json<T: serde::Serialize>(
tool_call_id: impl Into<String>,
result: &T,
) -> Result<InputItem, serde_json::Error> {
let content = serde_json::to_string(result)?;
Ok(tool_result_message(tool_call_id, content))
}
pub fn respond_to_tool_call(call: &ToolCall, result: impl Into<String>) -> InputItem {
tool_result_message(&call.id, result)
}
pub fn respond_to_tool_call_json<T: serde::Serialize>(
call: &ToolCall,
result: &T,
) -> Result<InputItem, serde_json::Error> {
tool_result_message_json(&call.id, result)
}
pub fn user_ask_tool() -> Tool {
Tool::function(
USER_ASK_TOOL_NAME,
Some("Ask the user a clarifying question.".to_string()),
Some(user_ask_schema()),
)
}
pub fn is_user_ask_tool_call(call: &ToolCall) -> bool {
call.kind == ToolType::Function
&& call
.function
.as_ref()
.is_some_and(|f| f.name == USER_ASK_TOOL_NAME)
}
pub fn parse_user_ask_args(call: &ToolCall) -> Result<UserAskArgs, ToolArgsError> {
parse_and_validate_tool_args(call)
}
pub fn serialize_user_ask_result(result: &UserAskResponse) -> Result<String, ToolArgsError> {
if result.answer.trim().is_empty() {
return Err(ToolArgsError {
message: "user_ask answer is required".to_string(),
tool_call_id: "".to_string(),
tool_name: USER_ASK_TOOL_NAME.to_string(),
raw_arguments: "".to_string(),
});
}
serde_json::to_string(result).map_err(|e| ToolArgsError {
message: format!("failed to serialize user.ask result: {e}"),
tool_call_id: "".to_string(),
tool_name: USER_ASK_TOOL_NAME.to_string(),
raw_arguments: "".to_string(),
})
}
pub fn user_ask_result_freeform(answer: impl Into<String>) -> Result<String, ToolArgsError> {
serialize_user_ask_result(&UserAskResponse {
answer: answer.into(),
is_freeform: true,
})
}
pub fn user_ask_result_choice(answer: impl Into<String>) -> Result<String, ToolArgsError> {
serialize_user_ask_result(&UserAskResponse {
answer: answer.into(),
is_freeform: false,
})
}
fn user_ask_schema() -> Value {
serde_json::json!({
"type": "object",
"properties": {
"question": {
"type": "string",
"minLength": 1,
"description": "The question to ask the user."
},
"options": {
"type": "array",
"items": {
"type": "object",
"properties": {
"label": { "type": "string", "minLength": 1 },
"description": { "type": "string" }
},
"required": ["label"]
},
"description": "Optional multiple choice options."
},
"allow_freeform": {
"type": "boolean",
"default": true,
"description": "Allow user to type a custom response."
}
},
"required": ["question"]
})
}
pub fn assistant_message_with_tool_calls(
content: impl Into<String>,
tool_calls: Vec<ToolCall>,
) -> InputItem {
InputItem::Message {
role: MessageRole::Assistant,
content: vec![ContentPart::text(content)],
tool_calls: Some(tool_calls),
tool_call_id: None,
}
}
#[derive(Debug, Default)]
pub struct ToolCallAccumulator {
calls: HashMap<u32, ToolCall>,
}
impl ToolCallAccumulator {
pub fn new() -> Self {
Self {
calls: HashMap::new(),
}
}
pub fn process_delta(&mut self, delta: &ToolCallDelta) -> bool {
if let Some(existing) = self.calls.get_mut(&delta.index) {
if let Some(ref func_delta) = delta.function {
if let Some(ref mut func) = existing.function {
if let Some(ref name) = func_delta.name {
func.name = name.clone();
}
if let Some(ref args) = func_delta.arguments {
func.arguments.push_str(args);
}
}
}
false
} else {
let function = delta.function.as_ref().map(|f| FunctionCall {
name: f.name.clone().unwrap_or_default(),
arguments: f.arguments.clone().unwrap_or_default(),
});
self.calls.insert(
delta.index,
ToolCall {
id: delta.id.clone().unwrap_or_default(),
kind: delta
.type_
.as_ref()
.map(|t| match t.as_str() {
"function" => ToolType::Function,
"x_search" => ToolType::XSearch,
"code_execution" => ToolType::CodeExecution,
_ => ToolType::Function,
})
.unwrap_or(ToolType::Function),
function,
},
);
true
}
}
pub fn get_tool_calls(&self) -> Vec<ToolCall> {
if self.calls.is_empty() {
return Vec::new();
}
let max_idx = self.calls.keys().max().copied().unwrap_or(0);
let mut result = Vec::with_capacity(self.calls.len());
for i in 0..=max_idx {
if let Some(call) = self.calls.get(&i) {
result.push(call.clone());
}
}
result
}
pub fn get_tool_call(&self, index: u32) -> Option<&ToolCall> {
self.calls.get(&index)
}
pub fn reset(&mut self) {
self.calls.clear();
}
}
#[derive(Debug, Clone)]
pub struct ToolArgsError {
pub message: String,
pub tool_call_id: String,
pub tool_name: String,
pub raw_arguments: String,
}
impl std::fmt::Display for ToolArgsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for ToolArgsError {}
pub fn parse_tool_args<T>(call: &ToolCall) -> Result<T, ToolArgsError>
where
T: serde::de::DeserializeOwned,
{
let tool_name = call
.function
.as_ref()
.map(|f| f.name.clone())
.unwrap_or_default();
let raw_args = call
.function
.as_ref()
.map(|f| f.arguments.clone())
.unwrap_or_default();
let json_str = if raw_args.is_empty() { "{}" } else { &raw_args };
serde_json::from_str(json_str).map_err(|e| ToolArgsError {
message: format!("failed to parse arguments for tool '{}': {}", tool_name, e),
tool_call_id: call.id.clone(),
tool_name,
raw_arguments: raw_args,
})
}
pub type ParseResult<T> = Result<T, ToolArgsError>;
pub trait ValidateArgs {
fn validate(&self) -> Result<(), String>;
}
impl ValidateArgs for UserAskArgs {
fn validate(&self) -> Result<(), String> {
if self.question.trim().is_empty() {
return Err("user_ask question is required".to_string());
}
for opt in &self.options {
if opt.label.trim().is_empty() {
return Err("user_ask option label is required".to_string());
}
}
let allow_freeform = self.allow_freeform.unwrap_or(true);
if !allow_freeform && self.options.is_empty() {
return Err("user_ask requires options when allow_freeform=false".to_string());
}
Ok(())
}
}
pub fn parse_and_validate_tool_args<T>(call: &ToolCall) -> Result<T, ToolArgsError>
where
T: serde::de::DeserializeOwned + ValidateArgs,
{
let args: T = parse_tool_args(call)?;
args.validate().map_err(|e| {
let tool_name = call
.function
.as_ref()
.map(|f| f.name.clone())
.unwrap_or_default();
let raw_args = call
.function
.as_ref()
.map(|f| f.arguments.clone())
.unwrap_or_default();
ToolArgsError {
message: format!("invalid arguments for tool '{}': {}", tool_name, e),
tool_call_id: call.id.clone(),
tool_name,
raw_arguments: raw_args,
}
})?;
Ok(args)
}
#[derive(Debug, Clone)]
pub struct UnknownToolError {
pub tool_name: String,
pub available: Vec<String>,
}
impl std::fmt::Display for UnknownToolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.available.is_empty() {
write!(
f,
"unknown tool: '{}'. No tools registered.",
self.tool_name
)
} else {
write!(
f,
"unknown tool: '{}'. Available: {}",
self.tool_name,
self.available.join(", ")
)
}
}
}
impl std::error::Error for UnknownToolError {}
#[derive(Debug, Clone)]
pub struct ToolExecutionResult {
pub tool_call_id: String,
pub tool_name: String,
pub result: Option<Value>,
pub error: Option<String>,
pub is_retryable: bool,
}
impl ToolExecutionResult {
pub fn is_ok(&self) -> bool {
self.error.is_none()
}
pub fn is_err(&self) -> bool {
self.error.is_some()
}
}
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub type ToolHandler =
Arc<dyn Fn(Value, ToolCall) -> BoxFuture<'static, Result<Value, String>> + Send + Sync>;
pub struct ToolRegistry {
handlers: HashMap<String, ToolHandler>,
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
}
}
pub fn register(mut self, name: impl Into<String>, handler: ToolHandler) -> Self {
self.handlers.insert(name.into(), handler);
self
}
pub fn register_mut(&mut self, name: impl Into<String>, handler: ToolHandler) -> &mut Self {
self.handlers.insert(name.into(), handler);
self
}
pub fn unregister(&mut self, name: &str) -> bool {
self.handlers.remove(name).is_some()
}
pub fn has(&self, name: &str) -> bool {
self.handlers.contains_key(name)
}
pub fn registered_tools(&self) -> Vec<String> {
self.handlers.keys().cloned().collect()
}
pub async fn execute(&self, call: &ToolCall) -> ToolExecutionResult {
let tool_name = call
.function
.as_ref()
.map(|f| f.name.clone())
.unwrap_or_default();
let handler = match self.handlers.get(&tool_name) {
Some(h) => h,
None => {
let error_msg = UnknownToolError {
tool_name: tool_name.clone(),
available: self.registered_tools(),
}
.to_string();
return ToolExecutionResult {
tool_call_id: call.id.clone(),
tool_name,
result: None,
error: Some(error_msg),
is_retryable: false, };
}
};
let args: Value = match call.function.as_ref() {
Some(f) if f.arguments.is_empty() => Value::Object(Default::default()),
Some(f) => match serde_json::from_str(&f.arguments) {
Ok(v) => v,
Err(e) => {
return ToolExecutionResult {
tool_call_id: call.id.clone(),
tool_name,
result: None,
error: Some(format!("failed to parse tool arguments: {}", e)),
is_retryable: true, };
}
},
None => Value::Object(Default::default()),
};
match handler(args, call.clone()).await {
Ok(result) => ToolExecutionResult {
tool_call_id: call.id.clone(),
tool_name,
result: Some(result),
error: None,
is_retryable: false,
},
Err(e) => {
let is_retryable = e.starts_with("invalid arguments")
|| e.starts_with("failed to parse")
|| e.contains("validation");
ToolExecutionResult {
tool_call_id: call.id.clone(),
tool_name,
result: None,
error: Some(e),
is_retryable,
}
}
}
}
pub async fn execute_all(&self, calls: &[ToolCall]) -> Vec<ToolExecutionResult> {
let futures: Vec<_> = calls.iter().map(|call| self.execute(call)).collect();
futures::future::join_all(futures).await
}
pub fn results_to_messages(&self, results: &[ToolExecutionResult]) -> Vec<InputItem> {
results
.iter()
.map(|r| {
let content = if let Some(ref error) = r.error {
format!("Error: {}", error)
} else if let Some(ref result) = r.result {
match result {
Value::String(s) => s.clone(),
_ => serde_json::to_string(result).unwrap_or_default(),
}
} else {
String::new()
};
tool_result_message(&r.tool_call_id, content)
})
.collect()
}
}
#[macro_export]
macro_rules! tool_handler {
($closure:expr) => {{
use std::sync::Arc;
let handler: $crate::tools::ToolHandler =
Arc::new(move |args, call| Box::pin($closure(args, call)));
handler
}};
}
pub fn sync_handler<F>(f: F) -> ToolHandler
where
F: Fn(Value, ToolCall) -> Result<Value, String> + Send + Sync + 'static,
{
Arc::new(move |args, call| {
let result = f(args, call);
Box::pin(async move { result })
})
}
pub fn function_tool_from_type<T: schemars::JsonSchema>(
name: impl Into<String>,
description: impl Into<String>,
) -> Tool {
let schema = schemars::schema_for!(T);
let parameters = serde_json::to_value(&schema).ok();
Tool {
kind: ToolType::Function,
function: Some(FunctionTool {
name: name.into(),
description: Some(description.into()),
parameters,
}),
x_search: None,
code_execution: None,
}
}
pub struct TypedTool<T> {
name: String,
description: String,
_marker: PhantomData<T>,
}
#[derive(Debug)]
pub struct TypedToolCall<T> {
pub call: ToolCall,
pub args: T,
}
impl<T> TypedTool<T>
where
T: schemars::JsonSchema + serde::de::DeserializeOwned,
{
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
_marker: PhantomData,
}
}
pub fn definition(&self) -> Tool {
function_tool_from_type::<T>(&self.name, &self.description)
}
pub fn name(&self) -> &str {
&self.name
}
pub fn parse_call(&self, call: &ToolCall) -> Result<TypedToolCall<T>, ToolArgsError> {
let func = call.function.as_ref().ok_or_else(|| ToolArgsError {
message: "tool call missing function".to_string(),
tool_call_id: call.id.clone(),
tool_name: self.name.clone(),
raw_arguments: "".to_string(),
})?;
if func.name != self.name {
return Err(ToolArgsError {
message: format!("expected tool '{}', got '{}'", self.name, func.name),
tool_call_id: call.id.clone(),
tool_name: self.name.clone(),
raw_arguments: func.arguments.clone(),
});
}
let args: T = parse_tool_args(call)?;
Ok(TypedToolCall {
call: call.clone(),
args,
})
}
}
pub trait ToolSchema: schemars::JsonSchema + Sized {
fn as_tool(name: impl Into<String>, description: impl Into<String>) -> Tool {
function_tool_from_type::<Self>(name, description)
}
}
impl<T: schemars::JsonSchema + Sized> ToolSchema for T {}
pub struct ToolBuilder {
definitions: Vec<Tool>,
registry: ToolRegistry,
}
impl Default for ToolBuilder {
fn default() -> Self {
Self::new()
}
}
impl ToolBuilder {
pub fn new() -> Self {
Self {
definitions: Vec::new(),
registry: ToolRegistry::new(),
}
}
pub fn add_sync<T, F>(
mut self,
name: impl Into<String>,
description: impl Into<String>,
handler: F,
) -> Self
where
T: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
F: Fn(T, ToolCall) -> Result<Value, String> + Send + Sync + 'static,
{
let name_str = name.into();
let desc_str = description.into();
let tool = function_tool_from_type::<T>(&name_str, &desc_str);
self.definitions.push(tool);
let handler: ToolHandler = Arc::new(move |args: Value, call: ToolCall| {
let parsed: Result<T, _> = serde_json::from_value(args);
match parsed {
Ok(typed_args) => {
let result = handler(typed_args, call);
Box::pin(async move { result })
}
Err(e) => Box::pin(async move { Err(format!("failed to parse arguments: {}", e)) }),
}
});
self.registry.register_mut(&name_str, handler);
self
}
pub fn add_async<T, F>(
mut self,
name: impl Into<String>,
description: impl Into<String>,
handler: F,
) -> Self
where
T: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
F: Fn(T, ToolCall) -> BoxFuture<'static, Result<Value, String>> + Send + Sync + 'static,
{
let name_str = name.into();
let desc_str = description.into();
let tool = function_tool_from_type::<T>(&name_str, &desc_str);
self.definitions.push(tool);
let handler = Arc::new(handler);
let async_handler: ToolHandler = Arc::new(move |args: Value, call: ToolCall| {
let parsed: Result<T, _> = serde_json::from_value(args);
let handler = handler.clone();
match parsed {
Ok(typed_args) => handler(typed_args, call),
Err(e) => Box::pin(async move { Err(format!("failed to parse arguments: {}", e)) }),
}
});
self.registry.register_mut(&name_str, async_handler);
self
}
pub fn add_raw(
mut self,
name: impl Into<String>,
description: impl Into<String>,
parameters: Option<Value>,
handler: ToolHandler,
) -> Self {
let name_str = name.into();
let desc_str = description.into();
let tool = Tool {
kind: ToolType::Function,
function: Some(FunctionTool {
name: name_str.clone(),
description: Some(desc_str),
parameters,
}),
x_search: None,
code_execution: None,
};
self.definitions.push(tool);
self.registry.register_mut(&name_str, handler);
self
}
pub fn definitions(&self) -> Vec<Tool> {
self.definitions.clone()
}
pub fn registry(&self) -> &ToolRegistry {
&self.registry
}
pub fn build(self) -> (Vec<Tool>, ToolRegistry) {
(self.definitions, self.registry)
}
}
pub fn format_tool_error_for_model(result: &ToolExecutionResult) -> String {
let error_msg = result.error.as_deref().unwrap_or("unknown error");
let mut lines = format!("Tool call error for '{}': {}", result.tool_name, error_msg);
if result.is_retryable {
lines.push_str("\n\nPlease correct the arguments and try again.");
}
lines
}
pub fn has_retryable_errors(results: &[ToolExecutionResult]) -> bool {
results.iter().any(|r| r.error.is_some() && r.is_retryable)
}
pub fn get_retryable_errors(results: &[ToolExecutionResult]) -> Vec<&ToolExecutionResult> {
results
.iter()
.filter(|r| r.error.is_some() && r.is_retryable)
.collect()
}
pub fn create_retry_messages(results: &[ToolExecutionResult]) -> Vec<InputItem> {
results
.iter()
.filter(|r| r.error.is_some() && r.is_retryable)
.map(|r| tool_result_message(&r.tool_call_id, format_tool_error_for_model(r)))
.collect()
}
pub struct RetryOptions<F>
where
F: Fn(Vec<InputItem>, usize) -> BoxFuture<'static, Result<Vec<ToolCall>, String>>,
{
pub max_retries: usize,
pub on_retry: F,
}
pub async fn execute_with_retry<F>(
registry: &ToolRegistry,
tool_calls: Vec<ToolCall>,
options: RetryOptions<F>,
) -> Vec<ToolExecutionResult>
where
F: Fn(Vec<InputItem>, usize) -> BoxFuture<'static, Result<Vec<ToolCall>, String>>,
{
let mut current_calls = tool_calls;
let mut attempt = 0;
let mut successful_results: HashMap<String, ToolExecutionResult> = HashMap::new();
loop {
let results = registry.execute_all(¤t_calls).await;
for result in &results {
if result.error.is_none() || !result.is_retryable {
successful_results.insert(result.tool_call_id.clone(), result.clone());
}
}
if !has_retryable_errors(&results) || attempt >= options.max_retries {
for result in results {
if result.error.is_some() && result.is_retryable {
successful_results.insert(result.tool_call_id.clone(), result);
}
}
return successful_results.into_values().collect();
}
let error_messages = create_retry_messages(&results);
let retryable: Vec<_> = results
.into_iter()
.filter(|r| r.error.is_some() && r.is_retryable)
.collect();
attempt += 1;
match (options.on_retry)(error_messages, attempt).await {
Ok(new_calls) if !new_calls.is_empty() => {
current_calls = new_calls;
}
Ok(_) => {
for result in retryable {
successful_results.insert(result.tool_call_id.clone(), result);
}
return successful_results.into_values().collect();
}
Err(_) => {
for result in retryable {
successful_results.insert(result.tool_call_id.clone(), result);
}
return successful_results.into_values().collect();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::FunctionCallDelta;
#[test]
fn test_function_tool_creation() {
let tool = Tool::function("get_weather", Some("Get the weather".into()), None);
assert_eq!(tool.kind, ToolType::Function);
assert_eq!(tool.function.as_ref().unwrap().name, "get_weather");
}
#[test]
fn test_tool_result_message() {
let msg = tool_result_message("call_123", "sunny");
match msg {
InputItem::Message {
role,
content,
tool_call_id,
..
} => {
assert_eq!(role, crate::types::MessageRole::Tool);
let text = content
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
ContentPart::File { .. } => None,
})
.collect::<String>();
assert_eq!(text, "sunny");
assert_eq!(tool_call_id.as_deref(), Some("call_123"));
}
}
}
#[test]
fn test_tool_call_accumulator() {
let mut acc = ToolCallAccumulator::new();
let delta1 = ToolCallDelta {
index: 0,
id: Some("call_1".to_string()),
type_: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some("get_weather".to_string()),
arguments: Some("{\"loc".to_string()),
}),
};
assert!(acc.process_delta(&delta1));
let delta2 = ToolCallDelta {
index: 0,
id: None,
type_: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some("ation\":\"NYC\"}".to_string()),
}),
};
assert!(!acc.process_delta(&delta2));
let calls = acc.get_tool_calls();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].id, "call_1");
assert_eq!(
calls[0].function.as_ref().unwrap().arguments,
"{\"location\":\"NYC\"}"
);
}
#[test]
fn test_tool_registry_has_and_registered_tools() {
let registry = ToolRegistry::new()
.register("tool_a", sync_handler(|_, _| Ok(Value::Null)))
.register("tool_b", sync_handler(|_, _| Ok(Value::Null)));
assert!(registry.has("tool_a"));
assert!(registry.has("tool_b"));
assert!(!registry.has("tool_c"));
let tools = registry.registered_tools();
assert_eq!(tools.len(), 2);
assert!(tools.contains(&"tool_a".to_string()));
assert!(tools.contains(&"tool_b".to_string()));
}
#[test]
fn test_tool_registry_unregister() {
let mut registry = ToolRegistry::new();
registry.register_mut("tool_a", sync_handler(|_, _| Ok(Value::Null)));
assert!(registry.has("tool_a"));
assert!(registry.unregister("tool_a"));
assert!(!registry.has("tool_a"));
assert!(!registry.unregister("tool_a")); }
#[tokio::test]
async fn test_tool_registry_execute_success() {
let registry = ToolRegistry::new().register(
"get_weather",
sync_handler(|args, _call| {
let location = args
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
Ok(serde_json::json!({ "temp": 72, "location": location }))
}),
);
let call = ToolCall {
id: "call_123".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"location":"NYC"}"#.to_string(),
}),
};
let result = registry.execute(&call).await;
assert!(result.is_ok());
assert_eq!(result.tool_call_id, "call_123");
assert_eq!(result.tool_name, "get_weather");
let value = result.result.unwrap();
assert_eq!(value.get("temp").unwrap(), 72);
assert_eq!(value.get("location").unwrap(), "NYC");
}
#[tokio::test]
async fn test_tool_registry_execute_unknown_tool() {
let registry =
ToolRegistry::new().register("known_tool", sync_handler(|_, _| Ok(Value::Null)));
let call = ToolCall {
id: "call_456".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "unknown_tool".to_string(),
arguments: "{}".to_string(),
}),
};
let result = registry.execute(&call).await;
assert!(result.is_err());
assert!(result.error.as_ref().unwrap().contains("unknown tool"));
assert!(result.error.as_ref().unwrap().contains("unknown_tool"));
assert!(result.error.as_ref().unwrap().contains("known_tool"));
}
#[tokio::test]
async fn test_tool_registry_execute_handler_error() {
let registry = ToolRegistry::new().register(
"failing_tool",
sync_handler(|_, _| Err("something went wrong".to_string())),
);
let call = ToolCall {
id: "call_789".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "failing_tool".to_string(),
arguments: "{}".to_string(),
}),
};
let result = registry.execute(&call).await;
assert!(result.is_err());
assert_eq!(result.error.as_ref().unwrap(), "something went wrong");
}
#[tokio::test]
async fn test_tool_registry_execute_malformed_json() {
let registry =
ToolRegistry::new().register("my_tool", sync_handler(|_, _| Ok(Value::Null)));
let call = ToolCall {
id: "call_bad".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "my_tool".to_string(),
arguments: "{not valid json".to_string(),
}),
};
let result = registry.execute(&call).await;
assert!(result.is_err());
assert!(result
.error
.as_ref()
.unwrap()
.contains("failed to parse tool arguments"));
}
#[tokio::test]
async fn test_tool_registry_execute_all() {
let registry = ToolRegistry::new()
.register(
"tool_a",
sync_handler(|_, _| Ok(serde_json::json!("result_a"))),
)
.register(
"tool_b",
sync_handler(|_, _| Ok(serde_json::json!("result_b"))),
);
let calls = vec![
ToolCall {
id: "call_1".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "tool_a".to_string(),
arguments: "{}".to_string(),
}),
},
ToolCall {
id: "call_2".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "tool_b".to_string(),
arguments: "{}".to_string(),
}),
},
];
let results = registry.execute_all(&calls).await;
assert_eq!(results.len(), 2);
assert_eq!(results[0].tool_call_id, "call_1");
assert_eq!(results[1].tool_call_id, "call_2");
assert!(results[0].is_ok());
assert!(results[1].is_ok());
}
#[tokio::test]
async fn test_tool_registry_results_to_messages() {
let registry = ToolRegistry::new()
.register(
"success_tool",
sync_handler(|_, _| Ok(serde_json::json!({"data": "success"}))),
)
.register("error_tool", sync_handler(|_, _| Err("failed".to_string())));
let calls = vec![
ToolCall {
id: "call_1".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "success_tool".to_string(),
arguments: "{}".to_string(),
}),
},
ToolCall {
id: "call_2".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "error_tool".to_string(),
arguments: "{}".to_string(),
}),
},
];
let results = registry.execute_all(&calls).await;
let messages = registry.results_to_messages(&results);
assert_eq!(messages.len(), 2);
for (idx, expected_call_id) in [("call_1", "success"), ("call_2", "Error:")]
.into_iter()
.enumerate()
{
let msg = &messages[idx];
match msg {
InputItem::Message {
role,
content,
tool_call_id,
..
} => {
assert_eq!(*role, crate::types::MessageRole::Tool);
assert_eq!(tool_call_id.as_deref(), Some(expected_call_id.0));
let text = content
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
ContentPart::File { .. } => None,
})
.collect::<String>();
assert!(
text.contains(expected_call_id.1) || text.starts_with(expected_call_id.1),
"unexpected content: {}",
text
);
}
}
}
}
#[test]
fn test_unknown_tool_error_display() {
let err = UnknownToolError {
tool_name: "foo".to_string(),
available: vec!["bar".to_string(), "baz".to_string()],
};
assert_eq!(err.to_string(), "unknown tool: 'foo'. Available: bar, baz");
let err_empty = UnknownToolError {
tool_name: "foo".to_string(),
available: vec![],
};
assert_eq!(
err_empty.to_string(),
"unknown tool: 'foo'. No tools registered."
);
}
#[derive(Debug, serde::Deserialize, PartialEq)]
struct WeatherArgs {
location: String,
#[serde(default)]
unit: Option<String>,
}
#[test]
fn test_parse_tool_args_success() {
let call = ToolCall {
id: "call_1".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"location":"NYC","unit":"celsius"}"#.to_string(),
}),
};
let args: WeatherArgs = parse_tool_args(&call).unwrap();
assert_eq!(args.location, "NYC");
assert_eq!(args.unit, Some("celsius".to_string()));
}
#[test]
fn test_parse_tool_args_with_defaults() {
let call = ToolCall {
id: "call_2".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"location":"London"}"#.to_string(),
}),
};
let args: WeatherArgs = parse_tool_args(&call).unwrap();
assert_eq!(args.location, "London");
assert_eq!(args.unit, None); }
#[test]
fn test_parse_tool_args_empty_arguments() {
let call = ToolCall {
id: "call_3".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "list_items".to_string(),
arguments: "".to_string(),
}),
};
let args: std::collections::HashMap<String, String> = parse_tool_args(&call).unwrap();
assert!(args.is_empty());
}
#[test]
fn test_parse_tool_args_invalid_json() {
let call = ToolCall {
id: "call_4".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "get_weather".to_string(),
arguments: "{not valid json".to_string(),
}),
};
let err = parse_tool_args::<WeatherArgs>(&call).unwrap_err();
assert!(err.message.contains("failed to parse arguments"));
assert!(err.message.contains("get_weather"));
assert_eq!(err.tool_call_id, "call_4");
assert_eq!(err.tool_name, "get_weather");
}
#[test]
fn test_parse_tool_args_missing_required_field() {
let call = ToolCall {
id: "call_5".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"unit":"celsius"}"#.to_string(), }),
};
let err = parse_tool_args::<WeatherArgs>(&call).unwrap_err();
assert!(err.message.contains("failed to parse arguments"));
}
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
struct ReadFileArgs {
path: String,
}
#[test]
fn test_typed_tool_parse_call_success() {
let tool = TypedTool::<ReadFileArgs>::new("read_file", "Read a file");
let call = ToolCall {
id: "call_typed_1".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "read_file".to_string(),
arguments: r#"{"path":"/tmp/config.json"}"#.to_string(),
}),
};
let typed = tool.parse_call(&call).unwrap();
assert_eq!(typed.args.path, "/tmp/config.json");
assert_eq!(typed.call.id, "call_typed_1");
}
#[test]
fn test_typed_tool_parse_call_name_mismatch() {
let tool = TypedTool::<ReadFileArgs>::new("read_file", "Read a file");
let call = ToolCall {
id: "call_typed_2".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "other_tool".to_string(),
arguments: r#"{"path":"/tmp/config.json"}"#.to_string(),
}),
};
let err = tool.parse_call(&call).unwrap_err();
assert!(err.message.contains("expected tool 'read_file'"));
}
#[test]
fn test_typed_tool_parse_call_missing_function() {
let tool = TypedTool::<ReadFileArgs>::new("read_file", "Read a file");
let call = ToolCall {
id: "call_typed_3".to_string(),
kind: ToolType::Function,
function: None,
};
let err = tool.parse_call(&call).unwrap_err();
assert!(err.message.contains("tool call missing function"));
}
#[derive(Debug, serde::Deserialize)]
struct ValidatedArgs {
value: i32,
}
impl ValidateArgs for ValidatedArgs {
fn validate(&self) -> Result<(), String> {
if self.value < 0 {
return Err("value must be non-negative".to_string());
}
if self.value > 100 {
return Err("value must be at most 100".to_string());
}
Ok(())
}
}
#[test]
fn test_parse_and_validate_tool_args_success() {
let call = ToolCall {
id: "call_6".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "set_value".to_string(),
arguments: r#"{"value":50}"#.to_string(),
}),
};
let args: ValidatedArgs = parse_and_validate_tool_args(&call).unwrap();
assert_eq!(args.value, 50);
}
#[test]
fn test_parse_and_validate_tool_args_validation_failure() {
let call = ToolCall {
id: "call_7".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "set_value".to_string(),
arguments: r#"{"value":-5}"#.to_string(),
}),
};
let err = parse_and_validate_tool_args::<ValidatedArgs>(&call).unwrap_err();
assert!(err.message.contains("invalid arguments"));
assert!(err.message.contains("value must be non-negative"));
}
#[test]
fn test_tool_args_error_display() {
let err = ToolArgsError {
message: "test error message".to_string(),
tool_call_id: "call_123".to_string(),
tool_name: "my_tool".to_string(),
raw_arguments: "{}".to_string(),
};
assert_eq!(err.to_string(), "test error message");
}
#[test]
fn test_format_tool_error_for_model_retryable() {
let result = ToolExecutionResult {
tool_call_id: "call_1".to_string(),
tool_name: "my_tool".to_string(),
result: None,
error: Some("failed to parse arguments".to_string()),
is_retryable: true,
};
let formatted = format_tool_error_for_model(&result);
assert!(formatted.contains("Tool call error for 'my_tool'"));
assert!(formatted.contains("failed to parse arguments"));
assert!(formatted.contains("Please correct the arguments and try again"));
}
#[test]
fn test_format_tool_error_for_model_not_retryable() {
let result = ToolExecutionResult {
tool_call_id: "call_1".to_string(),
tool_name: "my_tool".to_string(),
result: None,
error: Some("internal error".to_string()),
is_retryable: false,
};
let formatted = format_tool_error_for_model(&result);
assert!(formatted.contains("Tool call error for 'my_tool'"));
assert!(formatted.contains("internal error"));
assert!(!formatted.contains("Please correct the arguments"));
}
#[test]
fn test_has_retryable_errors() {
let results = vec![
ToolExecutionResult {
tool_call_id: "call_1".to_string(),
tool_name: "tool_a".to_string(),
result: Some(Value::String("ok".to_string())),
error: None,
is_retryable: false,
},
ToolExecutionResult {
tool_call_id: "call_2".to_string(),
tool_name: "tool_b".to_string(),
result: None,
error: Some("parse error".to_string()),
is_retryable: true,
},
];
assert!(has_retryable_errors(&results));
let results_no_retry = vec![ToolExecutionResult {
tool_call_id: "call_1".to_string(),
tool_name: "tool_a".to_string(),
result: Some(Value::String("ok".to_string())),
error: None,
is_retryable: false,
}];
assert!(!has_retryable_errors(&results_no_retry));
let results_not_retryable = vec![ToolExecutionResult {
tool_call_id: "call_1".to_string(),
tool_name: "tool_a".to_string(),
result: None,
error: Some("internal error".to_string()),
is_retryable: false,
}];
assert!(!has_retryable_errors(&results_not_retryable));
}
#[test]
fn test_get_retryable_errors() {
let results = vec![
ToolExecutionResult {
tool_call_id: "call_1".to_string(),
tool_name: "tool_a".to_string(),
result: Some(Value::String("ok".to_string())),
error: None,
is_retryable: false,
},
ToolExecutionResult {
tool_call_id: "call_2".to_string(),
tool_name: "tool_b".to_string(),
result: None,
error: Some("parse error".to_string()),
is_retryable: true,
},
ToolExecutionResult {
tool_call_id: "call_3".to_string(),
tool_name: "tool_c".to_string(),
result: None,
error: Some("validation error".to_string()),
is_retryable: true,
},
];
let retryable = get_retryable_errors(&results);
assert_eq!(retryable.len(), 2);
assert_eq!(retryable[0].tool_call_id, "call_2");
assert_eq!(retryable[1].tool_call_id, "call_3");
}
#[test]
fn test_create_retry_messages() {
let results = vec![
ToolExecutionResult {
tool_call_id: "call_1".to_string(),
tool_name: "tool_a".to_string(),
result: Some(Value::String("ok".to_string())),
error: None,
is_retryable: false,
},
ToolExecutionResult {
tool_call_id: "call_2".to_string(),
tool_name: "tool_b".to_string(),
result: None,
error: Some("parse error".to_string()),
is_retryable: true,
},
];
let messages = create_retry_messages(&results);
assert_eq!(messages.len(), 1);
match &messages[0] {
InputItem::Message {
role,
content,
tool_call_id,
..
} => {
assert_eq!(*role, crate::types::MessageRole::Tool);
assert_eq!(tool_call_id.as_deref(), Some("call_2"));
let text = content
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
ContentPart::File { .. } => None,
})
.collect::<String>();
assert!(text.contains("Tool call error"));
assert!(text.contains("Please correct the arguments"));
}
}
}
#[tokio::test]
async fn test_execute_sets_is_retryable_for_json_parse_error() {
let registry =
ToolRegistry::new().register("my_tool", sync_handler(|_, _| Ok(Value::Null)));
let call = ToolCall {
id: "call_1".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "my_tool".to_string(),
arguments: "{invalid json".to_string(),
}),
};
let result = registry.execute(&call).await;
assert!(result.is_err());
assert!(result.is_retryable);
assert!(result.error.as_ref().unwrap().contains("failed to parse"));
}
#[tokio::test]
async fn test_execute_sets_is_retryable_for_validation_error() {
let registry = ToolRegistry::new().register(
"my_tool",
sync_handler(|_, _| Err("invalid arguments: missing field".to_string())),
);
let call = ToolCall {
id: "call_1".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "my_tool".to_string(),
arguments: "{}".to_string(),
}),
};
let result = registry.execute(&call).await;
assert!(result.is_err());
assert!(result.is_retryable);
}
#[tokio::test]
async fn test_execute_not_retryable_for_other_errors() {
let registry = ToolRegistry::new().register(
"my_tool",
sync_handler(|_, _| Err("network timeout".to_string())),
);
let call = ToolCall {
id: "call_1".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "my_tool".to_string(),
arguments: "{}".to_string(),
}),
};
let result = registry.execute(&call).await;
assert!(result.is_err());
assert!(!result.is_retryable);
}
#[tokio::test]
async fn test_execute_with_retry_no_errors() {
let registry = ToolRegistry::new().register(
"my_tool",
sync_handler(|_, _| Ok(serde_json::json!("success"))),
);
let calls = vec![ToolCall {
id: "call_1".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "my_tool".to_string(),
arguments: "{}".to_string(),
}),
}];
let results = execute_with_retry(
®istry,
calls,
RetryOptions {
max_retries: 2,
on_retry: |_, _| Box::pin(async { Ok(vec![]) }),
},
)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].is_ok());
}
#[tokio::test]
async fn test_execute_with_retry_retries_on_parse_error() {
use std::sync::atomic::{AtomicUsize, Ordering};
let retry_count = Arc::new(AtomicUsize::new(0));
let retry_count_clone = retry_count.clone();
let registry = ToolRegistry::new().register(
"my_tool",
sync_handler(|_, _| Ok(serde_json::json!("success"))),
);
let initial_calls = vec![ToolCall {
id: "call_1".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "my_tool".to_string(),
arguments: "{invalid".to_string(),
}),
}];
let results = execute_with_retry(
®istry,
initial_calls,
RetryOptions {
max_retries: 2,
on_retry: move |_messages, _attempt| {
retry_count_clone.fetch_add(1, Ordering::SeqCst);
Box::pin(async {
Ok(vec![ToolCall {
id: "call_1_retry".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "my_tool".to_string(),
arguments: "{}".to_string(),
}),
}])
})
},
},
)
.await;
assert_eq!(retry_count.load(Ordering::SeqCst), 1);
assert_eq!(results.len(), 1);
assert!(results[0].is_ok());
}
#[tokio::test]
async fn test_execute_with_retry_respects_max_retries() {
use std::sync::atomic::{AtomicUsize, Ordering};
let retry_count = Arc::new(AtomicUsize::new(0));
let retry_count_clone = retry_count.clone();
let registry = ToolRegistry::new().register(
"my_tool",
sync_handler(|_, _| Ok(serde_json::json!("success"))),
);
let initial_calls = vec![ToolCall {
id: "call_1".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "my_tool".to_string(),
arguments: "{invalid".to_string(),
}),
}];
let results = execute_with_retry(
®istry,
initial_calls,
RetryOptions {
max_retries: 2,
on_retry: move |_messages, _attempt| {
retry_count_clone.fetch_add(1, Ordering::SeqCst);
Box::pin(async {
Ok(vec![ToolCall {
id: "call_retry".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "my_tool".to_string(),
arguments: "{still invalid".to_string(),
}),
}])
})
},
},
)
.await;
assert_eq!(retry_count.load(Ordering::SeqCst), 2);
assert!(results[0].is_err());
}
#[tokio::test]
async fn test_execute_with_retry_preserves_successful_results() {
use std::sync::atomic::{AtomicUsize, Ordering};
let retry_count = Arc::new(AtomicUsize::new(0));
let retry_count_clone = retry_count.clone();
let registry = ToolRegistry::new()
.register(
"success_tool",
sync_handler(|_, _| Ok(serde_json::json!("success_result"))),
)
.register(
"failing_tool",
sync_handler(|_, _| Ok(serde_json::json!("fixed_result"))),
);
let initial_calls = vec![
ToolCall {
id: "call_success".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "success_tool".to_string(),
arguments: "{}".to_string(),
}),
},
ToolCall {
id: "call_fail".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "failing_tool".to_string(),
arguments: "{invalid".to_string(),
}),
},
];
let results = execute_with_retry(
®istry,
initial_calls,
RetryOptions {
max_retries: 2,
on_retry: move |_messages, _attempt| {
retry_count_clone.fetch_add(1, Ordering::SeqCst);
Box::pin(async {
Ok(vec![ToolCall {
id: "call_fail_retry".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: "failing_tool".to_string(),
arguments: "{}".to_string(),
}),
}])
})
},
},
)
.await;
assert_eq!(retry_count.load(Ordering::SeqCst), 1);
assert_eq!(results.len(), 2);
let original_success = results
.iter()
.find(|r| r.tool_call_id == "call_success")
.expect("original successful result was lost during retry");
assert!(original_success.is_ok());
assert_eq!(
original_success.result.as_ref().unwrap(),
&serde_json::json!("success_result")
);
let retry_success = results
.iter()
.find(|r| r.tool_call_id == "call_fail_retry")
.expect("retried result not found");
assert!(retry_success.is_ok());
assert_eq!(
retry_success.result.as_ref().unwrap(),
&serde_json::json!("fixed_result")
);
}
#[test]
fn test_parse_user_ask_args() {
let call = ToolCall {
id: "call_user_ask".to_string(),
kind: ToolType::Function,
function: Some(FunctionCall {
name: USER_ASK_TOOL_NAME.to_string(),
arguments: r#"{"question":"Pick one","options":[{"label":"A"}]}"#.to_string(),
}),
};
let args = parse_user_ask_args(&call).expect("parse user.ask args");
assert_eq!(args.question, "Pick one");
assert_eq!(args.options.len(), 1);
assert_eq!(args.options[0].label, "A");
}
#[test]
fn test_serialize_user_ask_result() {
let out = serialize_user_ask_result(&UserAskResponse {
answer: "PostgreSQL".to_string(),
is_freeform: false,
})
.expect("serialize user.ask result");
assert!(out.contains("PostgreSQL"));
assert!(out.contains("is_freeform"));
}
}