use std::marker::PhantomData;
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use crate::error::ToolError;
use super::ToolFunction;
pub struct TypedTool<T: DeserializeOwned + JsonSchema + Send + Sync + 'static> {
name: String,
description: String,
schema: serde_json::Value,
#[allow(clippy::type_complexity)]
handler: Box<
dyn Fn(
T,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<serde_json::Value, ToolError>> + Send>,
> + Send
+ Sync,
>,
_phantom: PhantomData<T>,
}
impl<T: DeserializeOwned + JsonSchema + Send + Sync + 'static> TypedTool<T> {
pub fn new<F, Fut>(name: impl Into<String>, description: impl Into<String>, handler: F) -> Self
where
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
{
let root_schema = schemars::schema_for!(T);
let schema =
serde_json::to_value(root_schema).expect("schemars schema should serialize to JSON");
Self {
name: name.into(),
description: description.into(),
schema,
handler: Box::new(move |args| Box::pin(handler(args))),
_phantom: PhantomData,
}
}
}
#[async_trait]
impl<T: DeserializeOwned + JsonSchema + Send + Sync + 'static> ToolFunction for TypedTool<T> {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameters(&self) -> Option<serde_json::Value> {
Some(self.schema.clone())
}
async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
let typed_args: T = serde_json::from_value(args)
.map_err(|e| ToolError::InvalidArgs(format!("Failed to deserialize arguments: {e}")))?;
(self.handler)(typed_args).await
}
}