Skip to main content

bamboo_agent_core/tools/
registry.rs

1//! Tool registry for managing and executing tools.
2//!
3//! This module provides a thread-safe registry for tool management,
4//! including registration, lookup, and execution of tools.
5//!
6//! # Key Types
7//!
8//! - [`Tool`] - Trait for implementing executable tools
9//! - [`ToolRegistry`] - Thread-safe tool registry
10//! - [`RegistryError`] - Registration errors
11//! - [`SharedTool`] - Reference-counted tool pointer
12//!
13//! # Usage
14//!
15//! ```rust,ignore
16//! use bamboo_agent::agent::core::tools::registry::*;
17//!
18//! // Create a registry
19//! let registry = ToolRegistry::new();
20//!
21//! // Register a tool
22//! registry.register(MyTool::new())?;
23//!
24//! // Get tool schema for LLM
25//! let schemas = registry.list_tools();
26//!
27//! // Execute a tool
28//! let tool = registry.get("my_tool").unwrap();
29//! let result = tool.execute(args).await?;
30//! ```
31//!
32//! # Global Registry
33//!
34//! For convenience, a global singleton registry is available:
35//!
36//! ```rust,ignore
37//! let registry = global_registry();
38//! registry.register(my_tool)?;
39//! ```
40
41use std::sync::{Arc, OnceLock};
42
43use async_trait::async_trait;
44use dashmap::{mapref::entry::Entry, DashMap};
45use thiserror::Error;
46
47use crate::tools::{FunctionSchema, ToolError, ToolExecutionContext, ToolResult, ToolSchema};
48
49/// Trait for implementing executable tools.
50///
51/// All tools must implement this trait to be registered with the tool registry.
52///
53/// # Required Methods
54///
55/// - `name()` - Unique tool identifier
56/// - `description()` - Human-readable tool description
57/// - `parameters_schema()` - JSON Schema for tool parameters
58/// - `execute()` - Async tool execution logic
59///
60/// # Provided Methods
61///
62/// - `to_schema()` - Convert tool to LLM-compatible schema
63///
64/// # Example
65///
66/// ```rust,ignore
67/// struct ReadFileTool;
68///
69/// #[async_trait]
70/// impl Tool for ReadFileTool {
71///     fn name(&self) -> &str {
72///         "read_file"
73///     }
74///
75///     fn description(&self) -> &str {
76///         "Read file contents from disk"
77///     }
78///
79///     fn parameters_schema(&self) -> serde_json::Value {
80///         json!({
81///             "type": "object",
82///             "properties": {
83///                 "path": {"type": "string"}
84///             },
85///             "required": ["path"]
86///         })
87///     }
88///
89///     async fn execute(&self, args: Value) -> Result<ToolResult, ToolError> {
90///         let path = args["path"].as_str().unwrap();
91///         let content = tokio::fs::read_to_string(path).await?;
92///         Ok(ToolResult {
93///             success: true,
94///             result: content,
95///             display_preference: None,
96///         })
97///     }
98/// }
99/// ```
100#[async_trait]
101pub trait Tool: Send + Sync {
102    fn name(&self) -> &str;
103    /// Human-readable tool description for LLM.
104    fn description(&self) -> &str;
105    /// JSON Schema for tool parameters.
106    fn parameters_schema(&self) -> serde_json::Value;
107    /// Declares whether this tool is read-only or mutating for orchestration and
108    /// parallel scheduling decisions. Defaults to mutating to stay conservative.
109    fn mutability(&self) -> crate::tools::ToolMutability {
110        crate::tools::ToolMutability::Mutating
111    }
112    /// Args-aware mutability hook. Defaults to the static mutability declaration.
113    fn call_mutability(&self, _args: &serde_json::Value) -> crate::tools::ToolMutability {
114        self.mutability()
115    }
116    /// Declares whether this tool can safely run in parallel with other
117    /// read-only tools. Defaults to false so tools remain serialized unless
118    /// they opt in explicitly.
119    fn concurrency_safe(&self) -> bool {
120        false
121    }
122    /// Args-aware parallel-safety hook. Defaults to the static declaration.
123    fn call_concurrency_safe(&self, _args: &serde_json::Value) -> bool {
124        self.concurrency_safe()
125    }
126    /// Execute the tool with given arguments.
127    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError>;
128
129    /// Execute the tool with a streaming-capable context.
130    ///
131    /// Default implementation falls back to `execute()` for tools that don't
132    /// need streaming.
133    async fn execute_with_context(
134        &self,
135        args: serde_json::Value,
136        _ctx: ToolExecutionContext<'_>,
137    ) -> Result<ToolResult, ToolError> {
138        self.execute(args).await
139    }
140
141    /// Convert tool to LLM-compatible schema.
142    ///
143    /// Creates a [`ToolSchema`] suitable for LLM function calling.
144    fn to_schema(&self) -> ToolSchema {
145        ToolSchema {
146            schema_type: "function".to_string(),
147            function: FunctionSchema {
148                name: self.name().to_string(),
149                description: self.description().to_string(),
150                parameters: self.parameters_schema(),
151            },
152        }
153    }
154}
155
156/// Reference-counted pointer to a tool.
157pub type SharedTool = Arc<dyn Tool>;
158
159/// Errors that can occur during tool registration.
160///
161/// # Variants
162///
163/// * `DuplicateTool` - Tool with same name already registered
164/// * `InvalidTool` - Tool validation failed (e.g., empty name)
165#[derive(Debug, Error, PartialEq, Eq)]
166pub enum RegistryError {
167    /// Tool with same name already exists in registry.
168    #[error("tool with name '{0}' already registered")]
169    DuplicateTool(String),
170
171    /// Tool validation failed.
172    #[error("invalid tool: {0}")]
173    InvalidTool(String),
174}
175
176/// Thread-safe tool registry.
177///
178/// Manages a collection of tools with concurrent access support.
179/// Uses a `DashMap` for lock-free concurrent operations.
180///
181/// # Features
182///
183/// - Thread-safe registration and lookup
184/// - Tool schema generation for LLM
185/// - Global singleton registry support
186///
187/// # Example
188///
189/// ```rust,ignore
190/// let registry = ToolRegistry::new();
191///
192/// // Register tools
193/// registry.register(ReadFileTool::new())?;
194/// registry.register(WriteFileTool::new())?;
195///
196/// // List all tool schemas
197/// let schemas = registry.list_tools();
198///
199/// // Get and execute a tool
200/// if let Some(tool) = registry.get("read_file") {
201///     let result = tool.execute(json!({"path": "test.txt"})).await?;
202/// }
203/// ```
204pub struct ToolRegistry {
205    tools: DashMap<String, SharedTool>,
206}
207
208impl Default for ToolRegistry {
209    fn default() -> Self {
210        Self::new()
211    }
212}
213
214impl ToolRegistry {
215    /// Create a new empty tool registry.
216    pub fn new() -> Self {
217        Self {
218            tools: DashMap::new(),
219        }
220    }
221
222    /// Register a tool in the registry.
223    ///
224    /// # Arguments
225    ///
226    /// * `tool` - Tool to register
227    ///
228    /// # Errors
229    ///
230    /// Returns [`RegistryError::DuplicateTool`] if tool name already exists.
231    /// Returns [`RegistryError::InvalidTool`] if tool name is empty.
232    ///
233    /// # Example
234    ///
235    /// ```rust,ignore
236    /// registry.register(MyTool::new())?;
237    /// ```
238    pub fn register<T>(&self, tool: T) -> Result<(), RegistryError>
239    where
240        T: Tool + 'static,
241    {
242        self.register_shared(Arc::new(tool))
243    }
244
245    /// Register a shared tool reference.
246    ///
247    /// # Arguments
248    ///
249    /// * `tool` - Shared tool reference
250    ///
251    /// # Errors
252    ///
253    /// Same as [`register`](Self::register).
254    pub fn register_shared(&self, tool: SharedTool) -> Result<(), RegistryError> {
255        let name = tool.name().trim();
256
257        if name.is_empty() {
258            return Err(RegistryError::InvalidTool(
259                "tool name cannot be empty".to_string(),
260            ));
261        }
262
263        match self.tools.entry(name.to_string()) {
264            Entry::Occupied(_) => Err(RegistryError::DuplicateTool(name.to_string())),
265            Entry::Vacant(entry) => {
266                entry.insert(tool);
267                Ok(())
268            }
269        }
270    }
271
272    /// Get a tool by name.
273    ///
274    /// # Arguments
275    ///
276    /// * `name` - Tool name
277    ///
278    /// # Returns
279    ///
280    /// Shared tool reference if found, `None` otherwise.
281    pub fn get(&self, name: &str) -> Option<SharedTool> {
282        self.tools.get(name).map(|entry| Arc::clone(entry.value()))
283    }
284
285    /// Check if a tool exists in the registry.
286    pub fn contains(&self, name: &str) -> bool {
287        self.tools.contains_key(name)
288    }
289
290    /// List all tool schemas.
291    ///
292    /// Returns schemas sorted alphabetically by tool name.
293    pub fn list_tools(&self) -> Vec<ToolSchema> {
294        let mut tools: Vec<ToolSchema> = self
295            .tools
296            .iter()
297            .map(|entry| entry.value().to_schema())
298            .collect();
299        tools.sort_by_key(|t| t.function.name.clone());
300        tools
301    }
302
303    /// List all tool names.
304    ///
305    /// Returns names sorted alphabetically.
306    pub fn list_tool_names(&self) -> Vec<String> {
307        let mut names: Vec<String> = self.tools.iter().map(|entry| entry.key().clone()).collect();
308        names.sort();
309        names
310    }
311
312    /// Remove a tool from the registry.
313    ///
314    /// # Returns
315    ///
316    /// `true` if tool was removed, `false` if not found.
317    pub fn unregister(&self, name: &str) -> bool {
318        self.tools.remove(name).is_some()
319    }
320
321    /// Get the number of registered tools.
322    pub fn len(&self) -> usize {
323        self.tools.len()
324    }
325
326    /// Check if registry is empty.
327    pub fn is_empty(&self) -> bool {
328        self.tools.is_empty()
329    }
330
331    /// Remove all tools from the registry.
332    pub fn clear(&self) {
333        self.tools.clear();
334    }
335}
336
337/// Global tool registry singleton.
338static GLOBAL_REGISTRY: OnceLock<ToolRegistry> = OnceLock::new();
339
340/// Get the global tool registry.
341///
342/// The global registry is a singleton that persists for the lifetime
343/// of the application. Useful for sharing tools across components.
344///
345/// # Example
346///
347/// ```rust,ignore
348/// let registry = global_registry();
349/// registry.register(my_tool)?;
350/// ```
351pub fn global_registry() -> &'static ToolRegistry {
352    GLOBAL_REGISTRY.get_or_init(ToolRegistry::new)
353}
354
355/// Normalize a tool name by removing namespace prefix.
356///
357/// # Arguments
358///
359/// * `name` - Tool name (may include `::` namespace separator)
360///
361/// # Returns
362///
363/// Tool name after the last `::`, or the original name if no separator.
364///
365/// # Example
366///
367/// ```rust,ignore
368/// assert_eq!(normalize_tool_name("bamboo::read_file"), "read_file");
369/// assert_eq!(normalize_tool_name("read_file"), "read_file");
370/// ```
371pub fn normalize_tool_name(name: &str) -> &str {
372    name.split("::").last().unwrap_or(name)
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    use serde_json::json;
380
381    struct TestTool {
382        name: &'static str,
383        description: &'static str,
384    }
385
386    #[async_trait]
387    impl Tool for TestTool {
388        fn name(&self) -> &str {
389            self.name
390        }
391
392        fn description(&self) -> &str {
393            self.description
394        }
395
396        fn parameters_schema(&self) -> serde_json::Value {
397            json!({
398                "type": "object",
399                "properties": {}
400            })
401        }
402
403        async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult, ToolError> {
404            Ok(ToolResult {
405                success: true,
406                result: "ok".to_string(),
407                display_preference: None,
408            })
409        }
410    }
411
412    #[test]
413    fn register_and_get() {
414        let registry = ToolRegistry::new();
415        let tool = TestTool {
416            name: "test_tool",
417            description: "test tool",
418        };
419
420        assert!(registry.register(tool).is_ok());
421        assert!(registry.get("test_tool").is_some());
422        assert!(registry.get("unknown").is_none());
423    }
424
425    #[test]
426    fn duplicate_tool_registration() {
427        let registry = ToolRegistry::new();
428
429        registry
430            .register(TestTool {
431                name: "dup",
432                description: "first",
433            })
434            .unwrap();
435
436        let duplicate = registry.register(TestTool {
437            name: "dup",
438            description: "second",
439        });
440
441        assert!(matches!(duplicate, Err(RegistryError::DuplicateTool(name)) if name == "dup"));
442    }
443
444    #[test]
445    fn list_tools_returns_registered_tools() {
446        let registry = ToolRegistry::new();
447
448        registry
449            .register(TestTool {
450                name: "tool_a",
451                description: "tool a",
452            })
453            .unwrap();
454        registry
455            .register(TestTool {
456                name: "tool_b",
457                description: "tool b",
458            })
459            .unwrap();
460
461        let tools = registry.list_tools();
462
463        assert_eq!(tools.len(), 2);
464        assert_eq!(tools[0].function.name, "tool_a");
465        assert_eq!(tools[1].function.name, "tool_b");
466    }
467
468    #[test]
469    fn register_rejects_empty_tool_name() {
470        let registry = ToolRegistry::new();
471
472        let result = registry.register(TestTool {
473            name: "",
474            description: "invalid",
475        });
476
477        assert!(
478            matches!(result, Err(RegistryError::InvalidTool(reason)) if reason == "tool name cannot be empty")
479        );
480    }
481
482    #[test]
483    fn normalize_tool_name_handles_namespaced_inputs() {
484        assert_eq!(normalize_tool_name("read_file"), "read_file");
485        assert_eq!(normalize_tool_name("default::read_file"), "read_file");
486        assert_eq!(normalize_tool_name("a::b::c::read_file"), "read_file");
487    }
488}