Skip to main content

llm_tool/registry/
mod.rs

1//! Tool registry: registry and concurrent dispatch of named tools.
2
3use alloc::{boxed::Box, format, vec::Vec};
4
5use super::{
6    rust_tool::{ErasedTool, RustTool, definition_of},
7    types::{ToolContext, ToolDefinition, ToolError, ToolOutput},
8};
9use crate::compat::HashMap;
10
11/// Entry holding a cached [`ToolDefinition`] alongside the type-erased tool.
12///
13/// The definition is computed once at registration time so that
14/// [`ToolRegistry::definitions`] and [`ToolRegistry::iter`] never
15/// regenerate JSON schemas.
16struct RegisteredTool {
17    definition: ToolDefinition,
18    erased: Box<dyn ErasedTool>,
19}
20
21pub struct ToolRegistry {
22    tools: HashMap<&'static str, RegisteredTool>,
23}
24
25impl core::fmt::Debug for ToolRegistry {
26    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27        let names: Vec<&str> = self
28            .tools
29            .values()
30            .map(|r| r.definition.name.as_str())
31            .collect();
32        f.debug_struct("ToolRegistry")
33            .field("tool_count", &self.tools.len())
34            .field("tool_names", &names)
35            .finish()
36    }
37}
38
39impl Default for ToolRegistry {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl ToolRegistry {
46    /// Create an empty registry.
47    #[must_use]
48    pub fn new() -> Self {
49        Self {
50            tools: HashMap::new(),
51        }
52    }
53
54    /// Register a [`RustTool`]. Returns `&mut Self` for chaining.
55    ///
56    /// The tool's [`ToolDefinition`] (including JSON schema) is computed once
57    /// here and cached for the lifetime of the registration.
58    ///
59    /// If a tool with the same name was already registered, it is replaced.
60    ///
61    /// # Panics
62    ///
63    /// Panics if the tool's JSON schema cannot be serialized. This indicates a
64    /// bug in the tool's `Params` type (e.g. a broken `JsonSchema` impl).
65    pub fn register<T: RustTool + 'static>(&mut self, tool: T) -> &mut Self {
66        let definition = definition_of(&tool)
67            .unwrap_or_else(|e| panic!("Failed to build definition for tool '{}': {e}", T::NAME));
68        self.tools.insert(
69            T::NAME,
70            RegisteredTool {
71                definition,
72                erased: Box::new(tool),
73            },
74        );
75        self
76    }
77
78    /// Register a [`RustTool`], consuming and returning `Self` for owned chaining.
79    ///
80    /// This is the owned counterpart of [`register`](Self::register), enabling
81    /// patterns like:
82    /// ```
83    /// use llm_tool::{RustTool, ToolContext, ToolError, ToolOutput, ToolRegistry};
84    /// use schemars::JsonSchema;
85    /// use serde::Deserialize;
86    ///
87    /// #[derive(Deserialize, JsonSchema)]
88    /// struct NoParams {}
89    ///
90    /// struct ToolA;
91    /// impl RustTool for ToolA {
92    ///     type Params = NoParams;
93    ///     const NAME: &'static str = "tool_a";
94    ///     const DESCRIPTION: &'static str = "Tool A";
95    ///     async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
96    ///         Ok("a".into())
97    ///     }
98    /// }
99    ///
100    /// struct ToolB;
101    /// impl RustTool for ToolB {
102    ///     type Params = NoParams;
103    ///     const NAME: &'static str = "tool_b";
104    ///     const DESCRIPTION: &'static str = "Tool B";
105    ///     async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
106    ///         Ok("b".into())
107    ///     }
108    /// }
109    ///
110    /// let registry = ToolRegistry::new().with_tool(ToolA).with_tool(ToolB);
111    ///
112    /// assert_eq!(registry.definitions().len(), 2);
113    /// ```
114    #[must_use]
115    pub fn with_tool<T: RustTool + 'static>(mut self, tool: T) -> Self {
116        self.register(tool);
117        self
118    }
119
120    /// Collect [`ToolDefinition`]s for all registered tools.
121    ///
122    /// Returns clones of the cached definitions computed at registration time.
123    #[must_use]
124    pub fn definitions(&self) -> Vec<ToolDefinition> {
125        self.tools
126            .values()
127            .map(|entry| entry.definition.clone())
128            .collect()
129    }
130
131    /// Dispatch a tool call by name with raw JSON arguments and a context.
132    ///
133    /// # Errors
134    ///
135    /// Returns `Err` if the tool name is unknown or the handler returns an error.
136    pub async fn dispatch(
137        &self,
138        name: &str,
139        args: serde_json::Value,
140        ctx: &ToolContext,
141    ) -> Result<ToolOutput, ToolError> {
142        let entry = self
143            .tools
144            .get(name)
145            .ok_or_else(|| ToolError::new(format!("Unknown tool: {name}")))?;
146        entry.erased.call_erased(args, ctx).await
147    }
148
149    /// Dispatch a tool call by name with a raw JSON string argument.
150    ///
151    /// # Errors
152    ///
153    /// Returns `Err` if JSON parsing fails, the tool name is unknown, or the handler fails.
154    pub async fn dispatch_str(
155        &self,
156        name: &str,
157        args_json: &str,
158        ctx: &ToolContext,
159    ) -> Result<ToolOutput, ToolError> {
160        let args = serde_json::from_str(args_json)
161            .map_err(|e| ToolError::new(format!("Malformed JSON arguments: {e}")))?;
162        self.dispatch(name, args, ctx).await
163    }
164
165    /// Number of registered tools.
166    #[must_use]
167    pub fn len(&self) -> usize {
168        self.tools.len()
169    }
170
171    /// Whether the registry has no registered tools.
172    #[must_use]
173    pub fn is_empty(&self) -> bool {
174        self.tools.is_empty()
175    }
176
177    /// Iterate over `(name, definition)` pairs for every registered tool.
178    ///
179    /// Returns clones of the cached definitions computed at registration time.
180    pub fn iter(&self) -> impl Iterator<Item = (&'static str, ToolDefinition)> + '_ {
181        self.tools
182            .iter()
183            .map(|(name, entry)| (*name, entry.definition.clone()))
184    }
185}
186
187/// Iterate over `(name, definition)` pairs for every registered tool.
188///
189/// Yields `(&'static str, ToolDefinition)` for each tool in the registry.
190impl<'a> IntoIterator for &'a ToolRegistry {
191    type Item = (&'static str, ToolDefinition);
192    type IntoIter = Box<dyn Iterator<Item = (&'static str, ToolDefinition)> + 'a>;
193
194    fn into_iter(self) -> Self::IntoIter {
195        Box::new(
196            self.tools
197                .iter()
198                .map(|(name, entry)| (*name, entry.definition.clone())),
199        )
200    }
201}
202
203#[cfg(all(test, feature = "std"))]
204mod tests;