1use std::sync::Arc;
117
118use anyhow::Result;
119use async_trait::async_trait;
120use dashmap::DashMap;
121use serde_json::Value;
122
123use neuromance_common::tools::{Tool, ToolCall};
124
125mod bool_tool;
126pub mod generic;
127pub mod mcp;
128mod think_tool;
129mod todo_tool;
130pub use bool_tool::BooleanTool;
131pub use think_tool::ThinkTool;
132pub use todo_tool::{TodoReadTool, TodoWriteTool, create_todo_tools};
133
134#[async_trait]
135pub trait ToolImplementation: Send + Sync {
136 fn get_definition(&self) -> Tool;
137
138 async fn execute(&self, args: &Value) -> Result<String>;
139
140 fn is_auto_approved(&self) -> bool {
141 false
142 }
143}
144
145pub struct ToolRegistry {
146 tools: Arc<DashMap<String, Arc<dyn ToolImplementation>>>,
147}
148
149impl Default for ToolRegistry {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155impl ToolRegistry {
156 pub fn new() -> Self {
157 Self {
158 tools: Arc::new(DashMap::new()),
159 }
160 }
161
162 pub fn register(&self, tool: Arc<dyn ToolImplementation>) {
163 let name = tool.get_definition().function.name.clone();
164 self.tools.insert(name, tool);
165 }
166
167 pub fn get(&self, name: &str) -> Option<Arc<dyn ToolImplementation>> {
168 self.tools.get(name).map(|r| r.value().clone())
169 }
170
171 pub fn get_all_definitions(&self) -> Vec<Tool> {
172 self.tools.iter().map(|t| t.get_definition()).collect()
173 }
174
175 pub fn is_tool_auto_approved(&self, name: &str) -> bool {
176 self.tools
177 .get(name)
178 .map(|t| t.is_auto_approved())
179 .unwrap_or(false)
180 }
181
182 pub fn remove(&mut self, name: &str) -> Option<Arc<dyn ToolImplementation>> {
183 self.tools.remove(name).map(|(_, tool)| tool)
184 }
185
186 pub fn clear(&mut self) {
187 self.tools.clear();
188 }
189
190 pub fn contains(&self, name: &str) -> bool {
191 self.tools.contains_key(name)
192 }
193
194 pub fn tool_names(&self) -> Vec<String> {
195 self.tools.iter().map(|t| t.key().clone()).collect()
196 }
197}
198
199pub struct ToolExecutor {
200 registry: ToolRegistry,
201}
202
203impl ToolExecutor {
204 pub fn new() -> Self {
205 Self {
206 registry: ToolRegistry::new(),
207 }
208 }
209
210 pub fn add_tool<T: ToolImplementation + 'static>(&mut self, tool: T) {
211 self.registry.register(Arc::new(tool));
212 }
213
214 pub fn add_tool_arc(&mut self, tool: Arc<dyn ToolImplementation>) {
215 self.registry.register(tool);
216 }
217
218 pub async fn has_tool(&self, name: &str) -> Result<bool> {
219 Ok(self.registry.contains(name))
220 }
221
222 pub fn get_all_tools(&self) -> Vec<Tool> {
223 self.registry.get_all_definitions()
224 }
225
226 pub fn is_tool_auto_approved(&self, name: &str) -> bool {
227 self.registry.is_tool_auto_approved(name)
228 }
229
230 pub async fn remove_tool(&mut self, name: &str) -> Result<Option<Arc<dyn ToolImplementation>>> {
231 let tool = self.registry.remove(name);
232 Ok(tool)
233 }
234
235 pub async fn reset_tools(&mut self) {
236 self.registry.clear();
237 }
238
239 pub async fn execute_tool(&self, tool_call: &ToolCall) -> Result<String> {
240 let function = &tool_call.function;
241
242 let tool = self
243 .registry
244 .get(&function.name)
245 .ok_or_else(|| anyhow::anyhow!("Unknown tool: '{}'", function.name))?;
246
247 let args = self.parse_arguments(&function.arguments)?;
248
249 tool.execute(&args).await
251 }
252
253 fn parse_arguments(&self, arguments: &[String]) -> Result<Value> {
254 match arguments.len() {
255 0 => Ok(Value::Object(serde_json::Map::new())),
257 1 => serde_json::from_str(&arguments[0])
261 .or_else(|_| Ok(Value::String(arguments[0].clone()))),
262 _ => Ok(Value::Array(
264 arguments.iter().map(|s| Value::String(s.clone())).collect(),
265 )),
266 }
267 }
268}
269
270impl Default for ToolExecutor {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use serde_json::json;
280
281 #[test]
282 fn test_parse_arguments_empty() {
283 let executor = ToolExecutor::new();
284 let result = executor.parse_arguments(&[]).unwrap();
285 assert_eq!(result, json!({}));
286 }
287
288 #[test]
289 fn test_parse_arguments_single_json_object() {
290 let executor = ToolExecutor::new();
291 let args = vec![r#"{"key": "value", "number": 42}"#.to_string()];
292 let result = executor.parse_arguments(&args).unwrap();
293 assert_eq!(result, json!({"key": "value", "number": 42}));
294 }
295
296 #[test]
297 fn test_parse_arguments_single_json_array() {
298 let executor = ToolExecutor::new();
299 let args = vec![r#"["item1", "item2", "item3"]"#.to_string()];
300 let result = executor.parse_arguments(&args).unwrap();
301 assert_eq!(result, json!(["item1", "item2", "item3"]));
302 }
303
304 #[test]
305 fn test_parse_arguments_single_string_fallback() {
306 let executor = ToolExecutor::new();
307 let args = vec!["plain text argument".to_string()];
308 let result = executor.parse_arguments(&args).unwrap();
309 assert_eq!(result, json!("plain text argument"));
310 }
311
312 #[test]
313 fn test_parse_arguments_single_invalid_json_fallback() {
314 let executor = ToolExecutor::new();
315 let args = vec![r#"{"incomplete json"#.to_string()];
316 let result = executor.parse_arguments(&args).unwrap();
317 assert_eq!(result, json!(r#"{"incomplete json"#));
318 }
319
320 #[test]
321 fn test_parse_arguments_multiple_strings() {
322 let executor = ToolExecutor::new();
323 let args = vec!["arg1".to_string(), "arg2".to_string(), "arg3".to_string()];
324 let result = executor.parse_arguments(&args).unwrap();
325 assert_eq!(result, json!(["arg1", "arg2", "arg3"]));
326 }
327
328 #[test]
329 fn test_parse_arguments_single_number_string() {
330 let executor = ToolExecutor::new();
331 let args = vec!["42".to_string()];
332 let result = executor.parse_arguments(&args).unwrap();
333 assert_eq!(result, json!(42));
334 }
335
336 #[test]
337 fn test_parse_arguments_single_boolean_string() {
338 let executor = ToolExecutor::new();
339 let args = vec!["true".to_string()];
340 let result = executor.parse_arguments(&args).unwrap();
341 assert_eq!(result, json!(true));
342 }
343}