use serde_json::Value;
use std::marker::PhantomData;
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn schema(&self) -> Value;
fn call(&self, args: Value) -> Result<String, String>;
}
pub trait TypedTool: Send + Sync {
type Args: serde::de::DeserializeOwned + Send;
type Output: serde::Serialize + Send;
type Error: std::fmt::Display + Send + Sync;
const NAME: &'static str;
const DESCRIPTION: &'static str;
fn schema() -> Value;
fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error>;
}
pub struct ErasedAdapter<T: TypedTool> {
inner: T,
_marker: PhantomData<fn() -> T>,
}
impl<T: TypedTool> ErasedAdapter<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
_marker: PhantomData,
}
}
pub fn inner(&self) -> &T {
&self.inner
}
}
impl<T: TypedTool> Tool for ErasedAdapter<T> {
fn name(&self) -> &str {
T::NAME
}
fn description(&self) -> &str {
T::DESCRIPTION
}
fn schema(&self) -> Value {
T::schema()
}
fn call(&self, args: Value) -> Result<String, String> {
let typed: T::Args = serde_json::from_value(args)
.map_err(|e| format!("args deserialize: {}", e))?;
let out = self.inner.call(typed).map_err(|e| e.to_string())?;
serde_json::to_string(&out).map_err(|e| format!("output serialize: {}", e))
}
}
pub struct Registry {
tools: Vec<Box<dyn Tool>>,
}
impl Registry {
pub fn new() -> Self {
Self { tools: Vec::new() }
}
pub fn register(&mut self, tool: Box<dyn Tool>) {
self.tools.push(tool);
}
pub fn register_typed<T: TypedTool + 'static>(&mut self, tool: T) {
self.tools.push(Box::new(ErasedAdapter::new(tool)));
}
pub fn dispatch(&self, name: &str, args: Value) -> Result<String, String> {
self.tools
.iter()
.find(|t| t.name() == name)
.ok_or_else(|| format!("unknown tool: {}", name))?
.call(args)
}
pub fn names(&self) -> Vec<&str> {
self.tools.iter().map(|t| t.name()).collect()
}
pub fn as_openai_tools(&self) -> Value {
Value::Array(
self.tools
.iter()
.map(|t| {
serde_json::json!({
"type": "function",
"function": {
"name": t.name(),
"description": t.description(),
"parameters": t.schema(),
}
})
})
.collect(),
)
}
}
impl Default for Registry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Deserialize)]
struct AddArgs {
a: i64,
b: i64,
}
#[derive(Serialize)]
struct AddOut {
sum: i64,
}
struct Add;
impl TypedTool for Add {
type Args = AddArgs;
type Output = AddOut;
type Error = String;
const NAME: &'static str = "add";
const DESCRIPTION: &'static str = "Add two integers.";
fn schema() -> Value {
serde_json::json!({
"type": "object",
"properties": {
"a": {"type": "integer"},
"b": {"type": "integer"}
},
"required": ["a", "b"]
})
}
fn call(&self, args: AddArgs) -> Result<AddOut, String> {
Ok(AddOut { sum: args.a + args.b })
}
}
#[test]
fn typed_tool_roundtrips_through_erased_adapter() {
let mut reg = Registry::new();
reg.register_typed(Add);
let out = reg
.dispatch("add", serde_json::json!({"a": 2, "b": 3}))
.expect("dispatch");
assert_eq!(out, r#"{"sum":5}"#);
}
#[test]
fn typed_tool_args_deserialize_error_is_string() {
let mut reg = Registry::new();
reg.register_typed(Add);
let err = reg
.dispatch("add", serde_json::json!({"a": "not-a-number"}))
.unwrap_err();
assert!(err.contains("args deserialize"), "got: {}", err);
}
#[test]
fn erased_adapter_name_and_description_are_const() {
let adapter = ErasedAdapter::new(Add);
assert_eq!(adapter.name(), "add");
assert_eq!(adapter.description(), "Add two integers.");
}
}