llm_tool/registry/mod.rs
1//! Tool registry: registry and concurrent dispatch of named tools.
2
3use alloc::{boxed::Box, format, vec::Vec};
4
5use super::{
6 rust_tool::{ErasedTool, RustTool, definition_of},
7 types::{ToolContext, ToolDefinition, ToolError, ToolOutput},
8};
9use crate::compat::HashMap;
10
11/// Entry holding a cached [`ToolDefinition`] alongside the type-erased tool.
12///
13/// The definition is computed once at registration time so that
14/// [`ToolRegistry::definitions`] and [`ToolRegistry::iter`] never
15/// regenerate JSON schemas.
16struct RegisteredTool {
17 definition: ToolDefinition,
18 erased: Box<dyn ErasedTool>,
19}
20
21pub struct ToolRegistry {
22 tools: HashMap<&'static str, RegisteredTool>,
23}
24
25impl core::fmt::Debug for ToolRegistry {
26 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27 let names: Vec<&str> = self
28 .tools
29 .values()
30 .map(|r| r.definition.name.as_str())
31 .collect();
32 f.debug_struct("ToolRegistry")
33 .field("tool_count", &self.tools.len())
34 .field("tool_names", &names)
35 .finish()
36 }
37}
38
39impl Default for ToolRegistry {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl ToolRegistry {
46 /// Create an empty registry.
47 #[must_use]
48 pub fn new() -> Self {
49 Self {
50 tools: HashMap::new(),
51 }
52 }
53
54 /// Register a [`RustTool`]. Returns `&mut Self` for chaining.
55 ///
56 /// The tool's [`ToolDefinition`] (including JSON schema) is computed once
57 /// here and cached for the lifetime of the registration.
58 ///
59 /// If a tool with the same name was already registered, it is replaced.
60 ///
61 /// # Panics
62 ///
63 /// Panics if the tool's JSON schema cannot be serialized. This indicates a
64 /// bug in the tool's `Params` type (e.g. a broken `JsonSchema` impl).
65 pub fn register<T: RustTool + 'static>(&mut self, tool: T) -> &mut Self {
66 let definition = definition_of(&tool)
67 .unwrap_or_else(|e| panic!("Failed to build definition for tool '{}': {e}", T::NAME));
68 self.tools.insert(
69 T::NAME,
70 RegisteredTool {
71 definition,
72 erased: Box::new(tool),
73 },
74 );
75 self
76 }
77
78 /// Register a [`RustTool`], consuming and returning `Self` for owned chaining.
79 ///
80 /// This is the owned counterpart of [`register`](Self::register), enabling
81 /// patterns like:
82 /// ```
83 /// use llm_tool::{RustTool, ToolContext, ToolError, ToolOutput, ToolRegistry};
84 /// use schemars::JsonSchema;
85 /// use serde::Deserialize;
86 ///
87 /// #[derive(Deserialize, JsonSchema)]
88 /// struct NoParams {}
89 ///
90 /// struct ToolA;
91 /// impl RustTool for ToolA {
92 /// type Params = NoParams;
93 /// const NAME: &'static str = "tool_a";
94 /// const DESCRIPTION: &'static str = "Tool A";
95 /// async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
96 /// Ok("a".into())
97 /// }
98 /// }
99 ///
100 /// struct ToolB;
101 /// impl RustTool for ToolB {
102 /// type Params = NoParams;
103 /// const NAME: &'static str = "tool_b";
104 /// const DESCRIPTION: &'static str = "Tool B";
105 /// async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
106 /// Ok("b".into())
107 /// }
108 /// }
109 ///
110 /// let registry = ToolRegistry::new().with_tool(ToolA).with_tool(ToolB);
111 ///
112 /// assert_eq!(registry.definitions().len(), 2);
113 /// ```
114 #[must_use]
115 pub fn with_tool<T: RustTool + 'static>(mut self, tool: T) -> Self {
116 self.register(tool);
117 self
118 }
119
120 /// Collect [`ToolDefinition`]s for all registered tools.
121 ///
122 /// Returns clones of the cached definitions computed at registration time.
123 #[must_use]
124 pub fn definitions(&self) -> Vec<ToolDefinition> {
125 self.tools
126 .values()
127 .map(|entry| entry.definition.clone())
128 .collect()
129 }
130
131 /// Dispatch a tool call by name with raw JSON arguments and a context.
132 ///
133 /// # Errors
134 ///
135 /// Returns `Err` if the tool name is unknown or the handler returns an error.
136 pub async fn dispatch(
137 &self,
138 name: &str,
139 args: serde_json::Value,
140 ctx: &ToolContext,
141 ) -> Result<ToolOutput, ToolError> {
142 let entry = self
143 .tools
144 .get(name)
145 .ok_or_else(|| ToolError::new(format!("Unknown tool: {name}")))?;
146 entry.erased.call_erased(args, ctx).await
147 }
148
149 /// Dispatch a tool call by name with a raw JSON string argument.
150 ///
151 /// # Errors
152 ///
153 /// Returns `Err` if JSON parsing fails, the tool name is unknown, or the handler fails.
154 pub async fn dispatch_str(
155 &self,
156 name: &str,
157 args_json: &str,
158 ctx: &ToolContext,
159 ) -> Result<ToolOutput, ToolError> {
160 let args = serde_json::from_str(args_json)
161 .map_err(|e| ToolError::new(format!("Malformed JSON arguments: {e}")))?;
162 self.dispatch(name, args, ctx).await
163 }
164
165 /// Number of registered tools.
166 #[must_use]
167 pub fn len(&self) -> usize {
168 self.tools.len()
169 }
170
171 /// Whether the registry has no registered tools.
172 #[must_use]
173 pub fn is_empty(&self) -> bool {
174 self.tools.is_empty()
175 }
176
177 /// Iterate over `(name, definition)` pairs for every registered tool.
178 ///
179 /// Returns clones of the cached definitions computed at registration time.
180 pub fn iter(&self) -> impl Iterator<Item = (&'static str, ToolDefinition)> + '_ {
181 self.tools
182 .iter()
183 .map(|(name, entry)| (*name, entry.definition.clone()))
184 }
185}
186
187/// Iterate over `(name, definition)` pairs for every registered tool.
188///
189/// Yields `(&'static str, ToolDefinition)` for each tool in the registry.
190impl<'a> IntoIterator for &'a ToolRegistry {
191 type Item = (&'static str, ToolDefinition);
192 type IntoIter = Box<dyn Iterator<Item = (&'static str, ToolDefinition)> + 'a>;
193
194 fn into_iter(self) -> Self::IntoIter {
195 Box::new(
196 self.tools
197 .iter()
198 .map(|(name, entry)| (*name, entry.definition.clone())),
199 )
200 }
201}
202
203#[cfg(all(test, feature = "std"))]
204mod tests;