llm_stack/tool/registry.rs
1//! Tool registry for managing and executing tools.
2
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::time::Duration;
8
9use rand::Rng;
10
11use super::ToolHandler;
12use crate::chat::{ToolCall, ToolResult};
13use crate::intercept::domain::{ToolExec, ToolRequest, ToolResponse};
14use crate::intercept::{InterceptorStack, Operation};
15use crate::provider::{ToolDefinition, ToolRetryConfig};
16
17/// A registry of tool handlers, indexed by name.
18///
19/// Generic over context type `Ctx` which is passed to tool handlers on
20/// execution. Default is `()` for backwards compatibility.
21///
22/// Provides validation of tool call arguments against their schemas
23/// and parallel execution of multiple tool calls.
24///
25/// # Interceptors
26///
27/// Tool execution can be wrapped with interceptors for cross-cutting concerns
28/// like logging, approval gates, or rate limiting:
29///
30/// ```rust,ignore
31/// use llm_stack::{ToolRegistry, tool_fn};
32/// use llm_stack::intercept::{InterceptorStack, ToolExec, Approval, ApprovalDecision};
33///
34/// let mut registry: ToolRegistry<()> = ToolRegistry::new()
35/// .with_interceptors(
36/// InterceptorStack::<ToolExec<()>>::new()
37/// .with(Approval::new(|req| {
38/// if req.name.starts_with("dangerous_") {
39/// ApprovalDecision::Deny("Not allowed".into())
40/// } else {
41/// ApprovalDecision::Allow
42/// }
43/// }))
44/// );
45/// ```
46pub struct ToolRegistry<Ctx = ()>
47where
48 Ctx: Send + Sync + 'static,
49{
50 pub(crate) handlers: HashMap<String, Arc<dyn ToolHandler<Ctx>>>,
51 interceptors: InterceptorStack<ToolExec<Ctx>>,
52}
53
54impl<Ctx> Default for ToolRegistry<Ctx>
55where
56 Ctx: Send + Sync + 'static,
57{
58 fn default() -> Self {
59 Self {
60 handlers: HashMap::new(),
61 interceptors: InterceptorStack::new(),
62 }
63 }
64}
65
66impl<Ctx> Clone for ToolRegistry<Ctx>
67where
68 Ctx: Send + Sync + 'static,
69{
70 /// Clone the registry.
71 ///
72 /// This is cheap — it clones `Arc` pointers to handlers, not the
73 /// handlers themselves.
74 fn clone(&self) -> Self {
75 Self {
76 handlers: self.handlers.clone(),
77 interceptors: self.interceptors.clone(),
78 }
79 }
80}
81
82impl<Ctx> std::fmt::Debug for ToolRegistry<Ctx>
83where
84 Ctx: Send + Sync + 'static,
85{
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("ToolRegistry")
88 .field("tools", &self.handlers.keys().collect::<Vec<_>>())
89 .field("interceptors", &self.interceptors.len())
90 .finish()
91 }
92}
93
94impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
95 /// Creates an empty registry.
96 pub fn new() -> Self {
97 Self::default()
98 }
99
100 /// Registers a tool handler.
101 ///
102 /// If a handler with the same name already exists, it is replaced.
103 pub fn register(&mut self, handler: impl ToolHandler<Ctx> + 'static) -> &mut Self {
104 let name = handler.definition().name.clone();
105 self.handlers.insert(name, Arc::new(handler));
106 self
107 }
108
109 /// Registers a shared tool handler.
110 pub fn register_shared(&mut self, handler: Arc<dyn ToolHandler<Ctx>>) -> &mut Self {
111 let name = handler.definition().name.clone();
112 self.handlers.insert(name, handler);
113 self
114 }
115
116 /// Returns the handler for the given tool name.
117 pub fn get(&self, name: &str) -> Option<&Arc<dyn ToolHandler<Ctx>>> {
118 self.handlers.get(name)
119 }
120
121 /// Returns whether a tool with the given name is registered.
122 pub fn contains(&self, name: &str) -> bool {
123 self.handlers.contains_key(name)
124 }
125
126 /// Returns the definitions of all registered tools.
127 ///
128 /// Pass this to [`ChatParams::tools`](crate::provider::ChatParams::tools) to tell the model which
129 /// tools are available.
130 pub fn definitions(&self) -> Vec<ToolDefinition> {
131 self.handlers.values().map(|h| h.definition()).collect()
132 }
133
134 /// Returns the number of registered tools.
135 pub fn len(&self) -> usize {
136 self.handlers.len()
137 }
138
139 /// Returns true if no tools are registered.
140 pub fn is_empty(&self) -> bool {
141 self.handlers.is_empty()
142 }
143
144 /// Returns a new registry excluding the named tools.
145 ///
146 /// Useful for creating scoped registries in Master/Worker patterns
147 /// where workers should not have access to certain tools (e.g., `spawn_task`).
148 ///
149 /// # Example
150 ///
151 /// ```rust
152 /// use llm_stack::ToolRegistry;
153 ///
154 /// let master_registry: ToolRegistry<()> = ToolRegistry::new();
155 /// // ... register tools ...
156 ///
157 /// // Workers can't spawn or use admin tools
158 /// let worker_registry = master_registry.without(["spawn_task", "admin_tool"]);
159 /// ```
160 #[must_use]
161 pub fn without<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
162 use std::collections::HashSet;
163 let exclude: HashSet<&str> = names.into_iter().collect();
164 let mut new = Self {
165 handlers: HashMap::new(),
166 interceptors: self.interceptors.clone(),
167 };
168 for (name, handler) in &self.handlers {
169 if !exclude.contains(name.as_str()) {
170 new.handlers.insert(name.clone(), Arc::clone(handler));
171 }
172 }
173 new
174 }
175
176 /// Returns a new registry with only the named tools.
177 ///
178 /// Useful for creating minimal registries with specific capabilities.
179 ///
180 /// # Example
181 ///
182 /// ```rust
183 /// use llm_stack::ToolRegistry;
184 ///
185 /// let full_registry: ToolRegistry<()> = ToolRegistry::new();
186 /// // ... register tools ...
187 ///
188 /// // Read-only registry with just search tools
189 /// let search_registry = full_registry.only(["search_docs", "search_web"]);
190 /// ```
191 #[must_use]
192 pub fn only<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
193 use std::collections::HashSet;
194 let include: HashSet<&str> = names.into_iter().collect();
195 let mut new = Self {
196 handlers: HashMap::new(),
197 interceptors: self.interceptors.clone(),
198 };
199 for (name, handler) in &self.handlers {
200 if include.contains(name.as_str()) {
201 new.handlers.insert(name.clone(), Arc::clone(handler));
202 }
203 }
204 new
205 }
206
207 /// Sets the interceptor stack for all tool executions.
208 ///
209 /// Interceptors run in the order added (first = outermost). They can
210 /// inspect, modify, or block tool calls before they reach the handler.
211 ///
212 /// # Example
213 ///
214 /// ```rust,ignore
215 /// use llm_stack::{ToolRegistry, tool_fn};
216 /// use llm_stack::intercept::{InterceptorStack, ToolExec, Approval, ApprovalDecision, Retry};
217 ///
218 /// let registry: ToolRegistry<()> = ToolRegistry::new()
219 /// .with_interceptors(
220 /// InterceptorStack::<ToolExec<()>>::new()
221 /// .with(Approval::new(|req| {
222 /// if req.name == "dangerous" {
223 /// ApprovalDecision::Deny("Not allowed".into())
224 /// } else {
225 /// ApprovalDecision::Allow
226 /// }
227 /// }))
228 /// .with(Retry::default())
229 /// );
230 /// ```
231 #[must_use]
232 pub fn with_interceptors(mut self, interceptors: InterceptorStack<ToolExec<Ctx>>) -> Self {
233 self.interceptors = interceptors;
234 self
235 }
236
237 /// Executes a single tool call with schema validation and optional retry.
238 ///
239 /// 1. Looks up the handler by [`ToolCall::name`]
240 /// 2. Validates arguments against the tool's parameter schema
241 /// 3. Runs the call through interceptors (if any)
242 /// 4. Invokes the handler with the provided context
243 /// 5. If the tool has retry configuration and execution fails,
244 /// retries with exponential backoff
245 ///
246 /// Returns a [`ToolResult`] (always succeeds at the outer level).
247 /// Execution errors are captured in `ToolResult::is_error`.
248 pub async fn execute(&self, call: &ToolCall, ctx: &Ctx) -> ToolResult {
249 let Some(handler) = self.handlers.get(&call.name) else {
250 return ToolResult {
251 tool_call_id: call.id.clone(),
252 content: format!("Unknown tool: {}", call.name),
253 is_error: true,
254 };
255 };
256
257 // Validate arguments against schema
258 #[cfg(feature = "schema")]
259 {
260 let definition = handler.definition();
261 if let Err(e) = definition.parameters.validate(&call.arguments) {
262 return ToolResult {
263 tool_call_id: call.id.clone(),
264 content: format!("Invalid arguments for tool '{}': {e}", call.name),
265 is_error: true,
266 };
267 }
268 }
269
270 // Create the request for interceptors
271 let request = ToolRequest {
272 name: call.name.clone(),
273 call_id: call.id.clone(),
274 arguments: call.arguments.clone(),
275 };
276
277 // Create the operation that executes the handler
278 let operation = ToolHandlerOperation {
279 handler: handler.clone(),
280 ctx,
281 retry_config: handler.definition().retry,
282 };
283
284 // Execute through interceptor stack
285 let response = self.interceptors.execute(&request, &operation).await;
286
287 ToolResult {
288 tool_call_id: call.id.clone(),
289 content: response.content,
290 is_error: response.is_error,
291 }
292 }
293
294 /// Executes a tool by name with the given arguments.
295 ///
296 /// This is a lower-level method used internally when the tool call
297 /// components are already separated (e.g., for streaming execution).
298 pub(crate) async fn execute_by_name(
299 &self,
300 name: &str,
301 call_id: &str,
302 arguments: serde_json::Value,
303 ctx: &Ctx,
304 ) -> ToolResult {
305 let call = ToolCall {
306 id: call_id.to_string(),
307 name: name.to_string(),
308 arguments,
309 };
310 self.execute(&call, ctx).await
311 }
312
313 /// Executes multiple tool calls, preserving order.
314 ///
315 /// When `parallel` is true, all calls run concurrently via
316 /// `futures::future::join_all`. When false, they run sequentially.
317 pub async fn execute_all(
318 &self,
319 calls: &[ToolCall],
320 ctx: &Ctx,
321 parallel: bool,
322 ) -> Vec<ToolResult> {
323 if !parallel || calls.len() <= 1 {
324 let mut results = Vec::with_capacity(calls.len());
325 for call in calls {
326 results.push(self.execute(call, ctx).await);
327 }
328 return results;
329 }
330
331 // Parallel execution using join_all (no spawn needed)
332 let futures: Vec<_> = calls.iter().map(|call| self.execute(call, ctx)).collect();
333 futures::future::join_all(futures).await
334 }
335}
336
337/// Computes backoff duration with exponential growth and jitter.
338///
339/// Formula: `min(initial * multiplier^attempt, max) * random(1-jitter, 1)`
340fn compute_backoff(config: &ToolRetryConfig, attempt: u32) -> Duration {
341 // Safe to cast: attempt is bounded by max_retries which is u32,
342 // and reasonable values are << i32::MAX
343 #[allow(clippy::cast_possible_wrap)]
344 let base =
345 config.initial_backoff.as_secs_f64() * config.backoff_multiplier.powi(attempt as i32);
346 let capped = base.min(config.max_backoff.as_secs_f64());
347
348 // Apply jitter: random value in range [1-jitter, 1]
349 let jitter_factor = if config.jitter > 0.0 {
350 let min_factor = 1.0 - config.jitter;
351 let mut rng = rand::rng();
352 rng.random_range(min_factor..=1.0)
353 } else {
354 1.0
355 };
356
357 Duration::from_secs_f64(capped * jitter_factor)
358}
359
360/// Wraps a tool handler as an [`Operation`] for the interceptor stack.
361///
362/// This struct captures the handler, context, and retry config so that
363/// the interceptor stack can execute the tool.
364struct ToolHandlerOperation<'a, Ctx: Send + Sync + 'static> {
365 handler: Arc<dyn ToolHandler<Ctx>>,
366 ctx: &'a Ctx,
367 retry_config: Option<ToolRetryConfig>,
368}
369
370impl<Ctx: Send + Sync + 'static> Operation<ToolExec<Ctx>> for ToolHandlerOperation<'_, Ctx> {
371 fn execute<'b>(
372 &'b self,
373 input: &'b ToolRequest,
374 ) -> Pin<Box<dyn Future<Output = ToolResponse> + Send + 'b>>
375 where
376 ToolRequest: Sync,
377 {
378 Box::pin(async move {
379 match &self.retry_config {
380 Some(config) => execute_with_retry(&self.handler, input, self.ctx, config).await,
381 None => execute_once(&self.handler, input, self.ctx).await,
382 }
383 })
384 }
385}
386
387/// Executes a tool once without retry.
388async fn execute_once<Ctx: Send + Sync + 'static>(
389 handler: &Arc<dyn ToolHandler<Ctx>>,
390 request: &ToolRequest,
391 ctx: &Ctx,
392) -> ToolResponse {
393 match handler.execute(request.arguments.clone(), ctx).await {
394 Ok(output) => ToolResponse {
395 content: output.content,
396 is_error: false,
397 },
398 Err(e) => ToolResponse {
399 content: e.message,
400 is_error: true,
401 },
402 }
403}
404
405/// Executes a tool with retry logic.
406async fn execute_with_retry<Ctx: Send + Sync + 'static>(
407 handler: &Arc<dyn ToolHandler<Ctx>>,
408 request: &ToolRequest,
409 ctx: &Ctx,
410 config: &ToolRetryConfig,
411) -> ToolResponse {
412 let mut attempt = 0u32;
413
414 loop {
415 match handler.execute(request.arguments.clone(), ctx).await {
416 Ok(output) => {
417 return ToolResponse {
418 content: output.content,
419 is_error: false,
420 };
421 }
422 Err(e) => {
423 let error_msg = e.message;
424
425 // Check if we should retry this error
426 let should_retry = config
427 .retry_if
428 .as_ref()
429 .is_none_or(|predicate| predicate(&error_msg));
430
431 if !should_retry || attempt >= config.max_retries {
432 return ToolResponse {
433 content: error_msg,
434 is_error: true,
435 };
436 }
437
438 // Calculate backoff with jitter
439 let backoff = compute_backoff(config, attempt);
440 tokio::time::sleep(backoff).await;
441
442 attempt += 1;
443 }
444 }
445 }
446}