llmoxide_tools/
registry.rs1use crate::runner::ToolError;
2use async_trait::async_trait;
3use llmoxide::types::{ToolCall, ToolSpec};
4use schemars::{JsonSchema, schema_for};
5use serde::{Serialize, de::DeserializeOwned};
6use std::collections::BTreeMap;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
12
13#[derive(Debug, Clone)]
14pub struct ToolMeta {
15 pub name: String,
16 pub description: Option<String>,
17}
18
19impl ToolMeta {
20 pub fn new(name: impl Into<String>) -> Self {
21 Self {
22 name: name.into(),
23 description: None,
24 }
25 }
26
27 pub fn description(mut self, description: impl Into<String>) -> Self {
28 self.description = Some(description.into());
29 self
30 }
31}
32
33#[async_trait(?Send)]
34trait DynTool: Send + Sync {
35 fn spec(&self) -> ToolSpec;
36 fn name(&self) -> &str;
37 async fn call(&self, call: &ToolCall) -> Result<serde_json::Value, ToolError>;
38}
39
40struct ToolImpl<TArgs, TResult> {
41 meta: ToolMeta,
42 handler: Arc<dyn Fn(TArgs) -> BoxFuture<'static, Result<TResult, ToolError>> + Send + Sync>,
43 _phantom: std::marker::PhantomData<(TArgs, TResult)>,
44}
45
46impl<TArgs, TResult> ToolImpl<TArgs, TResult>
47where
48 TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
49 TResult: Serialize + Send + Sync + 'static,
50{
51 fn schema_json() -> serde_json::Value {
52 let schema = schema_for!(TArgs);
53 serde_json::to_value(&schema.schema).unwrap_or(serde_json::Value::Null)
54 }
55}
56
57#[async_trait(?Send)]
58impl<TArgs, TResult> DynTool for ToolImpl<TArgs, TResult>
59where
60 TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
61 TResult: Serialize + Send + Sync + 'static,
62{
63 fn spec(&self) -> ToolSpec {
64 ToolSpec {
65 name: self.meta.name.clone(),
66 description: self.meta.description.clone(),
67 parameters: Self::schema_json(),
68 }
69 }
70
71 fn name(&self) -> &str {
72 &self.meta.name
73 }
74
75 async fn call(&self, call: &ToolCall) -> Result<serde_json::Value, ToolError> {
76 let args: TArgs = serde_json::from_value(call.arguments.clone()).map_err(|e| {
77 ToolError::InvalidArguments {
78 tool: self.meta.name.clone(),
79 details: e.to_string(),
80 }
81 })?;
82
83 let res = (self.handler)(args).await?;
84 serde_json::to_value(res).map_err(|e| ToolError::Handler {
85 tool: self.meta.name.clone(),
86 details: e.to_string(),
87 })
88 }
89}
90
91#[derive(Clone, Default)]
93pub struct ToolRegistry {
94 tools: Arc<BTreeMap<String, Arc<dyn DynTool>>>,
95}
96
97impl ToolRegistry {
98 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn register<TArgs, TResult, Fut, F>(&mut self, meta: ToolMeta, handler: F) -> &mut Self
107 where
108 TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
109 TResult: Serialize + Send + Sync + 'static,
110 Fut: Future<Output = Result<TResult, ToolError>> + 'static,
111 F: Fn(TArgs) -> Fut + Send + Sync + 'static,
112 {
113 let mut map: BTreeMap<String, Arc<dyn DynTool>> = (*self.tools).clone();
114 let name = meta.name.clone();
115 let handler = Arc::new(
116 move |args: TArgs| -> BoxFuture<'static, Result<TResult, ToolError>> {
117 Box::pin(handler(args))
118 },
119 );
120 let tool = ToolImpl::<TArgs, TResult> {
121 meta,
122 handler,
123 _phantom: std::marker::PhantomData,
124 };
125 map.insert(name, Arc::new(tool));
126 self.tools = Arc::new(map);
127 self
128 }
129
130 pub fn specs(&self) -> Vec<ToolSpec> {
131 self.tools.values().map(|t| t.spec()).collect()
132 }
133
134 pub(crate) async fn dispatch(
135 &self,
136 call: &ToolCall,
137 ) -> Result<(String, serde_json::Value), ToolError> {
138 let Some(tool) = self.tools.get(&call.name) else {
139 return Err(ToolError::UnknownTool {
140 tool: call.name.clone(),
141 });
142 };
143 let out = tool.call(call).await?;
144 Ok((tool.name().to_string(), out))
145 }
146}