Skip to main content

llmoxide_tools/
registry.rs

1use crate::runner::ToolError;
2use async_trait::async_trait;
3use llmoxide::types::{ToolCall, ToolSpec};
4use schemars::{JsonSchema, schema_for};
5use serde::{Serialize, de::DeserializeOwned};
6use std::collections::BTreeMap;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
12
13#[derive(Debug, Clone)]
14pub struct ToolMeta {
15    pub name: String,
16    pub description: Option<String>,
17}
18
19impl ToolMeta {
20    pub fn new(name: impl Into<String>) -> Self {
21        Self {
22            name: name.into(),
23            description: None,
24        }
25    }
26
27    pub fn description(mut self, description: impl Into<String>) -> Self {
28        self.description = Some(description.into());
29        self
30    }
31}
32
33#[async_trait(?Send)]
34trait DynTool: Send + Sync {
35    fn spec(&self) -> ToolSpec;
36    fn name(&self) -> &str;
37    async fn call(&self, call: &ToolCall) -> Result<serde_json::Value, ToolError>;
38}
39
40struct ToolImpl<TArgs, TResult> {
41    meta: ToolMeta,
42    handler: Arc<dyn Fn(TArgs) -> BoxFuture<'static, Result<TResult, ToolError>> + Send + Sync>,
43    _phantom: std::marker::PhantomData<(TArgs, TResult)>,
44}
45
46impl<TArgs, TResult> ToolImpl<TArgs, TResult>
47where
48    TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
49    TResult: Serialize + Send + Sync + 'static,
50{
51    fn schema_json() -> serde_json::Value {
52        let schema = schema_for!(TArgs);
53        serde_json::to_value(&schema.schema).unwrap_or(serde_json::Value::Null)
54    }
55}
56
57#[async_trait(?Send)]
58impl<TArgs, TResult> DynTool for ToolImpl<TArgs, TResult>
59where
60    TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
61    TResult: Serialize + Send + Sync + 'static,
62{
63    fn spec(&self) -> ToolSpec {
64        ToolSpec {
65            name: self.meta.name.clone(),
66            description: self.meta.description.clone(),
67            parameters: Self::schema_json(),
68        }
69    }
70
71    fn name(&self) -> &str {
72        &self.meta.name
73    }
74
75    async fn call(&self, call: &ToolCall) -> Result<serde_json::Value, ToolError> {
76        let args: TArgs = serde_json::from_value(call.arguments.clone()).map_err(|e| {
77            ToolError::InvalidArguments {
78                tool: self.meta.name.clone(),
79                details: e.to_string(),
80            }
81        })?;
82
83        let res = (self.handler)(args).await?;
84        serde_json::to_value(res).map_err(|e| ToolError::Handler {
85            tool: self.meta.name.clone(),
86            details: e.to_string(),
87        })
88    }
89}
90
91/// Registry of typed tools, convertible to provider tool schemas.
92#[derive(Clone, Default)]
93pub struct ToolRegistry {
94    tools: Arc<BTreeMap<String, Arc<dyn DynTool>>>,
95}
96
97impl ToolRegistry {
98    pub fn new() -> Self {
99        Self::default()
100    }
101
102    /// Register a tool with typed args and typed result.
103    ///
104    /// - `TArgs` must implement `JsonSchema` so we can produce a provider tool schema.
105    /// - The handler returns a typed result which is serialized to JSON for tool responses.
106    pub fn register<TArgs, TResult, Fut, F>(&mut self, meta: ToolMeta, handler: F) -> &mut Self
107    where
108        TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
109        TResult: Serialize + Send + Sync + 'static,
110        Fut: Future<Output = Result<TResult, ToolError>> + 'static,
111        F: Fn(TArgs) -> Fut + Send + Sync + 'static,
112    {
113        let mut map: BTreeMap<String, Arc<dyn DynTool>> = (*self.tools).clone();
114        let name = meta.name.clone();
115        let handler = Arc::new(
116            move |args: TArgs| -> BoxFuture<'static, Result<TResult, ToolError>> {
117                Box::pin(handler(args))
118            },
119        );
120        let tool = ToolImpl::<TArgs, TResult> {
121            meta,
122            handler,
123            _phantom: std::marker::PhantomData,
124        };
125        map.insert(name, Arc::new(tool));
126        self.tools = Arc::new(map);
127        self
128    }
129
130    pub fn specs(&self) -> Vec<ToolSpec> {
131        self.tools.values().map(|t| t.spec()).collect()
132    }
133
134    pub(crate) async fn dispatch(
135        &self,
136        call: &ToolCall,
137    ) -> Result<(String, serde_json::Value), ToolError> {
138        let Some(tool) = self.tools.get(&call.name) else {
139            return Err(ToolError::UnknownTool {
140                tool: call.name.clone(),
141            });
142        };
143        let out = tool.call(call).await?;
144        Ok((tool.name().to_string(), out))
145    }
146}