use schemars::{JsonSchema, schema_for};
use serde_json::Value;
use crate::types::Tool;
pub trait TypedTool: JsonSchema + serde::Serialize + for<'de> serde::Deserialize<'de> {
fn name() -> &'static str;
fn description() -> &'static str;
fn create_tool() -> Tool {
let schema = schema_for!(Self);
let schema_json = serde_json::to_value(schema).unwrap_or(Value::Null);
Tool::new(Self::name(), Self::description(), schema_json)
}
fn get_schema() -> serde_json::Value {
let schema = schema_for!(Self);
serde_json::to_value(schema).unwrap_or(Value::Null)
}
}
pub trait TypedToolParams: TypedTool {
fn validate(&self) -> Result<(), String> {
Ok(())
}
fn from_json_value(value: serde_json::Value) -> Result<Self, serde_json::Error> {
serde_json::from_value(value)
}
fn to_json_value(&self) -> Result<serde_json::Value, serde_json::Error> {
serde_json::to_value(self)
}
}
impl<T: TypedTool> TypedToolParams for T {}
#[cfg(test)]
mod tests {
use super::*;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, JsonSchema, Debug, PartialEq)]
struct TestTool {
pub message: String,
pub count: u32,
}
impl TypedTool for TestTool {
fn name() -> &'static str {
"test_tool"
}
fn description() -> &'static str {
"A test tool for unit testing"
}
}
#[test]
fn test_typed_tool_creation() {
let tool = TestTool::create_tool();
assert_eq!(tool.function.name, "test_tool");
assert_eq!(tool.function.description, "A test tool for unit testing");
assert_eq!(tool.tool_type, "function");
let schema = &tool.function.parameters;
assert_eq!(schema["type"], "object");
assert!(schema["properties"].is_object());
assert!(schema["properties"]["message"].is_object());
assert!(schema["properties"]["count"].is_object());
}
#[test]
fn test_schema_generation() {
let schema = TestTool::get_schema();
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["message"]["type"] == "string");
assert!(schema["properties"]["count"]["type"] == "integer");
let required = schema["required"].as_array().unwrap();
assert_eq!(required.len(), 2);
assert!(required.contains(&serde_json::json!("message")));
assert!(required.contains(&serde_json::json!("count")));
}
#[test]
fn test_json_conversion() {
let test_tool = TestTool {
message: "Hello".to_string(),
count: 42,
};
let json_value = test_tool.to_json_value().unwrap();
let converted_back = TestTool::from_json_value(json_value).unwrap();
assert_eq!(test_tool, converted_back);
}
#[derive(Serialize, Deserialize, JsonSchema, Debug)]
struct EnumTool {
pub operation: TestOperation,
pub value: f64,
}
#[derive(Serialize, Deserialize, JsonSchema, Debug)]
#[serde(rename_all = "lowercase")]
enum TestOperation {
Square,
Sqrt,
Abs,
}
impl TypedTool for EnumTool {
fn name() -> &'static str {
"math_tool"
}
fn description() -> &'static str {
"Perform mathematical operations"
}
}
#[test]
fn test_enum_schema_generation() {
let tool = EnumTool::create_tool();
let schema = &tool.function.parameters;
assert!(schema["properties"]["operation"].is_object());
let operation_schema = &schema["properties"]["operation"];
assert!(operation_schema.is_object());
assert_eq!(tool.function.name, "math_tool");
assert_eq!(tool.function.description, "Perform mathematical operations");
}
}