Skip to main content

llm_stack/tool/
registry.rs

1//! Tool registry for managing and executing tools.
2
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::time::Duration;
8
9use rand::Rng;
10
11use super::ToolHandler;
12use crate::chat::{ToolCall, ToolResult};
13use crate::intercept::domain::{ToolExec, ToolRequest, ToolResponse};
14use crate::intercept::{InterceptorStack, Operation};
15use crate::provider::{ToolDefinition, ToolRetryConfig};
16
17/// A registry of tool handlers, indexed by name.
18///
19/// Generic over context type `Ctx` which is passed to tool handlers on
20/// execution. Default is `()` for backwards compatibility.
21///
22/// Provides validation of tool call arguments against their schemas
23/// and parallel execution of multiple tool calls.
24///
25/// # Interceptors
26///
27/// Tool execution can be wrapped with interceptors for cross-cutting concerns
28/// like logging, approval gates, or rate limiting:
29///
30/// ```rust,ignore
31/// use llm_stack::{ToolRegistry, tool_fn};
32/// use llm_stack::intercept::{InterceptorStack, ToolExec, Approval, ApprovalDecision};
33///
34/// let mut registry: ToolRegistry<()> = ToolRegistry::new()
35///     .with_interceptors(
36///         InterceptorStack::<ToolExec<()>>::new()
37///             .with(Approval::new(|req| {
38///                 if req.name.starts_with("dangerous_") {
39///                     ApprovalDecision::Deny("Not allowed".into())
40///                 } else {
41///                     ApprovalDecision::Allow
42///                 }
43///             }))
44///     );
45/// ```
46pub struct ToolRegistry<Ctx = ()>
47where
48    Ctx: Send + Sync + 'static,
49{
50    pub(crate) handlers: HashMap<String, Arc<dyn ToolHandler<Ctx>>>,
51    interceptors: InterceptorStack<ToolExec<Ctx>>,
52}
53
54impl<Ctx> Default for ToolRegistry<Ctx>
55where
56    Ctx: Send + Sync + 'static,
57{
58    fn default() -> Self {
59        Self {
60            handlers: HashMap::new(),
61            interceptors: InterceptorStack::new(),
62        }
63    }
64}
65
66impl<Ctx> Clone for ToolRegistry<Ctx>
67where
68    Ctx: Send + Sync + 'static,
69{
70    /// Clone the registry.
71    ///
72    /// This is cheap — it clones `Arc` pointers to handlers, not the
73    /// handlers themselves.
74    fn clone(&self) -> Self {
75        Self {
76            handlers: self.handlers.clone(),
77            interceptors: self.interceptors.clone(),
78        }
79    }
80}
81
82impl<Ctx> std::fmt::Debug for ToolRegistry<Ctx>
83where
84    Ctx: Send + Sync + 'static,
85{
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("ToolRegistry")
88            .field("tools", &self.handlers.keys().collect::<Vec<_>>())
89            .field("interceptors", &self.interceptors.len())
90            .finish()
91    }
92}
93
94impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
95    /// Creates an empty registry.
96    pub fn new() -> Self {
97        Self::default()
98    }
99
100    /// Registers a tool handler.
101    ///
102    /// If a handler with the same name already exists, it is replaced.
103    pub fn register(&mut self, handler: impl ToolHandler<Ctx> + 'static) -> &mut Self {
104        let name = handler.definition().name.clone();
105        self.handlers.insert(name, Arc::new(handler));
106        self
107    }
108
109    /// Registers a shared tool handler.
110    pub fn register_shared(&mut self, handler: Arc<dyn ToolHandler<Ctx>>) -> &mut Self {
111        let name = handler.definition().name.clone();
112        self.handlers.insert(name, handler);
113        self
114    }
115
116    /// Returns the handler for the given tool name.
117    pub fn get(&self, name: &str) -> Option<&Arc<dyn ToolHandler<Ctx>>> {
118        self.handlers.get(name)
119    }
120
121    /// Returns whether a tool with the given name is registered.
122    pub fn contains(&self, name: &str) -> bool {
123        self.handlers.contains_key(name)
124    }
125
126    /// Returns the definitions of all registered tools.
127    ///
128    /// Pass this to [`ChatParams::tools`](crate::provider::ChatParams::tools) to tell the model which
129    /// tools are available.
130    pub fn definitions(&self) -> Vec<ToolDefinition> {
131        self.handlers.values().map(|h| h.definition()).collect()
132    }
133
134    /// Returns the number of registered tools.
135    pub fn len(&self) -> usize {
136        self.handlers.len()
137    }
138
139    /// Returns true if no tools are registered.
140    pub fn is_empty(&self) -> bool {
141        self.handlers.is_empty()
142    }
143
144    /// Returns a new registry excluding the named tools.
145    ///
146    /// Useful for creating scoped registries in Master/Worker patterns
147    /// where workers should not have access to certain tools (e.g., `spawn_task`).
148    ///
149    /// # Example
150    ///
151    /// ```rust
152    /// use llm_stack::ToolRegistry;
153    ///
154    /// let master_registry: ToolRegistry<()> = ToolRegistry::new();
155    /// // ... register tools ...
156    ///
157    /// // Workers can't spawn or use admin tools
158    /// let worker_registry = master_registry.without(["spawn_task", "admin_tool"]);
159    /// ```
160    #[must_use]
161    pub fn without<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
162        use std::collections::HashSet;
163        let exclude: HashSet<&str> = names.into_iter().collect();
164        let mut new = Self {
165            handlers: HashMap::new(),
166            interceptors: self.interceptors.clone(),
167        };
168        for (name, handler) in &self.handlers {
169            if !exclude.contains(name.as_str()) {
170                new.handlers.insert(name.clone(), Arc::clone(handler));
171            }
172        }
173        new
174    }
175
176    /// Returns a new registry with only the named tools.
177    ///
178    /// Useful for creating minimal registries with specific capabilities.
179    ///
180    /// # Example
181    ///
182    /// ```rust
183    /// use llm_stack::ToolRegistry;
184    ///
185    /// let full_registry: ToolRegistry<()> = ToolRegistry::new();
186    /// // ... register tools ...
187    ///
188    /// // Read-only registry with just search tools
189    /// let search_registry = full_registry.only(["search_docs", "search_web"]);
190    /// ```
191    #[must_use]
192    pub fn only<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
193        use std::collections::HashSet;
194        let include: HashSet<&str> = names.into_iter().collect();
195        let mut new = Self {
196            handlers: HashMap::new(),
197            interceptors: self.interceptors.clone(),
198        };
199        for (name, handler) in &self.handlers {
200            if include.contains(name.as_str()) {
201                new.handlers.insert(name.clone(), Arc::clone(handler));
202            }
203        }
204        new
205    }
206
207    /// Sets the interceptor stack for all tool executions.
208    ///
209    /// Interceptors run in the order added (first = outermost). They can
210    /// inspect, modify, or block tool calls before they reach the handler.
211    ///
212    /// # Example
213    ///
214    /// ```rust,ignore
215    /// use llm_stack::{ToolRegistry, tool_fn};
216    /// use llm_stack::intercept::{InterceptorStack, ToolExec, Approval, ApprovalDecision, Retry};
217    ///
218    /// let registry: ToolRegistry<()> = ToolRegistry::new()
219    ///     .with_interceptors(
220    ///         InterceptorStack::<ToolExec<()>>::new()
221    ///             .with(Approval::new(|req| {
222    ///                 if req.name == "dangerous" {
223    ///                     ApprovalDecision::Deny("Not allowed".into())
224    ///                 } else {
225    ///                     ApprovalDecision::Allow
226    ///                 }
227    ///             }))
228    ///             .with(Retry::default())
229    ///     );
230    /// ```
231    #[must_use]
232    pub fn with_interceptors(mut self, interceptors: InterceptorStack<ToolExec<Ctx>>) -> Self {
233        self.interceptors = interceptors;
234        self
235    }
236
237    /// Executes a single tool call with schema validation and optional retry.
238    ///
239    /// 1. Looks up the handler by [`ToolCall::name`]
240    /// 2. Validates arguments against the tool's parameter schema
241    /// 3. Runs the call through interceptors (if any)
242    /// 4. Invokes the handler with the provided context
243    /// 5. If the tool has retry configuration and execution fails,
244    ///    retries with exponential backoff
245    ///
246    /// Returns a [`ToolResult`] (always succeeds at the outer level).
247    /// Execution errors are captured in `ToolResult::is_error`.
248    pub async fn execute(&self, call: &ToolCall, ctx: &Ctx) -> ToolResult {
249        let Some(handler) = self.handlers.get(&call.name) else {
250            return ToolResult {
251                tool_call_id: call.id.clone(),
252                content: format!("Unknown tool: {}", call.name),
253                is_error: true,
254            };
255        };
256
257        // Validate arguments against schema
258        #[cfg(feature = "schema")]
259        {
260            let definition = handler.definition();
261            if let Err(e) = definition.parameters.validate(&call.arguments) {
262                return ToolResult {
263                    tool_call_id: call.id.clone(),
264                    content: format!("Invalid arguments for tool '{}': {e}", call.name),
265                    is_error: true,
266                };
267            }
268        }
269
270        // Create the request for interceptors
271        let request = ToolRequest {
272            name: call.name.clone(),
273            call_id: call.id.clone(),
274            arguments: call.arguments.clone(),
275        };
276
277        // Create the operation that executes the handler
278        let operation = ToolHandlerOperation {
279            handler: handler.clone(),
280            ctx,
281            retry_config: handler.definition().retry,
282        };
283
284        // Execute through interceptor stack
285        let response = self.interceptors.execute(&request, &operation).await;
286
287        ToolResult {
288            tool_call_id: call.id.clone(),
289            content: response.content,
290            is_error: response.is_error,
291        }
292    }
293
294    /// Executes a tool by name with the given arguments.
295    ///
296    /// This is a lower-level method used internally when the tool call
297    /// components are already separated (e.g., from `execute_with_events`).
298    /// Accepts owned arguments to avoid an extra deep clone of `serde_json::Value`.
299    pub(crate) async fn execute_by_name(
300        &self,
301        name: &str,
302        call_id: &str,
303        arguments: serde_json::Value,
304        ctx: &Ctx,
305    ) -> ToolResult {
306        let Some(handler) = self.handlers.get(name) else {
307            return ToolResult {
308                tool_call_id: call_id.to_string(),
309                content: format!("Unknown tool: {name}"),
310                is_error: true,
311            };
312        };
313
314        // Validate arguments against schema
315        #[cfg(feature = "schema")]
316        {
317            let definition = handler.definition();
318            if let Err(e) = definition.parameters.validate(&arguments) {
319                return ToolResult {
320                    tool_call_id: call_id.to_string(),
321                    content: format!("Invalid arguments for tool '{name}': {e}"),
322                    is_error: true,
323                };
324            }
325        }
326
327        // Build ToolRequest directly — moves owned arguments, no deep clone
328        let request = ToolRequest {
329            name: name.to_string(),
330            call_id: call_id.to_string(),
331            arguments,
332        };
333
334        let operation = ToolHandlerOperation {
335            handler: handler.clone(),
336            ctx,
337            retry_config: handler.definition().retry,
338        };
339
340        let response = self.interceptors.execute(&request, &operation).await;
341
342        ToolResult {
343            tool_call_id: request.call_id,
344            content: response.content,
345            is_error: response.is_error,
346        }
347    }
348
349    /// Executes multiple tool calls, preserving order.
350    ///
351    /// When `parallel` is true, all calls run concurrently via
352    /// `futures::future::join_all`. When false, they run sequentially.
353    pub async fn execute_all(
354        &self,
355        calls: &[ToolCall],
356        ctx: &Ctx,
357        parallel: bool,
358    ) -> Vec<ToolResult> {
359        if !parallel || calls.len() <= 1 {
360            let mut results = Vec::with_capacity(calls.len());
361            for call in calls {
362                results.push(self.execute(call, ctx).await);
363            }
364            return results;
365        }
366
367        // Parallel execution using join_all (no spawn needed)
368        let futures: Vec<_> = calls.iter().map(|call| self.execute(call, ctx)).collect();
369        futures::future::join_all(futures).await
370    }
371}
372
373/// Computes backoff duration with exponential growth and jitter.
374///
375/// Formula: `min(initial * multiplier^attempt, max) * random(1-jitter, 1)`
376fn compute_backoff(config: &ToolRetryConfig, attempt: u32) -> Duration {
377    // Safe to cast: attempt is bounded by max_retries which is u32,
378    // and reasonable values are << i32::MAX
379    #[allow(clippy::cast_possible_wrap)]
380    let base =
381        config.initial_backoff.as_secs_f64() * config.backoff_multiplier.powi(attempt as i32);
382    let capped = base.min(config.max_backoff.as_secs_f64());
383
384    // Apply jitter: random value in range [1-jitter, 1]
385    let jitter_factor = if config.jitter > 0.0 {
386        let min_factor = 1.0 - config.jitter;
387        let mut rng = rand::rng();
388        rng.random_range(min_factor..=1.0)
389    } else {
390        1.0
391    };
392
393    Duration::from_secs_f64(capped * jitter_factor)
394}
395
396/// Wraps a tool handler as an [`Operation`] for the interceptor stack.
397///
398/// This struct captures the handler, context, and retry config so that
399/// the interceptor stack can execute the tool.
400struct ToolHandlerOperation<'a, Ctx: Send + Sync + 'static> {
401    handler: Arc<dyn ToolHandler<Ctx>>,
402    ctx: &'a Ctx,
403    retry_config: Option<ToolRetryConfig>,
404}
405
406impl<Ctx: Send + Sync + 'static> Operation<ToolExec<Ctx>> for ToolHandlerOperation<'_, Ctx> {
407    fn execute<'b>(
408        &'b self,
409        input: &'b ToolRequest,
410    ) -> Pin<Box<dyn Future<Output = ToolResponse> + Send + 'b>>
411    where
412        ToolRequest: Sync,
413    {
414        Box::pin(async move {
415            match &self.retry_config {
416                Some(config) => execute_with_retry(&self.handler, input, self.ctx, config).await,
417                None => execute_once(&self.handler, input, self.ctx).await,
418            }
419        })
420    }
421}
422
423/// Executes a tool once without retry.
424async fn execute_once<Ctx: Send + Sync + 'static>(
425    handler: &Arc<dyn ToolHandler<Ctx>>,
426    request: &ToolRequest,
427    ctx: &Ctx,
428) -> ToolResponse {
429    match handler.execute(request.arguments.clone(), ctx).await {
430        Ok(output) => ToolResponse {
431            content: output.content,
432            is_error: false,
433        },
434        Err(e) => ToolResponse {
435            content: e.message,
436            is_error: true,
437        },
438    }
439}
440
441/// Executes a tool with retry logic.
442async fn execute_with_retry<Ctx: Send + Sync + 'static>(
443    handler: &Arc<dyn ToolHandler<Ctx>>,
444    request: &ToolRequest,
445    ctx: &Ctx,
446    config: &ToolRetryConfig,
447) -> ToolResponse {
448    let mut attempt = 0u32;
449
450    loop {
451        match handler.execute(request.arguments.clone(), ctx).await {
452            Ok(output) => {
453                return ToolResponse {
454                    content: output.content,
455                    is_error: false,
456                };
457            }
458            Err(e) => {
459                let error_msg = e.message;
460
461                // Check if we should retry this error
462                let should_retry = config
463                    .retry_if
464                    .as_ref()
465                    .is_none_or(|predicate| predicate(&error_msg));
466
467                if !should_retry || attempt >= config.max_retries {
468                    return ToolResponse {
469                        content: error_msg,
470                        is_error: true,
471                    };
472                }
473
474                // Calculate backoff with jitter
475                let backoff = compute_backoff(config, attempt);
476                tokio::time::sleep(backoff).await;
477
478                attempt += 1;
479            }
480        }
481    }
482}