Skip to main content

neuron_tool/
registry.rs

1//! Tool registry: register, lookup, and execute tools.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use neuron_types::{Tool, ToolContext, ToolDefinition, ToolDyn, ToolError, ToolOutput};
7
8use crate::middleware::{Next, ToolCall, ToolMiddleware};
9
10/// Registry of tools with optional middleware pipelines.
11///
12/// Tools are stored as type-erased [`ToolDyn`] trait objects.
13/// Middleware can be added globally (applies to all tools) or per-tool.
14pub struct ToolRegistry {
15    tools: HashMap<String, Arc<dyn ToolDyn>>,
16    global_middleware: Vec<Arc<dyn ToolMiddleware>>,
17    tool_middleware: HashMap<String, Vec<Arc<dyn ToolMiddleware>>>,
18}
19
20impl ToolRegistry {
21    /// Create an empty registry.
22    #[must_use]
23    pub fn new() -> Self {
24        Self {
25            tools: HashMap::new(),
26            global_middleware: Vec::new(),
27            tool_middleware: HashMap::new(),
28        }
29    }
30
31    /// Register a strongly-typed tool (auto-erased to `ToolDyn`).
32    pub fn register<T: Tool + 'static>(&mut self, tool: T) {
33        let name = T::NAME.to_string();
34        self.tools.insert(name, Arc::new(tool));
35    }
36
37    /// Register a pre-erased tool.
38    pub fn register_dyn(&mut self, tool: Arc<dyn ToolDyn>) {
39        let name = tool.name().to_string();
40        self.tools.insert(name, tool);
41    }
42
43    /// Look up a tool by name.
44    #[must_use]
45    pub fn get(&self, name: &str) -> Option<Arc<dyn ToolDyn>> {
46        self.tools.get(name).cloned()
47    }
48
49    /// Get definitions for all registered tools.
50    #[must_use]
51    pub fn definitions(&self) -> Vec<ToolDefinition> {
52        self.tools.values().map(|t| t.definition()).collect()
53    }
54
55    /// Add global middleware (applies to all tool executions).
56    pub fn add_middleware(&mut self, m: impl ToolMiddleware + 'static) -> &mut Self {
57        self.global_middleware.push(Arc::new(m));
58        self
59    }
60
61    /// Add middleware that only applies to a specific tool.
62    pub fn add_tool_middleware(
63        &mut self,
64        tool_name: &str,
65        m: impl ToolMiddleware + 'static,
66    ) -> &mut Self {
67        self.tool_middleware
68            .entry(tool_name.to_string())
69            .or_default()
70            .push(Arc::new(m));
71        self
72    }
73
74    /// Execute a tool by name, running it through the middleware chain.
75    ///
76    /// Middleware order: global middleware first, then per-tool middleware,
77    /// then the actual tool.
78    pub async fn execute(
79        &self,
80        name: &str,
81        input: serde_json::Value,
82        ctx: &ToolContext,
83    ) -> Result<ToolOutput, ToolError> {
84        let tool = self
85            .tools
86            .get(name)
87            .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
88
89        let call = ToolCall {
90            id: String::new(),
91            name: name.to_string(),
92            input,
93        };
94
95        // Build combined middleware chain: global + per-tool
96        let mut chain: Vec<Arc<dyn ToolMiddleware>> = self.global_middleware.clone();
97        if let Some(per_tool) = self.tool_middleware.get(name) {
98            chain.extend(per_tool.iter().cloned());
99        }
100
101        let next = Next::new(tool.as_ref(), &chain);
102        next.run(&call, ctx).await
103    }
104}
105
106impl Default for ToolRegistry {
107    fn default() -> Self {
108        Self::new()
109    }
110}