use crate::runner::ToolError;
use async_trait::async_trait;
use llmoxide::types::{ToolCall, ToolSpec};
use schemars::{JsonSchema, schema_for};
use serde::{Serialize, de::DeserializeOwned};
use std::collections::BTreeMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
#[derive(Debug, Clone)]
pub struct ToolMeta {
pub name: String,
pub description: Option<String>,
}
impl ToolMeta {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
description: None,
}
}
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
}
#[async_trait(?Send)]
trait DynTool: Send + Sync {
fn spec(&self) -> ToolSpec;
fn name(&self) -> &str;
async fn call(&self, call: &ToolCall) -> Result<serde_json::Value, ToolError>;
}
struct ToolImpl<TArgs, TResult> {
meta: ToolMeta,
handler: Arc<dyn Fn(TArgs) -> BoxFuture<'static, Result<TResult, ToolError>> + Send + Sync>,
_phantom: std::marker::PhantomData<(TArgs, TResult)>,
}
impl<TArgs, TResult> ToolImpl<TArgs, TResult>
where
TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
TResult: Serialize + Send + Sync + 'static,
{
fn schema_json() -> serde_json::Value {
let schema = schema_for!(TArgs);
serde_json::to_value(&schema.schema).unwrap_or(serde_json::Value::Null)
}
}
#[async_trait(?Send)]
impl<TArgs, TResult> DynTool for ToolImpl<TArgs, TResult>
where
TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
TResult: Serialize + Send + Sync + 'static,
{
fn spec(&self) -> ToolSpec {
ToolSpec {
name: self.meta.name.clone(),
description: self.meta.description.clone(),
parameters: Self::schema_json(),
}
}
fn name(&self) -> &str {
&self.meta.name
}
async fn call(&self, call: &ToolCall) -> Result<serde_json::Value, ToolError> {
let args: TArgs = serde_json::from_value(call.arguments.clone()).map_err(|e| {
ToolError::InvalidArguments {
tool: self.meta.name.clone(),
details: e.to_string(),
}
})?;
let res = (self.handler)(args).await?;
serde_json::to_value(res).map_err(|e| ToolError::Handler {
tool: self.meta.name.clone(),
details: e.to_string(),
})
}
}
#[derive(Clone, Default)]
pub struct ToolRegistry {
tools: Arc<BTreeMap<String, Arc<dyn DynTool>>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register<TArgs, TResult, Fut, F>(&mut self, meta: ToolMeta, handler: F) -> &mut Self
where
TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
TResult: Serialize + Send + Sync + 'static,
Fut: Future<Output = Result<TResult, ToolError>> + 'static,
F: Fn(TArgs) -> Fut + Send + Sync + 'static,
{
let mut map: BTreeMap<String, Arc<dyn DynTool>> = (*self.tools).clone();
let name = meta.name.clone();
let handler = Arc::new(
move |args: TArgs| -> BoxFuture<'static, Result<TResult, ToolError>> {
Box::pin(handler(args))
},
);
let tool = ToolImpl::<TArgs, TResult> {
meta,
handler,
_phantom: std::marker::PhantomData,
};
map.insert(name, Arc::new(tool));
self.tools = Arc::new(map);
self
}
pub fn specs(&self) -> Vec<ToolSpec> {
self.tools.values().map(|t| t.spec()).collect()
}
pub(crate) async fn dispatch(
&self,
call: &ToolCall,
) -> Result<(String, serde_json::Value), ToolError> {
let Some(tool) = self.tools.get(&call.name) else {
return Err(ToolError::UnknownTool {
tool: call.name.clone(),
});
};
let out = tool.call(call).await?;
Ok((tool.name().to_string(), out))
}
}