1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::error::OxideError;
7use crate::types::ToolDefinition;
8
9type AsyncResult = Pin<Box<dyn Future<Output = Result<serde_json::Value, OxideError>> + Send>>;
12type HandlerFn =
13 Arc<dyn Fn(serde_json::Value) -> AsyncResult + Send + Sync>;
14
15struct RegisteredTool {
18 definition: ToolDefinition,
19 handler: HandlerFn,
20}
21
22#[derive(Default)]
30pub struct ToolRegistry {
31 tools: HashMap<String, RegisteredTool>,
32}
33
34impl ToolRegistry {
35 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn register<F, Fut>(
45 &mut self,
46 definition: ToolDefinition,
47 handler: F,
48 ) where
49 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
50 Fut: Future<Output = Result<serde_json::Value, OxideError>> + Send + 'static,
51 {
52 let name = definition.function.name.clone();
53 self.tools.insert(
54 name,
55 RegisteredTool {
56 definition,
57 handler: Arc::new(move |args| Box::pin(handler(args))),
58 },
59 );
60 }
61
62 pub fn definitions(&self) -> Vec<ToolDefinition> {
64 self.tools.values().map(|t| t.definition.clone()).collect()
65 }
66
67 pub async fn dispatch(
71 &self,
72 tool_name: &str,
73 args: serde_json::Value,
74 ) -> Result<serde_json::Value, OxideError> {
75 let tool = self.tools.get(tool_name).ok_or_else(|| {
76 OxideError::Other(format!("unknown tool: {tool_name}"))
77 })?;
78
79 (tool.handler)(args).await
80 }
81
82 pub fn contains(&self, name: &str) -> bool {
83 self.tools.contains_key(name)
84 }
85
86 pub fn len(&self) -> usize {
87 self.tools.len()
88 }
89
90 pub fn is_empty(&self) -> bool {
91 self.tools.is_empty()
92 }
93}
94
95pub struct ToolBuilder {
100 name: String,
101 description: String,
102 properties: serde_json::Map<String, serde_json::Value>,
103 required: Vec<String>,
104}
105
106impl ToolBuilder {
107 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
108 Self {
109 name: name.into(),
110 description: description.into(),
111 properties: serde_json::Map::new(),
112 required: Vec::new(),
113 }
114 }
115
116 pub fn string_param(
118 mut self,
119 name: impl Into<String>,
120 description: impl Into<String>,
121 required: bool,
122 ) -> Self {
123 let n = name.into();
124 self.properties.insert(
125 n.clone(),
126 serde_json::json!({"type": "string", "description": description.into()}),
127 );
128 if required {
129 self.required.push(n);
130 }
131 self
132 }
133
134 pub fn number_param(
136 mut self,
137 name: impl Into<String>,
138 description: impl Into<String>,
139 required: bool,
140 ) -> Self {
141 let n = name.into();
142 self.properties.insert(
143 n.clone(),
144 serde_json::json!({"type": "number", "description": description.into()}),
145 );
146 if required {
147 self.required.push(n);
148 }
149 self
150 }
151
152 pub fn bool_param(
154 mut self,
155 name: impl Into<String>,
156 description: impl Into<String>,
157 required: bool,
158 ) -> Self {
159 let n = name.into();
160 self.properties.insert(
161 n.clone(),
162 serde_json::json!({"type": "boolean", "description": description.into()}),
163 );
164 if required {
165 self.required.push(n);
166 }
167 self
168 }
169
170 pub fn build(self) -> ToolDefinition {
171 use crate::types::FunctionDefinition;
172 ToolDefinition {
173 kind: "function".into(),
174 function: FunctionDefinition {
175 name: self.name,
176 description: self.description,
177 parameters: serde_json::json!({
178 "type": "object",
179 "properties": serde_json::Value::Object(self.properties),
180 "required": self.required,
181 }),
182 },
183 }
184 }
185}
186
187#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[tokio::test]
194 async fn registry_dispatch_calls_handler() {
195 let mut registry = ToolRegistry::new();
196
197 let def = ToolBuilder::new("add", "Add two numbers")
198 .number_param("a", "First operand", true)
199 .number_param("b", "Second operand", true)
200 .build();
201
202 registry.register(def, |args| async move {
203 let a = args["a"].as_f64().unwrap_or(0.0);
204 let b = args["b"].as_f64().unwrap_or(0.0);
205 Ok(serde_json::json!(a + b))
206 });
207
208 let result = registry
209 .dispatch("add", serde_json::json!({"a": 3.0, "b": 4.0}))
210 .await
211 .unwrap();
212
213 assert_eq!(result, serde_json::json!(7.0));
214 }
215
216 #[tokio::test]
217 async fn unknown_tool_returns_error() {
218 let registry = ToolRegistry::new();
219 let err = registry
220 .dispatch("nonexistent", serde_json::json!({}))
221 .await
222 .unwrap_err();
223 assert!(matches!(err, OxideError::Other(_)));
224 }
225
226 #[test]
227 fn definitions_are_returned() {
228 let mut registry = ToolRegistry::new();
229 let def = ToolBuilder::new("greet", "Say hello").build();
230 registry.register(def, |_| async move { Ok(serde_json::json!("hello")) });
231 assert_eq!(registry.definitions().len(), 1);
232 assert_eq!(registry.definitions()[0].function.name, "greet");
233 }
234}