Skip to main content

neuron_tool/
lib.rs

1#![deny(missing_docs)]
2//! Tool interface and registry for neuron.
3//!
4//! Defines the [`ToolDyn`] trait for object-safe tool abstraction and
5//! [`ToolRegistry`] for managing collections of tools. Any tool source
6//! (local function, MCP server, HTTP endpoint) implements [`ToolDyn`].
7
8use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use thiserror::Error;
13
14/// Errors from tool operations.
15#[non_exhaustive]
16#[derive(Debug, Error)]
17pub enum ToolError {
18    /// The requested tool was not found in the registry.
19    #[error("tool not found: {0}")]
20    NotFound(String),
21
22    /// Tool execution failed.
23    #[error("execution failed: {0}")]
24    ExecutionFailed(String),
25
26    /// The input provided to the tool was invalid.
27    #[error("invalid input: {0}")]
28    InvalidInput(String),
29
30    /// Catch-all for other errors.
31    #[error("{0}")]
32    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
33}
34
35/// Concurrency hint for tool scheduling.
36#[non_exhaustive]
37#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
38pub enum ToolConcurrencyHint {
39    /// Safe to run alongside other shared tools in the same batch.
40    Shared,
41    /// Must run alone (barrier before and after).
42    #[default]
43    Exclusive,
44}
45
46/// Optional streaming interface for tools.
47pub trait ToolDynStreaming: Send + Sync + 'static + ToolDyn {
48    /// Execute the tool with streaming chunk updates.
49    fn call_streaming<'a>(
50        &'a self,
51        input: serde_json::Value,
52        on_chunk: Box<dyn Fn(&str) + Send + Sync + 'a>,
53    ) -> Pin<Box<dyn Future<Output = Result<(), ToolError>> + Send + 'a>>;
54}
55/// Object-safe trait for tool implementations.
56///
57/// Any tool source (local function, MCP server, HTTP endpoint) implements
58/// this trait. Tools are stored as `Arc<dyn ToolDyn>` in [`ToolRegistry`].
59pub trait ToolDyn: Send + Sync {
60    /// The tool's unique name.
61    fn name(&self) -> &str;
62
63    /// Human-readable description of what the tool does.
64    fn description(&self) -> &str;
65
66    /// JSON Schema for the tool's input parameters.
67    fn input_schema(&self) -> serde_json::Value;
68
69    /// Execute the tool with the given input.
70    fn call(
71        &self,
72        input: serde_json::Value,
73    ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>;
74
75    /// If this tool also supports streaming, return a reference to its streaming interface.
76    /// Default is None; streaming is opt-in and non-disruptive.
77    fn maybe_streaming(&self) -> Option<&dyn ToolDynStreaming> {
78        None
79    }
80
81    /// Optional concurrency hint used by planners/deciders.
82    ///
83    /// Default is Exclusive to preserve backward-compatible behavior.
84    fn concurrency_hint(&self) -> ToolConcurrencyHint {
85        ToolConcurrencyHint::Exclusive
86    }
87}
88
89/// A tool wrapper that exposes a different name while delegating behavior to an inner tool.
90///
91/// This is useful when importing tools from external systems (e.g. MCP servers) where the
92/// upstream tool names are not stable or do not match the caller's desired naming scheme.
93pub struct AliasedTool {
94    alias: String,
95    inner: Arc<dyn ToolDyn>,
96}
97
98impl AliasedTool {
99    /// Create a new aliased tool wrapper.
100    pub fn new(alias: impl Into<String>, inner: Arc<dyn ToolDyn>) -> Self {
101        Self {
102            alias: alias.into(),
103            inner,
104        }
105    }
106
107    /// Access the wrapped tool.
108    pub fn inner(&self) -> &Arc<dyn ToolDyn> {
109        &self.inner
110    }
111}
112
113impl ToolDyn for AliasedTool {
114    fn name(&self) -> &str {
115        &self.alias
116    }
117
118    fn description(&self) -> &str {
119        self.inner.description()
120    }
121
122    fn input_schema(&self) -> serde_json::Value {
123        self.inner.input_schema()
124    }
125
126    fn call(
127        &self,
128        input: serde_json::Value,
129    ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>> {
130        self.inner.call(input)
131    }
132
133    fn concurrency_hint(&self) -> ToolConcurrencyHint {
134        self.inner.concurrency_hint()
135    }
136}
137
138/// Registry of tools available to a turn.
139///
140/// Holds tools as `Arc<dyn ToolDyn>` keyed by name. The turn's ReAct loop
141/// uses this to look up and execute tools requested by the model.
142#[derive(Clone)]
143pub struct ToolRegistry {
144    tools: HashMap<String, Arc<dyn ToolDyn>>,
145}
146
147impl ToolRegistry {
148    /// Create an empty registry.
149    pub fn new() -> Self {
150        Self {
151            tools: HashMap::new(),
152        }
153    }
154
155    /// Register a tool. Overwrites any existing tool with the same name.
156    pub fn register(&mut self, tool: Arc<dyn ToolDyn>) {
157        self.tools.insert(tool.name().to_string(), tool);
158    }
159
160    /// Look up a tool by name.
161    pub fn get(&self, name: &str) -> Option<&Arc<dyn ToolDyn>> {
162        self.tools.get(name)
163    }
164
165    /// Iterate over all registered tools.
166    pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn ToolDyn>> {
167        self.tools.values()
168    }
169
170    /// Number of registered tools.
171    pub fn len(&self) -> usize {
172        self.tools.len()
173    }
174
175    /// Whether the registry is empty.
176    pub fn is_empty(&self) -> bool {
177        self.tools.is_empty()
178    }
179}
180
181impl Default for ToolRegistry {
182    fn default() -> Self {
183        Self::new()
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use serde_json::json;
191
192    fn _assert_send_sync<T: Send + Sync>() {}
193
194    #[test]
195    fn tool_dyn_is_object_safe() {
196        _assert_send_sync::<Arc<dyn ToolDyn>>();
197    }
198
199    #[test]
200    fn tool_error_display() {
201        assert_eq!(
202            ToolError::NotFound("bash".into()).to_string(),
203            "tool not found: bash"
204        );
205        assert_eq!(
206            ToolError::ExecutionFailed("timeout".into()).to_string(),
207            "execution failed: timeout"
208        );
209        assert_eq!(
210            ToolError::InvalidInput("missing field".into()).to_string(),
211            "invalid input: missing field"
212        );
213    }
214
215    struct EchoTool;
216
217    impl ToolDyn for EchoTool {
218        fn name(&self) -> &str {
219            "echo"
220        }
221        fn description(&self) -> &str {
222            "Echoes input back"
223        }
224        fn input_schema(&self) -> serde_json::Value {
225            json!({"type": "object"})
226        }
227        fn call(
228            &self,
229            input: serde_json::Value,
230        ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>
231        {
232            Box::pin(async move { Ok(json!({"echoed": input})) })
233        }
234    }
235
236    struct FailTool;
237
238    impl ToolDyn for FailTool {
239        fn name(&self) -> &str {
240            "fail"
241        }
242        fn description(&self) -> &str {
243            "Always fails"
244        }
245        fn input_schema(&self) -> serde_json::Value {
246            json!({"type": "object"})
247        }
248        fn call(
249            &self,
250            _input: serde_json::Value,
251        ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>
252        {
253            Box::pin(async { Err(ToolError::ExecutionFailed("always fails".into())) })
254        }
255    }
256
257    #[test]
258    fn registry_add_and_get() {
259        let mut reg = ToolRegistry::new();
260        assert!(reg.is_empty());
261
262        reg.register(Arc::new(EchoTool));
263        assert_eq!(reg.len(), 1);
264        assert!(reg.get("echo").is_some());
265        assert!(reg.get("nonexistent").is_none());
266    }
267
268    #[test]
269    fn registry_iter() {
270        let mut reg = ToolRegistry::new();
271        reg.register(Arc::new(EchoTool));
272        reg.register(Arc::new(FailTool));
273
274        let names: Vec<&str> = reg.iter().map(|t| t.name()).collect();
275        assert!(names.contains(&"echo"));
276        assert!(names.contains(&"fail"));
277    }
278
279    #[tokio::test]
280    async fn registry_call_tool() {
281        let mut reg = ToolRegistry::new();
282        reg.register(Arc::new(EchoTool));
283
284        let tool = reg.get("echo").unwrap();
285        let result = tool.call(json!({"msg": "hello"})).await.unwrap();
286        assert_eq!(result, json!({"echoed": {"msg": "hello"}}));
287    }
288
289    #[tokio::test]
290    async fn aliased_tool_exposes_alias_name_and_delegates() {
291        let inner: Arc<dyn ToolDyn> = Arc::new(EchoTool);
292        let tool: Arc<dyn ToolDyn> = Arc::new(AliasedTool::new("echo_alias", Arc::clone(&inner)));
293
294        assert_eq!(tool.name(), "echo_alias");
295        assert_eq!(tool.description(), inner.description());
296
297        let result = tool.call(json!({"msg": "hi"})).await.unwrap();
298        assert_eq!(result, json!({"echoed": {"msg": "hi"}}));
299    }
300
301    #[tokio::test]
302    async fn registry_call_failing_tool() {
303        let mut reg = ToolRegistry::new();
304        reg.register(Arc::new(FailTool));
305
306        let tool = reg.get("fail").unwrap();
307        let result = tool.call(json!({})).await;
308        assert!(result.is_err());
309    }
310
311    #[test]
312    fn registry_overwrite() {
313        let mut reg = ToolRegistry::new();
314        reg.register(Arc::new(EchoTool));
315        assert_eq!(reg.len(), 1);
316
317        // Register another tool with the same name
318        reg.register(Arc::new(EchoTool));
319        assert_eq!(reg.len(), 1);
320    }
321
322    struct StreamerTool;
323    impl ToolDyn for StreamerTool {
324        fn name(&self) -> &str {
325            "streamer"
326        }
327        fn description(&self) -> &str {
328            "Streams chunks"
329        }
330        fn input_schema(&self) -> serde_json::Value {
331            json!({"type":"object"})
332        }
333        fn call(
334            &self,
335            _input: serde_json::Value,
336        ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>
337        {
338            Box::pin(async { Ok(serde_json::json!({"status":"done"})) })
339        }
340        fn maybe_streaming(&self) -> Option<&dyn ToolDynStreaming> {
341            Some(self)
342        }
343    }
344    impl ToolDynStreaming for StreamerTool {
345        fn call_streaming<'a>(
346            &'a self,
347            _input: serde_json::Value,
348            on_chunk: Box<dyn Fn(&str) + Send + Sync + 'a>,
349        ) -> Pin<Box<dyn Future<Output = Result<(), ToolError>> + Send + 'a>> {
350            Box::pin(async move {
351                on_chunk("one");
352                on_chunk("two");
353                on_chunk("three");
354                Ok(())
355            })
356        }
357    }
358
359    #[tokio::test]
360    async fn streaming_tool_emits_chunks_and_completes() {
361        use std::sync::{
362            Arc as StdArc, Mutex,
363            atomic::{AtomicUsize, Ordering},
364        };
365        let count = StdArc::new(AtomicUsize::new(0));
366        let seen: StdArc<Mutex<Vec<String>>> = StdArc::new(Mutex::new(vec![]));
367        let c2 = count.clone();
368        let s2 = seen.clone();
369        let tool = StreamerTool;
370        let on_chunk = Box::new(move |c: &str| {
371            c2.fetch_add(1, Ordering::SeqCst);
372            s2.lock().unwrap().push(c.to_string());
373        });
374        let res = tool.call_streaming(serde_json::json!({}), on_chunk).await;
375        assert!(res.is_ok());
376        assert_eq!(count.load(Ordering::SeqCst), 3);
377        let got = seen.lock().unwrap().clone();
378        assert_eq!(got, vec!["one", "two", "three"]);
379    }
380}