llm_tool/registry/mod.rs
1//! Tool registry: registry and concurrent dispatch of named tools.
2
3use alloc::{boxed::Box, format, vec::Vec};
4
5use super::{
6 rust_tool::{ErasedTool, RustTool, definition_of},
7 types::{ToolContext, ToolDefinition, ToolError, ToolOutput},
8};
9use crate::compat::HashMap;
10
11/// Entry holding a cached [`ToolDefinition`] alongside the type-erased tool.
12///
13/// The definition is computed once at registration time so that
14/// [`ToolRegistry::definitions`] and [`ToolRegistry::iter`] never
15/// regenerate JSON schemas.
16struct RegisteredTool {
17 definition: ToolDefinition,
18 erased: Box<dyn ErasedTool>,
19}
20
21/// A registry of named tools available for dynamic dispatch.
22///
23/// Holds type-erased tool implementations and cached [`ToolDefinition`](super::types::ToolDefinition)
24/// schemas for fast lookup and execution.
25pub struct ToolRegistry {
26 tools: HashMap<&'static str, RegisteredTool>,
27}
28
29impl core::fmt::Debug for ToolRegistry {
30 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31 let names: Vec<&str> = self
32 .tools
33 .values()
34 .map(|r| r.definition.name.as_str())
35 .collect();
36 f.debug_struct("ToolRegistry")
37 .field("tool_count", &self.tools.len())
38 .field("tool_names", &names)
39 .finish()
40 }
41}
42
43impl Default for ToolRegistry {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl ToolRegistry {
50 /// Create an empty registry.
51 #[must_use]
52 pub fn new() -> Self {
53 Self {
54 tools: HashMap::new(),
55 }
56 }
57
58 /// Register a [`RustTool`]. Returns `&mut Self` for chaining.
59 ///
60 /// The tool's [`ToolDefinition`] (including JSON schema) is computed once
61 /// here and cached for the lifetime of the registration.
62 ///
63 /// If a tool with the same name was already registered, it is replaced.
64 ///
65 /// # Panics
66 ///
67 /// Panics if the tool's JSON schema cannot be serialized. This indicates a
68 /// bug in the tool's `Params` type (e.g. a broken `JsonSchema` impl).
69 pub fn register<T: RustTool + 'static>(&mut self, tool: T) -> &mut Self {
70 let definition = definition_of(&tool)
71 .unwrap_or_else(|e| panic!("Failed to build definition for tool '{}': {e}", T::NAME));
72 self.tools.insert(
73 T::NAME,
74 RegisteredTool {
75 definition,
76 erased: Box::new(tool),
77 },
78 );
79 self
80 }
81
82 /// Register a [`RustTool`], consuming and returning `Self` for owned chaining.
83 ///
84 /// This is the owned counterpart of [`register`](Self::register), enabling
85 /// patterns like:
86 /// ```
87 /// use llm_tool::{RustTool, ToolContext, ToolError, ToolOutput, ToolRegistry};
88 /// use schemars::JsonSchema;
89 /// use serde::Deserialize;
90 ///
91 /// #[derive(Deserialize, JsonSchema)]
92 /// struct NoParams {}
93 ///
94 /// struct ToolA;
95 /// impl RustTool for ToolA {
96 /// type Params = NoParams;
97 /// const NAME: &'static str = "tool_a";
98 /// const DESCRIPTION: &'static str = "Tool A";
99 /// async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
100 /// Ok("a".into())
101 /// }
102 /// }
103 ///
104 /// struct ToolB;
105 /// impl RustTool for ToolB {
106 /// type Params = NoParams;
107 /// const NAME: &'static str = "tool_b";
108 /// const DESCRIPTION: &'static str = "Tool B";
109 /// async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
110 /// Ok("b".into())
111 /// }
112 /// }
113 ///
114 /// let registry = ToolRegistry::new().with_tool(ToolA).with_tool(ToolB);
115 ///
116 /// assert_eq!(registry.definitions().len(), 2);
117 /// ```
118 #[must_use]
119 pub fn with_tool<T: RustTool + 'static>(mut self, tool: T) -> Self {
120 self.register(tool);
121 self
122 }
123
124 /// Collect [`ToolDefinition`]s for all registered tools.
125 ///
126 /// Returns clones of the cached definitions computed at registration time.
127 #[must_use]
128 pub fn definitions(&self) -> Vec<ToolDefinition> {
129 self.tools
130 .values()
131 .map(|entry| entry.definition.clone())
132 .collect()
133 }
134
135 /// Dispatch a tool call by name with raw JSON arguments and a context.
136 ///
137 /// # Errors
138 ///
139 /// Returns `Err` if the tool name is unknown or the handler returns an error.
140 pub async fn dispatch(
141 &self,
142 name: &str,
143 args: serde_json::Value,
144 ctx: &ToolContext,
145 ) -> Result<ToolOutput, ToolError> {
146 let entry = self
147 .tools
148 .get(name)
149 .ok_or_else(|| ToolError::new(format!("Unknown tool: {name}")))?;
150 entry.erased.call_erased(args, ctx).await
151 }
152
153 /// Dispatch a tool call by name with a raw JSON string argument.
154 ///
155 /// # Errors
156 ///
157 /// Returns `Err` if JSON parsing fails, the tool name is unknown, or the handler fails.
158 pub async fn dispatch_str(
159 &self,
160 name: &str,
161 args_json: &str,
162 ctx: &ToolContext,
163 ) -> Result<ToolOutput, ToolError> {
164 let args = serde_json::from_str(args_json)
165 .map_err(|e| ToolError::new(format!("Malformed JSON arguments: {e}")))?;
166 self.dispatch(name, args, ctx).await
167 }
168
169 /// Number of registered tools.
170 #[must_use]
171 pub fn len(&self) -> usize {
172 self.tools.len()
173 }
174
175 /// Whether the registry has no registered tools.
176 #[must_use]
177 pub fn is_empty(&self) -> bool {
178 self.tools.is_empty()
179 }
180
181 /// Iterate over `(name, definition)` pairs for every registered tool.
182 ///
183 /// Returns clones of the cached definitions computed at registration time.
184 pub fn iter(&self) -> impl Iterator<Item = (&'static str, ToolDefinition)> + '_ {
185 self.tools
186 .iter()
187 .map(|(name, entry)| (*name, entry.definition.clone()))
188 }
189}
190
191/// Iterate over `(name, definition)` pairs for every registered tool.
192///
193/// Yields `(&'static str, ToolDefinition)` for each tool in the registry.
194impl<'a> IntoIterator for &'a ToolRegistry {
195 type Item = (&'static str, ToolDefinition);
196 type IntoIter = Box<dyn Iterator<Item = (&'static str, ToolDefinition)> + 'a>;
197
198 fn into_iter(self) -> Self::IntoIter {
199 Box::new(
200 self.tools
201 .iter()
202 .map(|(name, entry)| (*name, entry.definition.clone())),
203 )
204 }
205}
206
207#[cfg(all(test, feature = "std"))]
208mod tests;