use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value;
use crate::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSchema {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[async_trait]
pub trait Tool: Send + Sync + 'static {
type Input: DeserializeOwned + JsonSchema + Send;
type Output: Serialize + Send;
fn name(&self) -> &str;
fn description(&self) -> &str;
fn schema(&self) -> ToolSchema {
let schema = schemars::schema_for!(Self::Input);
ToolSchema {
name: self.name().to_string(),
description: self.description().to_string(),
parameters: serde_json::to_value(schema).unwrap_or_default(),
}
}
async fn execute(&self, input: Self::Input) -> Result<Self::Output>;
}
#[async_trait]
pub trait ErasedTool: Send + Sync + 'static {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn schema(&self) -> ToolSchema;
async fn execute_json(&self, input: Value) -> Result<Value>;
}
#[async_trait]
impl<T: Tool> ErasedTool for T {
fn name(&self) -> &str {
Tool::name(self)
}
fn description(&self) -> &str {
Tool::description(self)
}
fn schema(&self) -> ToolSchema {
Tool::schema(self)
}
async fn execute_json(&self, input: Value) -> Result<Value> {
let typed_input: T::Input = serde_json::from_value(input).map_err(|e| {
crate::Error::tool_execution(self.name(), format!("Invalid input: {e}"))
})?;
let output = self.execute(typed_input).await?;
serde_json::to_value(output).map_err(|e| {
crate::Error::tool_execution(self.name(), format!("Failed to serialize output: {e}"))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Deserialize, JsonSchema)]
struct AddInput {
a: i64,
b: i64,
}
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct AddOutput {
sum: i64,
}
struct AddTool;
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Tool for AddTool {
type Input = AddInput;
type Output = AddOutput;
fn name(&self) -> &str {
"add"
}
fn description(&self) -> &str {
"Add two numbers"
}
async fn execute(&self, input: Self::Input) -> Result<Self::Output> {
Ok(AddOutput {
sum: input.a + input.b,
})
}
}
#[tokio::test]
async fn test_tool_execute_typed() {
let tool = AddTool;
let result = tool.execute(AddInput { a: 3, b: 4 }).await.unwrap();
assert_eq!(result.sum, 7);
}
#[test]
fn test_tool_name_and_description() {
let tool = AddTool;
assert_eq!(Tool::name(&tool), "add");
assert_eq!(Tool::description(&tool), "Add two numbers");
}
#[test]
fn test_schema_generation() {
let tool = AddTool;
let schema = Tool::schema(&tool);
assert_eq!(schema.name, "add");
assert_eq!(schema.description, "Add two numbers");
let params = &schema.parameters;
let props = params
.get("properties")
.expect("schema should have properties");
assert!(props.get("a").is_some(), "schema missing 'a' property");
assert!(props.get("b").is_some(), "schema missing 'b' property");
}
#[test]
fn test_tool_schema_serializes_to_json() {
let tool = AddTool;
let schema = Tool::schema(&tool);
let json = serde_json::to_value(&schema).unwrap();
assert_eq!(json["name"], "add");
assert_eq!(json["description"], "Add two numbers");
assert!(json["parameters"].is_object());
}
#[test]
fn test_erased_tool_in_vec() {
let tools: Vec<std::sync::Arc<dyn ErasedTool>> = vec![std::sync::Arc::new(AddTool)];
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name(), "add");
assert_eq!(tools[0].description(), "Add two numbers");
}
#[tokio::test]
async fn test_erased_tool_json_round_trip() {
let tool: std::sync::Arc<dyn ErasedTool> = std::sync::Arc::new(AddTool);
let input = serde_json::json!({"a": 10, "b": 20});
let output = tool.execute_json(input).await.unwrap();
let result: AddOutput = serde_json::from_value(output).unwrap();
assert_eq!(result.sum, 30);
}
#[tokio::test]
async fn test_erased_tool_invalid_input_returns_error() {
let tool: std::sync::Arc<dyn ErasedTool> = std::sync::Arc::new(AddTool);
let bad_input = serde_json::json!({"x": "not a number"});
let result = tool.execute_json(bad_input).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("add"),
"error should mention tool name"
);
assert!(
err.to_string().contains("Invalid input"),
"error should say invalid input"
);
}
#[test]
fn test_erased_tool_schema_matches_tool_schema() {
let tool = AddTool;
let direct_schema = Tool::schema(&tool);
let erased_schema = ErasedTool::schema(&tool);
assert_eq!(direct_schema.name, erased_schema.name);
assert_eq!(direct_schema.description, erased_schema.description);
assert_eq!(direct_schema.parameters, erased_schema.parameters);
}
}