Skip to main content

bamboo_agent/agent/core/tools/
registry.rs

1//! Tool registry for managing and executing tools.
2//!
3//! This module provides a thread-safe registry for tool management,
4//! including registration, lookup, and execution of tools.
5//!
6//! # Key Types
7//!
8//! - [`Tool`] - Trait for implementing executable tools
9//! - [`ToolRegistry`] - Thread-safe tool registry
10//! - [`RegistryError`] - Registration errors
11//! - [`SharedTool`] - Reference-counted tool pointer
12//!
13//! # Usage
14//!
15//! ```rust,ignore
16//! use bamboo_agent::agent::core::tools::registry::*;
17//!
18//! // Create a registry
19//! let registry = ToolRegistry::new();
20//!
21//! // Register a tool
22//! registry.register(MyTool::new())?;
23//!
24//! // Get tool schema for LLM
25//! let schemas = registry.list_tools();
26//!
27//! // Execute a tool
28//! let tool = registry.get("my_tool").unwrap();
29//! let result = tool.execute(args).await?;
30//! ```
31//!
32//! # Global Registry
33//!
34//! For convenience, a global singleton registry is available:
35//!
36//! ```rust,ignore
37//! let registry = global_registry();
38//! registry.register(my_tool)?;
39//! ```
40
41use std::sync::{Arc, OnceLock};
42
43use async_trait::async_trait;
44use dashmap::{mapref::entry::Entry, DashMap};
45use thiserror::Error;
46
47use crate::agent::core::tools::{
48    FunctionSchema, ToolError, ToolExecutionContext, ToolResult, ToolSchema,
49};
50
51/// Trait for implementing executable tools.
52///
53/// All tools must implement this trait to be registered with the tool registry.
54///
55/// # Required Methods
56///
57/// - `name()` - Unique tool identifier
58/// - `description()` - Human-readable tool description
59/// - `parameters_schema()` - JSON Schema for tool parameters
60/// - `execute()` - Async tool execution logic
61///
62/// # Provided Methods
63///
64/// - `to_schema()` - Convert tool to LLM-compatible schema
65///
66/// # Example
67///
68/// ```rust,ignore
69/// struct ReadFileTool;
70///
71/// #[async_trait]
72/// impl Tool for ReadFileTool {
73///     fn name(&self) -> &str {
74///         "read_file"
75///     }
76///
77///     fn description(&self) -> &str {
78///         "Read file contents from disk"
79///     }
80///
81///     fn parameters_schema(&self) -> serde_json::Value {
82///         json!({
83///             "type": "object",
84///             "properties": {
85///                 "path": {"type": "string"}
86///             },
87///             "required": ["path"]
88///         })
89///     }
90///
91///     async fn execute(&self, args: Value) -> Result<ToolResult, ToolError> {
92///         let path = args["path"].as_str().unwrap();
93///         let content = tokio::fs::read_to_string(path).await?;
94///         Ok(ToolResult {
95///             success: true,
96///             result: content,
97///             display_preference: None,
98///         })
99///     }
100/// }
101/// ```
102#[async_trait]
103pub trait Tool: Send + Sync {
104    fn name(&self) -> &str;
105    /// Human-readable tool description for LLM.
106    fn description(&self) -> &str;
107    /// JSON Schema for tool parameters.
108    fn parameters_schema(&self) -> serde_json::Value;
109    /// Execute the tool with given arguments.
110    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError>;
111
112    /// Execute the tool with a streaming-capable context.
113    ///
114    /// Default implementation falls back to `execute()` for tools that don't
115    /// need streaming.
116    async fn execute_with_context(
117        &self,
118        args: serde_json::Value,
119        _ctx: ToolExecutionContext<'_>,
120    ) -> Result<ToolResult, ToolError> {
121        self.execute(args).await
122    }
123
124    /// Convert tool to LLM-compatible schema.
125    ///
126    /// Creates a [`ToolSchema`] suitable for LLM function calling.
127    fn to_schema(&self) -> ToolSchema {
128        ToolSchema {
129            schema_type: "function".to_string(),
130            function: FunctionSchema {
131                name: self.name().to_string(),
132                description: self.description().to_string(),
133                parameters: self.parameters_schema(),
134            },
135        }
136    }
137}
138
139/// Reference-counted pointer to a tool.
140pub type SharedTool = Arc<dyn Tool>;
141
142/// Errors that can occur during tool registration.
143///
144/// # Variants
145///
146/// * `DuplicateTool` - Tool with same name already registered
147/// * `InvalidTool` - Tool validation failed (e.g., empty name)
148#[derive(Debug, Error, PartialEq, Eq)]
149pub enum RegistryError {
150    /// Tool with same name already exists in registry.
151    #[error("tool with name '{0}' already registered")]
152    DuplicateTool(String),
153
154    /// Tool validation failed.
155    #[error("invalid tool: {0}")]
156    InvalidTool(String),
157}
158
159/// Thread-safe tool registry.
160///
161/// Manages a collection of tools with concurrent access support.
162/// Uses a `DashMap` for lock-free concurrent operations.
163///
164/// # Features
165///
166/// - Thread-safe registration and lookup
167/// - Tool schema generation for LLM
168/// - Global singleton registry support
169///
170/// # Example
171///
172/// ```rust,ignore
173/// let registry = ToolRegistry::new();
174///
175/// // Register tools
176/// registry.register(ReadFileTool::new())?;
177/// registry.register(WriteFileTool::new())?;
178///
179/// // List all tool schemas
180/// let schemas = registry.list_tools();
181///
182/// // Get and execute a tool
183/// if let Some(tool) = registry.get("read_file") {
184///     let result = tool.execute(json!({"path": "test.txt"})).await?;
185/// }
186/// ```
187pub struct ToolRegistry {
188    tools: DashMap<String, SharedTool>,
189}
190
191impl Default for ToolRegistry {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197impl ToolRegistry {
198    /// Create a new empty tool registry.
199    pub fn new() -> Self {
200        Self {
201            tools: DashMap::new(),
202        }
203    }
204
205    /// Register a tool in the registry.
206    ///
207    /// # Arguments
208    ///
209    /// * `tool` - Tool to register
210    ///
211    /// # Errors
212    ///
213    /// Returns [`RegistryError::DuplicateTool`] if tool name already exists.
214    /// Returns [`RegistryError::InvalidTool`] if tool name is empty.
215    ///
216    /// # Example
217    ///
218    /// ```rust,ignore
219    /// registry.register(MyTool::new())?;
220    /// ```
221    pub fn register<T>(&self, tool: T) -> Result<(), RegistryError>
222    where
223        T: Tool + 'static,
224    {
225        self.register_shared(Arc::new(tool))
226    }
227
228    /// Register a shared tool reference.
229    ///
230    /// # Arguments
231    ///
232    /// * `tool` - Shared tool reference
233    ///
234    /// # Errors
235    ///
236    /// Same as [`register`](Self::register).
237    pub fn register_shared(&self, tool: SharedTool) -> Result<(), RegistryError> {
238        let name = tool.name().trim();
239
240        if name.is_empty() {
241            return Err(RegistryError::InvalidTool(
242                "tool name cannot be empty".to_string(),
243            ));
244        }
245
246        match self.tools.entry(name.to_string()) {
247            Entry::Occupied(_) => Err(RegistryError::DuplicateTool(name.to_string())),
248            Entry::Vacant(entry) => {
249                entry.insert(tool);
250                Ok(())
251            }
252        }
253    }
254
255    /// Get a tool by name.
256    ///
257    /// # Arguments
258    ///
259    /// * `name` - Tool name
260    ///
261    /// # Returns
262    ///
263    /// Shared tool reference if found, `None` otherwise.
264    pub fn get(&self, name: &str) -> Option<SharedTool> {
265        self.tools.get(name).map(|entry| Arc::clone(entry.value()))
266    }
267
268    /// Check if a tool exists in the registry.
269    pub fn contains(&self, name: &str) -> bool {
270        self.tools.contains_key(name)
271    }
272
273    /// List all tool schemas.
274    ///
275    /// Returns schemas sorted alphabetically by tool name.
276    pub fn list_tools(&self) -> Vec<ToolSchema> {
277        let mut tools: Vec<ToolSchema> = self
278            .tools
279            .iter()
280            .map(|entry| entry.value().to_schema())
281            .collect();
282        tools.sort_by(|left, right| left.function.name.cmp(&right.function.name));
283        tools
284    }
285
286    /// List all tool names.
287    ///
288    /// Returns names sorted alphabetically.
289    pub fn list_tool_names(&self) -> Vec<String> {
290        let mut names: Vec<String> = self.tools.iter().map(|entry| entry.key().clone()).collect();
291        names.sort();
292        names
293    }
294
295    /// Remove a tool from the registry.
296    ///
297    /// # Returns
298    ///
299    /// `true` if tool was removed, `false` if not found.
300    pub fn unregister(&self, name: &str) -> bool {
301        self.tools.remove(name).is_some()
302    }
303
304    /// Get the number of registered tools.
305    pub fn len(&self) -> usize {
306        self.tools.len()
307    }
308
309    /// Check if registry is empty.
310    pub fn is_empty(&self) -> bool {
311        self.tools.is_empty()
312    }
313
314    /// Remove all tools from the registry.
315    pub fn clear(&self) {
316        self.tools.clear();
317    }
318}
319
320/// Global tool registry singleton.
321static GLOBAL_REGISTRY: OnceLock<ToolRegistry> = OnceLock::new();
322
323/// Get the global tool registry.
324///
325/// The global registry is a singleton that persists for the lifetime
326/// of the application. Useful for sharing tools across components.
327///
328/// # Example
329///
330/// ```rust,ignore
331/// let registry = global_registry();
332/// registry.register(my_tool)?;
333/// ```
334pub fn global_registry() -> &'static ToolRegistry {
335    GLOBAL_REGISTRY.get_or_init(ToolRegistry::new)
336}
337
338/// Normalize a tool name by removing namespace prefix.
339///
340/// # Arguments
341///
342/// * `name` - Tool name (may include `::` namespace separator)
343///
344/// # Returns
345///
346/// Tool name after the last `::`, or the original name if no separator.
347///
348/// # Example
349///
350/// ```rust,ignore
351/// assert_eq!(normalize_tool_name("bamboo::read_file"), "read_file");
352/// assert_eq!(normalize_tool_name("read_file"), "read_file");
353/// ```
354pub fn normalize_tool_name(name: &str) -> &str {
355    name.split("::").last().unwrap_or(name)
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    use serde_json::json;
363
364    struct TestTool {
365        name: &'static str,
366        description: &'static str,
367    }
368
369    #[async_trait]
370    impl Tool for TestTool {
371        fn name(&self) -> &str {
372            self.name
373        }
374
375        fn description(&self) -> &str {
376            self.description
377        }
378
379        fn parameters_schema(&self) -> serde_json::Value {
380            json!({
381                "type": "object",
382                "properties": {}
383            })
384        }
385
386        async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult, ToolError> {
387            Ok(ToolResult {
388                success: true,
389                result: "ok".to_string(),
390                display_preference: None,
391            })
392        }
393    }
394
395    #[test]
396    fn register_and_get() {
397        let registry = ToolRegistry::new();
398        let tool = TestTool {
399            name: "test_tool",
400            description: "test tool",
401        };
402
403        assert!(registry.register(tool).is_ok());
404        assert!(registry.get("test_tool").is_some());
405        assert!(registry.get("unknown").is_none());
406    }
407
408    #[test]
409    fn duplicate_tool_registration() {
410        let registry = ToolRegistry::new();
411
412        registry
413            .register(TestTool {
414                name: "dup",
415                description: "first",
416            })
417            .unwrap();
418
419        let duplicate = registry.register(TestTool {
420            name: "dup",
421            description: "second",
422        });
423
424        assert!(matches!(duplicate, Err(RegistryError::DuplicateTool(name)) if name == "dup"));
425    }
426
427    #[test]
428    fn list_tools_returns_registered_tools() {
429        let registry = ToolRegistry::new();
430
431        registry
432            .register(TestTool {
433                name: "tool_a",
434                description: "tool a",
435            })
436            .unwrap();
437        registry
438            .register(TestTool {
439                name: "tool_b",
440                description: "tool b",
441            })
442            .unwrap();
443
444        let tools = registry.list_tools();
445
446        assert_eq!(tools.len(), 2);
447        assert_eq!(tools[0].function.name, "tool_a");
448        assert_eq!(tools[1].function.name, "tool_b");
449    }
450
451    #[test]
452    fn register_rejects_empty_tool_name() {
453        let registry = ToolRegistry::new();
454
455        let result = registry.register(TestTool {
456            name: "",
457            description: "invalid",
458        });
459
460        assert!(
461            matches!(result, Err(RegistryError::InvalidTool(reason)) if reason == "tool name cannot be empty")
462        );
463    }
464
465    #[test]
466    fn normalize_tool_name_handles_namespaced_inputs() {
467        assert_eq!(normalize_tool_name("read_file"), "read_file");
468        assert_eq!(normalize_tool_name("default::read_file"), "read_file");
469        assert_eq!(normalize_tool_name("a::b::c::read_file"), "read_file");
470    }
471}