Skip to main content

llm_stack/tool/
registry.rs

1//! Tool registry for managing and executing tools.
2
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::time::Duration;
8
9use rand::Rng;
10
11use super::ToolHandler;
12use crate::chat::{ToolCall, ToolResult};
13use crate::intercept::domain::{ToolExec, ToolRequest, ToolResponse};
14use crate::intercept::{InterceptorStack, Operation};
15use crate::provider::{ToolDefinition, ToolRetryConfig};
16
17/// A registry of tool handlers, indexed by name.
18///
19/// Generic over context type `Ctx` which is passed to tool handlers on
20/// execution. Default is `()` for backwards compatibility.
21///
22/// Provides validation of tool call arguments against their schemas
23/// and parallel execution of multiple tool calls.
24///
25/// # Interceptors
26///
27/// Tool execution can be wrapped with interceptors for cross-cutting concerns
28/// like logging, approval gates, or rate limiting:
29///
30/// ```rust,ignore
31/// use llm_stack::ToolRegistry;
32/// use llm_stack::tool::tool_fn;
33/// use llm_stack::intercept::{InterceptorStack, ToolExec, Approval, ApprovalDecision};
34///
35/// let mut registry: ToolRegistry<()> = ToolRegistry::new()
36///     .with_interceptors(
37///         InterceptorStack::<ToolExec<()>>::new()
38///             .with(Approval::new(|req| {
39///                 if req.name.starts_with("dangerous_") {
40///                     ApprovalDecision::Deny("Not allowed".into())
41///                 } else {
42///                     ApprovalDecision::Allow
43///                 }
44///             }))
45///     );
46/// ```
47pub struct ToolRegistry<Ctx = ()>
48where
49    Ctx: Send + Sync + 'static,
50{
51    pub(crate) handlers: HashMap<String, Arc<dyn ToolHandler<Ctx>>>,
52    interceptors: InterceptorStack<ToolExec<Ctx>>,
53}
54
55impl<Ctx> Default for ToolRegistry<Ctx>
56where
57    Ctx: Send + Sync + 'static,
58{
59    fn default() -> Self {
60        Self {
61            handlers: HashMap::new(),
62            interceptors: InterceptorStack::new(),
63        }
64    }
65}
66
67impl<Ctx> Clone for ToolRegistry<Ctx>
68where
69    Ctx: Send + Sync + 'static,
70{
71    /// Clone the registry.
72    ///
73    /// This is cheap — it clones `Arc` pointers to handlers, not the
74    /// handlers themselves.
75    fn clone(&self) -> Self {
76        Self {
77            handlers: self.handlers.clone(),
78            interceptors: self.interceptors.clone(),
79        }
80    }
81}
82
83impl<Ctx> std::fmt::Debug for ToolRegistry<Ctx>
84where
85    Ctx: Send + Sync + 'static,
86{
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("ToolRegistry")
89            .field("tools", &self.handlers.keys().collect::<Vec<_>>())
90            .field("interceptors", &self.interceptors.len())
91            .finish()
92    }
93}
94
95impl<Ctx: Send + Sync + 'static> ToolRegistry<Ctx> {
96    /// Creates an empty registry.
97    pub fn new() -> Self {
98        Self::default()
99    }
100
101    /// Registers a tool handler.
102    ///
103    /// If a handler with the same name already exists, it is replaced.
104    pub fn register(&mut self, handler: impl ToolHandler<Ctx> + 'static) -> &mut Self {
105        let name = handler.definition().name.clone();
106        self.handlers.insert(name, Arc::new(handler));
107        self
108    }
109
110    /// Registers a shared tool handler.
111    pub fn register_shared(&mut self, handler: Arc<dyn ToolHandler<Ctx>>) -> &mut Self {
112        let name = handler.definition().name.clone();
113        self.handlers.insert(name, handler);
114        self
115    }
116
117    /// Returns the handler for the given tool name.
118    pub fn get(&self, name: &str) -> Option<&Arc<dyn ToolHandler<Ctx>>> {
119        self.handlers.get(name)
120    }
121
122    /// Returns whether a tool with the given name is registered.
123    pub fn contains(&self, name: &str) -> bool {
124        self.handlers.contains_key(name)
125    }
126
127    /// Returns the definitions of all registered tools.
128    ///
129    /// Pass this to [`ChatParams::tools`](crate::provider::ChatParams::tools) to tell the model which
130    /// tools are available.
131    pub fn definitions(&self) -> Vec<ToolDefinition> {
132        self.handlers.values().map(|h| h.definition()).collect()
133    }
134
135    /// Returns the number of registered tools.
136    pub fn len(&self) -> usize {
137        self.handlers.len()
138    }
139
140    /// Returns true if no tools are registered.
141    pub fn is_empty(&self) -> bool {
142        self.handlers.is_empty()
143    }
144
145    /// Returns a new registry excluding the named tools.
146    ///
147    /// Useful for creating scoped registries in Master/Worker patterns
148    /// where workers should not have access to certain tools (e.g., `spawn_task`).
149    ///
150    /// # Example
151    ///
152    /// ```rust
153    /// use llm_stack::ToolRegistry;
154    ///
155    /// let master_registry: ToolRegistry<()> = ToolRegistry::new();
156    /// // ... register tools ...
157    ///
158    /// // Workers can't spawn or use admin tools
159    /// let worker_registry = master_registry.without(["spawn_task", "admin_tool"]);
160    /// ```
161    #[must_use]
162    pub fn without<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
163        use std::collections::HashSet;
164        let exclude: HashSet<&str> = names.into_iter().collect();
165        let mut new = Self {
166            handlers: HashMap::new(),
167            interceptors: self.interceptors.clone(),
168        };
169        for (name, handler) in &self.handlers {
170            if !exclude.contains(name.as_str()) {
171                new.handlers.insert(name.clone(), Arc::clone(handler));
172            }
173        }
174        new
175    }
176
177    /// Returns a new registry with only the named tools.
178    ///
179    /// Useful for creating minimal registries with specific capabilities.
180    ///
181    /// # Example
182    ///
183    /// ```rust
184    /// use llm_stack::ToolRegistry;
185    ///
186    /// let full_registry: ToolRegistry<()> = ToolRegistry::new();
187    /// // ... register tools ...
188    ///
189    /// // Read-only registry with just search tools
190    /// let search_registry = full_registry.only(["search_docs", "search_web"]);
191    /// ```
192    #[must_use]
193    pub fn only<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
194        use std::collections::HashSet;
195        let include: HashSet<&str> = names.into_iter().collect();
196        let mut new = Self {
197            handlers: HashMap::new(),
198            interceptors: self.interceptors.clone(),
199        };
200        for (name, handler) in &self.handlers {
201            if include.contains(name.as_str()) {
202                new.handlers.insert(name.clone(), Arc::clone(handler));
203            }
204        }
205        new
206    }
207
208    /// Sets the interceptor stack for all tool executions.
209    ///
210    /// Interceptors run in the order added (first = outermost). They can
211    /// inspect, modify, or block tool calls before they reach the handler.
212    ///
213    /// # Example
214    ///
215    /// ```rust,ignore
216    /// use llm_stack::ToolRegistry;
217    /// use llm_stack::tool::tool_fn;
218    /// use llm_stack::intercept::{InterceptorStack, ToolExec, Approval, ApprovalDecision, Retry};
219    ///
220    /// let registry: ToolRegistry<()> = ToolRegistry::new()
221    ///     .with_interceptors(
222    ///         InterceptorStack::<ToolExec<()>>::new()
223    ///             .with(Approval::new(|req| {
224    ///                 if req.name == "dangerous" {
225    ///                     ApprovalDecision::Deny("Not allowed".into())
226    ///                 } else {
227    ///                     ApprovalDecision::Allow
228    ///                 }
229    ///             }))
230    ///             .with(Retry::default())
231    ///     );
232    /// ```
233    #[must_use]
234    pub fn with_interceptors(mut self, interceptors: InterceptorStack<ToolExec<Ctx>>) -> Self {
235        self.interceptors = interceptors;
236        self
237    }
238
239    /// Executes a single tool call with schema validation and optional retry.
240    ///
241    /// 1. Looks up the handler by [`ToolCall::name`]
242    /// 2. Validates arguments against the tool's parameter schema
243    /// 3. Runs the call through interceptors (if any)
244    /// 4. Invokes the handler with the provided context
245    /// 5. If the tool has retry configuration and execution fails,
246    ///    retries with exponential backoff
247    ///
248    /// Returns a [`ToolResult`] (always succeeds at the outer level).
249    /// Execution errors are captured in `ToolResult::is_error`.
250    pub async fn execute(&self, call: &ToolCall, ctx: &Ctx) -> ToolResult {
251        self.execute_inner(&call.name, &call.id, call.arguments.clone(), ctx)
252            .await
253    }
254
255    /// Executes a tool by name with the given arguments.
256    ///
257    /// This is a lower-level method used internally when the tool call
258    /// components are already separated (e.g., from `execute_with_events`).
259    /// Accepts owned arguments to avoid an extra deep clone of `serde_json::Value`.
260    pub(crate) async fn execute_by_name(
261        &self,
262        name: &str,
263        call_id: &str,
264        arguments: serde_json::Value,
265        ctx: &Ctx,
266    ) -> ToolResult {
267        self.execute_inner(name, call_id, arguments, ctx).await
268    }
269
270    /// Shared implementation for `execute` and `execute_by_name`.
271    async fn execute_inner(
272        &self,
273        name: &str,
274        call_id: &str,
275        arguments: serde_json::Value,
276        ctx: &Ctx,
277    ) -> ToolResult {
278        let Some(handler) = self.handlers.get(name) else {
279            return ToolResult {
280                tool_call_id: call_id.to_string(),
281                content: format!("Unknown tool: {name}"),
282                is_error: true,
283            };
284        };
285
286        // Validate arguments against schema
287        #[cfg(feature = "schema")]
288        {
289            let definition = handler.definition();
290            if let Err(e) = definition.parameters.validate(&arguments) {
291                return ToolResult {
292                    tool_call_id: call_id.to_string(),
293                    content: format!("Invalid arguments for tool '{name}': {e}"),
294                    is_error: true,
295                };
296            }
297        }
298
299        let request = ToolRequest {
300            name: name.to_string(),
301            call_id: call_id.to_string(),
302            arguments,
303        };
304
305        let operation = ToolHandlerOperation {
306            handler: handler.clone(),
307            ctx,
308            retry_config: handler.definition().retry,
309        };
310
311        let response = self.interceptors.execute(&request, &operation).await;
312
313        ToolResult {
314            tool_call_id: request.call_id,
315            content: response.content,
316            is_error: response.is_error,
317        }
318    }
319
320    /// Executes multiple tool calls, preserving order.
321    ///
322    /// When `parallel` is true, all calls run concurrently via
323    /// `futures::future::join_all`. When false, they run sequentially.
324    pub async fn execute_all(
325        &self,
326        calls: &[ToolCall],
327        ctx: &Ctx,
328        parallel: bool,
329    ) -> Vec<ToolResult> {
330        if !parallel || calls.len() <= 1 {
331            let mut results = Vec::with_capacity(calls.len());
332            for call in calls {
333                results.push(self.execute(call, ctx).await);
334            }
335            return results;
336        }
337
338        // Parallel execution using join_all (no spawn needed)
339        let futures: Vec<_> = calls.iter().map(|call| self.execute(call, ctx)).collect();
340        futures::future::join_all(futures).await
341    }
342}
343
344/// Computes backoff duration with exponential growth and jitter.
345///
346/// Formula: `min(initial * multiplier^attempt, max) * random(1-jitter, 1)`
347fn compute_backoff(config: &ToolRetryConfig, attempt: u32) -> Duration {
348    // Safe to cast: attempt is bounded by max_retries which is u32,
349    // and reasonable values are << i32::MAX
350    #[allow(clippy::cast_possible_wrap)]
351    let base =
352        config.initial_backoff.as_secs_f64() * config.backoff_multiplier.powi(attempt as i32);
353    let capped = base.min(config.max_backoff.as_secs_f64());
354
355    // Apply jitter: random value in range [1-jitter, 1]
356    let jitter_factor = if config.jitter > 0.0 {
357        let min_factor = 1.0 - config.jitter;
358        let mut rng = rand::rng();
359        rng.random_range(min_factor..=1.0)
360    } else {
361        1.0
362    };
363
364    Duration::from_secs_f64(capped * jitter_factor)
365}
366
367/// Wraps a tool handler as an [`Operation`] for the interceptor stack.
368///
369/// This struct captures the handler, context, and retry config so that
370/// the interceptor stack can execute the tool.
371struct ToolHandlerOperation<'a, Ctx: Send + Sync + 'static> {
372    handler: Arc<dyn ToolHandler<Ctx>>,
373    ctx: &'a Ctx,
374    retry_config: Option<ToolRetryConfig>,
375}
376
377impl<Ctx: Send + Sync + 'static> Operation<ToolExec<Ctx>> for ToolHandlerOperation<'_, Ctx> {
378    fn execute<'b>(
379        &'b self,
380        input: &'b ToolRequest,
381    ) -> Pin<Box<dyn Future<Output = ToolResponse> + Send + 'b>>
382    where
383        ToolRequest: Sync,
384    {
385        Box::pin(async move {
386            match &self.retry_config {
387                Some(config) => execute_with_retry(&self.handler, input, self.ctx, config).await,
388                None => execute_once(&self.handler, input, self.ctx).await,
389            }
390        })
391    }
392}
393
394/// Executes a tool once without retry.
395async fn execute_once<Ctx: Send + Sync + 'static>(
396    handler: &Arc<dyn ToolHandler<Ctx>>,
397    request: &ToolRequest,
398    ctx: &Ctx,
399) -> ToolResponse {
400    match handler.execute(request.arguments.clone(), ctx).await {
401        Ok(output) => ToolResponse {
402            content: output.content,
403            is_error: false,
404        },
405        Err(e) => ToolResponse {
406            content: e.message,
407            is_error: true,
408        },
409    }
410}
411
412/// Executes a tool with retry logic.
413async fn execute_with_retry<Ctx: Send + Sync + 'static>(
414    handler: &Arc<dyn ToolHandler<Ctx>>,
415    request: &ToolRequest,
416    ctx: &Ctx,
417    config: &ToolRetryConfig,
418) -> ToolResponse {
419    let mut attempt = 0u32;
420
421    loop {
422        match handler.execute(request.arguments.clone(), ctx).await {
423            Ok(output) => {
424                return ToolResponse {
425                    content: output.content,
426                    is_error: false,
427                };
428            }
429            Err(e) => {
430                let error_msg = e.message;
431
432                // Check if we should retry this error
433                let should_retry = config
434                    .retry_if
435                    .as_ref()
436                    .is_none_or(|predicate| predicate(&error_msg));
437
438                if !should_retry || attempt >= config.max_retries {
439                    return ToolResponse {
440                        content: error_msg,
441                        is_error: true,
442                    };
443                }
444
445                // Calculate backoff with jitter
446                let backoff = compute_backoff(config, attempt);
447                tokio::time::sleep(backoff).await;
448
449                attempt += 1;
450            }
451        }
452    }
453}