Skip to main content

oxide_agent/tools/
mod.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crate::error::OxideError;
7use crate::types::ToolDefinition;
8
9// ── Handler type ──────────────────────────────────────────────────────────────
10
11type AsyncResult = Pin<Box<dyn Future<Output = Result<serde_json::Value, OxideError>> + Send>>;
12type HandlerFn =
13    Arc<dyn Fn(serde_json::Value) -> AsyncResult + Send + Sync>;
14
15// ── Registered tool ───────────────────────────────────────────────────────────
16
17struct RegisteredTool {
18    definition: ToolDefinition,
19    handler: HandlerFn,
20}
21
22// ── ToolRegistry ──────────────────────────────────────────────────────────────
23
24/// Registry that maps tool names to their JSON Schema definitions and async
25/// handler functions.
26///
27/// Pass the definitions to [`ChatRequest::tools`] when calling Ollama and call
28/// [`ToolRegistry::dispatch`] when the model returns a tool-call in its reply.
29#[derive(Default)]
30pub struct ToolRegistry {
31    tools: HashMap<String, RegisteredTool>,
32}
33
34impl ToolRegistry {
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// Register a tool with its JSON Schema definition and an async handler.
40    ///
41    /// The handler receives the `arguments` object from Ollama's tool-call
42    /// response and must return a JSON value that will be fed back as the
43    /// tool result.
44    pub fn register<F, Fut>(
45        &mut self,
46        definition: ToolDefinition,
47        handler: F,
48    ) where
49        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
50        Fut: Future<Output = Result<serde_json::Value, OxideError>> + Send + 'static,
51    {
52        let name = definition.function.name.clone();
53        self.tools.insert(
54            name,
55            RegisteredTool {
56                definition,
57                handler: Arc::new(move |args| Box::pin(handler(args))),
58            },
59        );
60    }
61
62    /// All definitions, ready to pass directly to [`ChatRequest::tools`].
63    pub fn definitions(&self) -> Vec<ToolDefinition> {
64        self.tools.values().map(|t| t.definition.clone()).collect()
65    }
66
67    /// Execute the handler for `tool_name` with the given arguments.
68    ///
69    /// Returns `Err(OxideError::Other)` if the tool name is not registered.
70    pub async fn dispatch(
71        &self,
72        tool_name: &str,
73        args: serde_json::Value,
74    ) -> Result<serde_json::Value, OxideError> {
75        let tool = self.tools.get(tool_name).ok_or_else(|| {
76            OxideError::Other(format!("unknown tool: {tool_name}"))
77        })?;
78
79        (tool.handler)(args).await
80    }
81
82    pub fn contains(&self, name: &str) -> bool {
83        self.tools.contains_key(name)
84    }
85
86    pub fn len(&self) -> usize {
87        self.tools.len()
88    }
89
90    pub fn is_empty(&self) -> bool {
91        self.tools.is_empty()
92    }
93}
94
95// ── Helper: build a ToolDefinition inline ────────────────────────────────────
96
97/// Convenience builder for [`ToolDefinition`] without the `#[ollama_tool]`
98/// macro.
99pub struct ToolBuilder {
100    name: String,
101    description: String,
102    properties: serde_json::Map<String, serde_json::Value>,
103    required: Vec<String>,
104}
105
106impl ToolBuilder {
107    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
108        Self {
109            name: name.into(),
110            description: description.into(),
111            properties: serde_json::Map::new(),
112            required: Vec::new(),
113        }
114    }
115
116    /// Add a string parameter.
117    pub fn string_param(
118        mut self,
119        name: impl Into<String>,
120        description: impl Into<String>,
121        required: bool,
122    ) -> Self {
123        let n = name.into();
124        self.properties.insert(
125            n.clone(),
126            serde_json::json!({"type": "string", "description": description.into()}),
127        );
128        if required {
129            self.required.push(n);
130        }
131        self
132    }
133
134    /// Add a numeric parameter.
135    pub fn number_param(
136        mut self,
137        name: impl Into<String>,
138        description: impl Into<String>,
139        required: bool,
140    ) -> Self {
141        let n = name.into();
142        self.properties.insert(
143            n.clone(),
144            serde_json::json!({"type": "number", "description": description.into()}),
145        );
146        if required {
147            self.required.push(n);
148        }
149        self
150    }
151
152    /// Add a boolean parameter.
153    pub fn bool_param(
154        mut self,
155        name: impl Into<String>,
156        description: impl Into<String>,
157        required: bool,
158    ) -> Self {
159        let n = name.into();
160        self.properties.insert(
161            n.clone(),
162            serde_json::json!({"type": "boolean", "description": description.into()}),
163        );
164        if required {
165            self.required.push(n);
166        }
167        self
168    }
169
170    pub fn build(self) -> ToolDefinition {
171        use crate::types::FunctionDefinition;
172        ToolDefinition {
173            kind: "function".into(),
174            function: FunctionDefinition {
175                name: self.name,
176                description: self.description,
177                parameters: serde_json::json!({
178                    "type": "object",
179                    "properties": serde_json::Value::Object(self.properties),
180                    "required": self.required,
181                }),
182            },
183        }
184    }
185}
186
187// ── Tests ─────────────────────────────────────────────────────────────────────
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[tokio::test]
194    async fn registry_dispatch_calls_handler() {
195        let mut registry = ToolRegistry::new();
196
197        let def = ToolBuilder::new("add", "Add two numbers")
198            .number_param("a", "First operand", true)
199            .number_param("b", "Second operand", true)
200            .build();
201
202        registry.register(def, |args| async move {
203            let a = args["a"].as_f64().unwrap_or(0.0);
204            let b = args["b"].as_f64().unwrap_or(0.0);
205            Ok(serde_json::json!(a + b))
206        });
207
208        let result = registry
209            .dispatch("add", serde_json::json!({"a": 3.0, "b": 4.0}))
210            .await
211            .unwrap();
212
213        assert_eq!(result, serde_json::json!(7.0));
214    }
215
216    #[tokio::test]
217    async fn unknown_tool_returns_error() {
218        let registry = ToolRegistry::new();
219        let err = registry
220            .dispatch("nonexistent", serde_json::json!({}))
221            .await
222            .unwrap_err();
223        assert!(matches!(err, OxideError::Other(_)));
224    }
225
226    #[test]
227    fn definitions_are_returned() {
228        let mut registry = ToolRegistry::new();
229        let def = ToolBuilder::new("greet", "Say hello").build();
230        registry.register(def, |_| async move { Ok(serde_json::json!("hello")) });
231        assert_eq!(registry.definitions().len(), 1);
232        assert_eq!(registry.definitions()[0].function.name, "greet");
233    }
234}