use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
use crate::context::ToolContext;
use crate::error::ToolError;
pub trait Tool: Send + Sync + 'static {
type Params: DeserializeOwned + JsonSchema + Send + Sync + 'static;
type Output: Serialize + JsonSchema + Send + Sync + 'static;
const NAME: &'static str;
const DESCRIPTION: &'static str;
const EXAMPLES: &'static [(&'static str, &'static str)] = &[];
fn schema() -> &'static serde_json::Value {
use std::any::TypeId;
use std::cell::RefCell;
use std::collections::HashMap;
thread_local! {
static SCHEMAS: RefCell<HashMap<TypeId, &'static serde_json::Value>> =
RefCell::new(HashMap::new());
}
SCHEMAS.with(|schemas| {
let type_id = TypeId::of::<Self::Params>();
let mut cache = schemas.borrow_mut();
*cache.entry(type_id).or_insert_with(|| {
let root_schema = schemars::schema_for!(Self::Params);
let value = serde_json::to_value(root_schema)
.expect("Failed to serialize schema");
Box::leak(Box::new(value))
})
})
}
fn output_schema() -> Option<&'static serde_json::Value> {
use std::any::TypeId;
use std::cell::RefCell;
use std::collections::HashMap;
thread_local! {
static SCHEMAS: RefCell<HashMap<TypeId, &'static serde_json::Value>> =
RefCell::new(HashMap::new());
}
Some(SCHEMAS.with(|schemas| {
let type_id = TypeId::of::<Self::Output>();
let mut cache = schemas.borrow_mut();
*cache.entry(type_id).or_insert_with(|| {
let root_schema = schemars::schema_for!(Self::Output);
let value = serde_json::to_value(root_schema)
.expect("Failed to serialize schema");
Box::leak(Box::new(value))
})
}))
}
async fn call(ctx: Arc<ToolContext>, params: Self::Params) -> Result<Self::Output, ToolError>;
}
pub trait HasAnnotations: Tool {
const READ_ONLY: bool = false;
const IDEMPOTENT: bool = false;
const DESTRUCTIVE: bool = false;
const OPEN_WORLD: bool = false;
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
struct TestParams {
value: i32,
}
#[derive(Debug, serde::Serialize, schemars::JsonSchema)]
struct TestOutput {
doubled: i32,
}
struct DoubleTool;
impl Tool for DoubleTool {
type Params = TestParams;
type Output = TestOutput;
const NAME: &'static str = "double";
const DESCRIPTION: &'static str = "Double the input value";
async fn call(
_ctx: Arc<ToolContext>,
params: Self::Params,
) -> Result<Self::Output, ToolError> {
Ok(TestOutput {
doubled: params.value * 2,
})
}
}
impl HasAnnotations for DoubleTool {
const READ_ONLY: bool = true;
const IDEMPOTENT: bool = true;
}
#[tokio::test]
async fn test_tool_meta() {
assert_eq!(DoubleTool::NAME, "double");
assert_eq!(DoubleTool::DESCRIPTION, "Double the input value");
}
#[tokio::test]
async fn test_schema_generation() {
let schema = <DoubleTool as Tool>::schema();
assert!(schema.is_object());
assert_eq!(schema.get("title").unwrap().as_str().unwrap(), "TestParams");
}
#[tokio::test]
async fn test_tool_call() {
let ctx = Arc::new(ToolContext::new());
let params = TestParams { value: 5 };
let result = DoubleTool::call(ctx, params).await.unwrap();
assert_eq!(result.doubled, 10);
}
#[tokio::test]
async fn test_annotations() {
assert!(<DoubleTool as HasAnnotations>::READ_ONLY);
assert!(<DoubleTool as HasAnnotations>::IDEMPOTENT);
assert!(!<DoubleTool as HasAnnotations>::DESTRUCTIVE);
}
}