Skip to main content

llm_tool/registry/
mod.rs

1//! Tool registry: registry and concurrent dispatch of named tools.
2
3use alloc::{boxed::Box, format, vec::Vec};
4
5use super::{
6    rust_tool::{ErasedTool, RustTool, definition_of},
7    types::{ToolContext, ToolDefinition, ToolError, ToolOutput},
8};
9use crate::compat::HashMap;
10
11/// Entry holding a cached [`ToolDefinition`] alongside the type-erased tool.
12///
13/// The definition is computed once at registration time so that
14/// [`ToolRegistry::definitions`] and [`ToolRegistry::iter`] never
15/// regenerate JSON schemas.
16struct RegisteredTool {
17    definition: ToolDefinition,
18    erased: Box<dyn ErasedTool>,
19}
20
21/// A registry of named tools available for dynamic dispatch.
22///
23/// Holds type-erased tool implementations and cached [`ToolDefinition`](super::types::ToolDefinition)
24/// schemas for fast lookup and execution.
25pub struct ToolRegistry {
26    tools: HashMap<&'static str, RegisteredTool>,
27}
28
29impl core::fmt::Debug for ToolRegistry {
30    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31        let names: Vec<&str> = self
32            .tools
33            .values()
34            .map(|r| r.definition.name.as_str())
35            .collect();
36        f.debug_struct("ToolRegistry")
37            .field("tool_count", &self.tools.len())
38            .field("tool_names", &names)
39            .finish()
40    }
41}
42
43impl Default for ToolRegistry {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl ToolRegistry {
50    /// Create an empty registry.
51    #[must_use]
52    pub fn new() -> Self {
53        Self {
54            tools: HashMap::new(),
55        }
56    }
57
58    /// Register a [`RustTool`]. Returns `&mut Self` for chaining.
59    ///
60    /// The tool's [`ToolDefinition`] (including JSON schema) is computed once
61    /// here and cached for the lifetime of the registration.
62    ///
63    /// If a tool with the same name was already registered, it is replaced.
64    ///
65    /// # Panics
66    ///
67    /// Panics if the tool's JSON schema cannot be serialized. This indicates a
68    /// bug in the tool's `Params` type (e.g. a broken `JsonSchema` impl).
69    pub fn register<T: RustTool + 'static>(&mut self, tool: T) -> &mut Self {
70        let definition = definition_of(&tool)
71            .unwrap_or_else(|e| panic!("Failed to build definition for tool '{}': {e}", T::NAME));
72        self.tools.insert(
73            T::NAME,
74            RegisteredTool {
75                definition,
76                erased: Box::new(tool),
77            },
78        );
79        self
80    }
81
82    /// Register a [`RustTool`], consuming and returning `Self` for owned chaining.
83    ///
84    /// This is the owned counterpart of [`register`](Self::register), enabling
85    /// patterns like:
86    /// ```
87    /// use llm_tool::{RustTool, ToolContext, ToolError, ToolOutput, ToolRegistry};
88    /// use schemars::JsonSchema;
89    /// use serde::Deserialize;
90    ///
91    /// #[derive(Deserialize, JsonSchema)]
92    /// struct NoParams {}
93    ///
94    /// struct ToolA;
95    /// impl RustTool for ToolA {
96    ///     type Params = NoParams;
97    ///     const NAME: &'static str = "tool_a";
98    ///     const DESCRIPTION: &'static str = "Tool A";
99    ///     async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
100    ///         Ok("a".into())
101    ///     }
102    /// }
103    ///
104    /// struct ToolB;
105    /// impl RustTool for ToolB {
106    ///     type Params = NoParams;
107    ///     const NAME: &'static str = "tool_b";
108    ///     const DESCRIPTION: &'static str = "Tool B";
109    ///     async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
110    ///         Ok("b".into())
111    ///     }
112    /// }
113    ///
114    /// let registry = ToolRegistry::new().with_tool(ToolA).with_tool(ToolB);
115    ///
116    /// assert_eq!(registry.definitions().len(), 2);
117    /// ```
118    #[must_use]
119    pub fn with_tool<T: RustTool + 'static>(mut self, tool: T) -> Self {
120        self.register(tool);
121        self
122    }
123
124    /// Collect [`ToolDefinition`]s for all registered tools.
125    ///
126    /// Returns clones of the cached definitions computed at registration time.
127    #[must_use]
128    pub fn definitions(&self) -> Vec<ToolDefinition> {
129        self.tools
130            .values()
131            .map(|entry| entry.definition.clone())
132            .collect()
133    }
134
135    /// Dispatch a tool call by name with raw JSON arguments and a context.
136    ///
137    /// # Errors
138    ///
139    /// Returns `Err` if the tool name is unknown or the handler returns an error.
140    pub async fn dispatch(
141        &self,
142        name: &str,
143        args: serde_json::Value,
144        ctx: &ToolContext,
145    ) -> Result<ToolOutput, ToolError> {
146        let entry = self
147            .tools
148            .get(name)
149            .ok_or_else(|| ToolError::new(format!("Unknown tool: {name}")))?;
150        entry.erased.call_erased(args, ctx).await
151    }
152
153    /// Dispatch a tool call by name with a raw JSON string argument.
154    ///
155    /// # Errors
156    ///
157    /// Returns `Err` if JSON parsing fails, the tool name is unknown, or the handler fails.
158    pub async fn dispatch_str(
159        &self,
160        name: &str,
161        args_json: &str,
162        ctx: &ToolContext,
163    ) -> Result<ToolOutput, ToolError> {
164        let args = serde_json::from_str(args_json)
165            .map_err(|e| ToolError::new(format!("Malformed JSON arguments: {e}")))?;
166        self.dispatch(name, args, ctx).await
167    }
168
169    /// Number of registered tools.
170    #[must_use]
171    pub fn len(&self) -> usize {
172        self.tools.len()
173    }
174
175    /// Whether the registry has no registered tools.
176    #[must_use]
177    pub fn is_empty(&self) -> bool {
178        self.tools.is_empty()
179    }
180
181    /// Iterate over `(name, definition)` pairs for every registered tool.
182    ///
183    /// Returns clones of the cached definitions computed at registration time.
184    pub fn iter(&self) -> impl Iterator<Item = (&'static str, ToolDefinition)> + '_ {
185        self.tools
186            .iter()
187            .map(|(name, entry)| (*name, entry.definition.clone()))
188    }
189}
190
191/// Iterate over `(name, definition)` pairs for every registered tool.
192///
193/// Yields `(&'static str, ToolDefinition)` for each tool in the registry.
194impl<'a> IntoIterator for &'a ToolRegistry {
195    type Item = (&'static str, ToolDefinition);
196    type IntoIter = Box<dyn Iterator<Item = (&'static str, ToolDefinition)> + 'a>;
197
198    fn into_iter(self) -> Self::IntoIter {
199        Box::new(
200            self.tools
201                .iter()
202                .map(|(name, entry)| (*name, entry.definition.clone())),
203        )
204    }
205}
206
207#[cfg(all(test, feature = "std"))]
208mod tests;