use std::{borrow::Cow, future::Future, pin::Pin};
use super::types::{ToolContext, ToolDefinition, ToolError, ToolOutput};
#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
pub struct EmptyParams {}
pub trait RustTool: Send + Sync {
type Params: serde::de::DeserializeOwned + schemars::JsonSchema + Send;
const NAME: &'static str;
const DESCRIPTION: &'static str;
fn description(&self) -> Cow<'static, str> {
Cow::Borrowed(Self::DESCRIPTION)
}
fn call(
&self,
params: Self::Params,
ctx: &ToolContext,
) -> impl std::future::Future<Output = Result<ToolOutput, ToolError>> + Send;
}
pub fn definition_of<T: RustTool>(tool: &T) -> Result<ToolDefinition, ToolError> {
let schema = schemars::schema_for!(T::Params);
let mut parameter_schema = serde_json::to_value(schema).map_err(|e| {
ToolError::new(format!(
"Failed to serialize schema for tool '{}': {e}",
T::NAME
))
})?;
sanitize_schema_types(&mut parameter_schema);
Ok(ToolDefinition {
name: T::NAME.to_string(),
description: tool.description().into_owned(),
parameter_schema,
})
}
fn sanitize_schema_types(value: &mut serde_json::Value) {
match value {
serde_json::Value::Object(map) => {
let replacement = match map.get("type") {
Some(serde_json::Value::Array(arr)) => {
let non_null = arr.iter().find(|v| v.as_str() != Some("null")).cloned();
non_null.or_else(|| arr.first().cloned())
}
_ => None,
};
if let Some(val) = replacement {
map.insert("type".to_string(), val);
}
for val in map.values_mut() {
sanitize_schema_types(val);
}
}
serde_json::Value::Array(arr) => {
for item in arr {
sanitize_schema_types(item);
}
}
_ => {}
}
}
type BoxToolFuture<'a> = Pin<Box<dyn Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>>;
pub(crate) trait ErasedTool: Send + Sync {
fn call_erased<'a>(
&'a self,
args: serde_json::Value,
ctx: &'a ToolContext,
) -> BoxToolFuture<'a>;
}
impl<T: RustTool> ErasedTool for T {
fn call_erased<'a>(
&'a self,
args: serde_json::Value,
ctx: &'a ToolContext,
) -> BoxToolFuture<'a> {
Box::pin(async move {
let params: T::Params = serde_json::from_value(args).map_err(|e| {
ToolError::new(format!("Failed to deserialize tool parameters: {e}"))
})?;
self.call(params, ctx).await
})
}
}