use std::fmt::Debug;
use std::future::Future;
use schemars::JsonSchema;
use schemars::schema_for;
use serde::{Deserialize, Serialize};
use crate::ToolDefinition;
use super::tool_error::ToolResult;
pub trait Tool: Send + Sync {
type Params: for<'de> Deserialize<'de> + JsonSchema + Send + Debug;
type Output: Serialize + Send;
fn name(&self) -> &'static str;
fn description(&self) -> &'static str;
fn execute(
&self,
params: Self::Params,
) -> impl Future<Output = ToolResult<Self::Output>> + Send;
fn parameters_schema(&self) -> serde_json::Value {
let schema = schema_for!(Self::Params);
serde_json::to_value(schema).unwrap_or_else(|_| serde_json::json!({}))
}
fn to_definition(&self) -> ToolDefinition {
ToolDefinition::function(self.name(), self.parameters_schema())
.with_description(self.description())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
#[derive(Debug, Deserialize, JsonSchema)]
struct TestParams {
value: String,
}
#[derive(Serialize)]
struct TestOutput {
result: String,
}
struct TestTool;
impl Tool for TestTool {
type Params = TestParams;
type Output = TestOutput;
fn name(&self) -> &'static str {
"test_tool"
}
fn description(&self) -> &'static str {
"A test tool"
}
async fn execute(&self, params: Self::Params) -> ToolResult<Self::Output> {
Ok(TestOutput {
result: format!("processed: {}", params.value),
})
}
}
#[test]
fn test_tool_name() {
let tool = TestTool;
assert_eq!(tool.name(), "test_tool");
}
#[test]
fn test_tool_description() {
let tool = TestTool;
assert_eq!(tool.description(), "A test tool");
}
#[test]
fn test_tool_parameters_schema() {
let tool = TestTool;
let schema = tool.parameters_schema();
assert!(schema.is_object());
}
#[test]
fn test_tool_to_definition() {
let tool = TestTool;
let def = tool.to_definition();
assert_eq!(def.name(), "test_tool");
assert_eq!(def.description(), Some("A test tool"));
}
#[tokio::test]
async fn test_tool_execute() {
let tool = TestTool;
let params = TestParams {
value: "hello".to_string(),
};
let result = tool.execute(params).await.unwrap();
assert_eq!(result.result, "processed: hello");
}
}