1use std::collections::HashMap;
4use std::sync::Arc;
5
6use neuron_types::{Tool, ToolContext, ToolDefinition, ToolDyn, ToolError, ToolOutput};
7
8use crate::middleware::{Next, ToolCall, ToolMiddleware};
9
10pub struct ToolRegistry {
15 tools: HashMap<String, Arc<dyn ToolDyn>>,
16 global_middleware: Vec<Arc<dyn ToolMiddleware>>,
17 tool_middleware: HashMap<String, Vec<Arc<dyn ToolMiddleware>>>,
18}
19
20impl ToolRegistry {
21 #[must_use]
23 pub fn new() -> Self {
24 Self {
25 tools: HashMap::new(),
26 global_middleware: Vec::new(),
27 tool_middleware: HashMap::new(),
28 }
29 }
30
31 pub fn register<T: Tool + 'static>(&mut self, tool: T) {
33 let name = T::NAME.to_string();
34 self.tools.insert(name, Arc::new(tool));
35 }
36
37 pub fn register_dyn(&mut self, tool: Arc<dyn ToolDyn>) {
39 let name = tool.name().to_string();
40 self.tools.insert(name, tool);
41 }
42
43 #[must_use]
45 pub fn get(&self, name: &str) -> Option<Arc<dyn ToolDyn>> {
46 self.tools.get(name).cloned()
47 }
48
49 #[must_use]
51 pub fn definitions(&self) -> Vec<ToolDefinition> {
52 self.tools.values().map(|t| t.definition()).collect()
53 }
54
55 pub fn add_middleware(&mut self, m: impl ToolMiddleware + 'static) -> &mut Self {
57 self.global_middleware.push(Arc::new(m));
58 self
59 }
60
61 pub fn add_tool_middleware(
63 &mut self,
64 tool_name: &str,
65 m: impl ToolMiddleware + 'static,
66 ) -> &mut Self {
67 self.tool_middleware
68 .entry(tool_name.to_string())
69 .or_default()
70 .push(Arc::new(m));
71 self
72 }
73
74 pub async fn execute(
79 &self,
80 name: &str,
81 input: serde_json::Value,
82 ctx: &ToolContext,
83 ) -> Result<ToolOutput, ToolError> {
84 let tool = self
85 .tools
86 .get(name)
87 .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
88
89 let call = ToolCall {
90 id: String::new(),
91 name: name.to_string(),
92 input,
93 };
94
95 let mut chain: Vec<Arc<dyn ToolMiddleware>> = self.global_middleware.clone();
97 if let Some(per_tool) = self.tool_middleware.get(name) {
98 chain.extend(per_tool.iter().cloned());
99 }
100
101 let next = Next::new(tool.as_ref(), &chain);
102 next.run(&call, ctx).await
103 }
104}
105
106impl Default for ToolRegistry {
107 fn default() -> Self {
108 Self::new()
109 }
110}