Skip to main content

claude_api/tool_dispatch/
registry.rs

1//! [`ToolRegistry`] and the [`FnTool`] closure adapter.
2//!
3//! The registry holds heterogeneous [`Tool`] implementations behind
4//! `Arc<dyn Tool>` and supports two registration shapes:
5//!
6//! - [`ToolRegistry::register_tool`] takes anything that implements [`Tool`]
7//!   directly.
8//! - [`ToolRegistry::register`] / [`ToolRegistry::register_described`] take
9//!   a closure plus name/schema; the closure is wrapped in an internal
10//!   [`FnTool`] adapter.
11//!
12//! Both reduce to the same `Arc<dyn Tool>`. [`ToolRegistry::to_messages_tools`]
13//! converts the registered schemas into the [`Tool`](crate::messages::Tool)
14//! slice needed by `CreateMessageRequest::tools`.
15//!
16//! For typed inputs derived from a Rust struct see
17//! [`crate::tool_dispatch::TypedTool`] and the `schemars-tools` feature.
18
19use std::collections::HashMap;
20use std::future::Future;
21use std::marker::PhantomData;
22use std::sync::Arc;
23
24use async_trait::async_trait;
25
26use crate::messages::tools::{CustomTool, Tool as MessagesTool};
27use crate::tool_dispatch::tool::{Tool, ToolError};
28
29/// In-memory registry of tools keyed by name.
30///
31/// Names must be unique; registering twice with the same name **replaces**
32/// the existing entry. Use [`Self::contains`] to check first when overwrite
33/// is undesired.
34#[derive(Default)]
35pub struct ToolRegistry {
36    tools: HashMap<String, Arc<dyn Tool>>,
37}
38
39impl ToolRegistry {
40    /// An empty registry.
41    #[must_use]
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    /// Register a value that implements [`Tool`] directly. Useful for tools
47    /// that have their own state or non-trivial logic worth giving a
48    /// dedicated type.
49    pub fn register_tool<T: Tool>(&mut self, tool: T) -> &mut Self {
50        let name = tool.name().to_owned();
51        self.tools.insert(name, Arc::new(tool));
52        self
53    }
54
55    /// Register a closure-based tool. The closure receives the model's raw
56    /// input as a [`serde_json::Value`] and returns the tool result. Use
57    /// [`ToolError::invalid_input`] for input-shape failures and
58    /// [`ToolError::execution`] to wrap any other error type.
59    ///
60    /// # Example
61    ///
62    /// ```
63    /// use claude_api::tool_dispatch::ToolRegistry;
64    /// use serde_json::json;
65    ///
66    /// let mut registry = ToolRegistry::new();
67    /// registry.register(
68    ///     "echo",
69    ///     json!({"type": "object", "properties": {"text": {"type": "string"}}}),
70    ///     |input| async move { Ok(input) },
71    /// );
72    /// assert!(registry.contains("echo"));
73    /// ```
74    pub fn register<F, Fut>(
75        &mut self,
76        name: impl Into<String>,
77        schema: serde_json::Value,
78        handler: F,
79    ) -> &mut Self
80    where
81        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
82        Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
83    {
84        let name = name.into();
85        let tool = FnTool::new(name.clone(), schema, handler);
86        self.tools.insert(name, Arc::new(tool));
87        self
88    }
89
90    /// Like [`Self::register`] but also attaches a description.
91    pub fn register_described<F, Fut>(
92        &mut self,
93        name: impl Into<String>,
94        description: impl Into<String>,
95        schema: serde_json::Value,
96        handler: F,
97    ) -> &mut Self
98    where
99        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
100        Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
101    {
102        let name = name.into();
103        let mut tool = FnTool::new(name.clone(), schema, handler);
104        tool.description = Some(description.into());
105        self.tools.insert(name, Arc::new(tool));
106        self
107    }
108
109    /// Borrow a registered tool by name.
110    #[must_use]
111    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
112        self.tools.get(name)
113    }
114
115    /// Whether a tool with the given name is registered.
116    #[must_use]
117    pub fn contains(&self, name: &str) -> bool {
118        self.tools.contains_key(name)
119    }
120
121    /// Number of registered tools.
122    #[must_use]
123    pub fn len(&self) -> usize {
124        self.tools.len()
125    }
126
127    /// Whether the registry is empty.
128    #[must_use]
129    pub fn is_empty(&self) -> bool {
130        self.tools.is_empty()
131    }
132
133    /// Iterator over registered tool names.
134    pub fn names(&self) -> impl Iterator<Item = &str> {
135        self.tools.keys().map(String::as_str)
136    }
137
138    /// Build the [`Vec<MessagesTool>`](crate::messages::tools::Tool) you
139    /// pass to `CreateMessageRequestBuilder::tools`. Includes name,
140    /// description, and schema for every registered tool.
141    #[must_use]
142    pub fn to_messages_tools(&self) -> Vec<MessagesTool> {
143        self.tools
144            .values()
145            .map(|t| {
146                let mut ct = CustomTool::new(t.name(), t.schema());
147                if let Some(desc) = t.description() {
148                    ct = ct.description(desc);
149                }
150                MessagesTool::Custom(ct)
151            })
152            .collect()
153    }
154
155    /// Look up a tool by name and invoke it with the given input.
156    ///
157    /// Returns [`ToolError::Unknown`] if no tool by that name is registered.
158    /// Other errors are propagated from the tool's `invoke` impl.
159    pub async fn dispatch(
160        &self,
161        name: &str,
162        input: serde_json::Value,
163    ) -> Result<serde_json::Value, ToolError> {
164        let tool = self.tools.get(name).ok_or_else(|| ToolError::Unknown {
165            name: name.to_owned(),
166        })?;
167        tool.invoke(input).await
168    }
169}
170
171impl std::fmt::Debug for ToolRegistry {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        // Tools don't necessarily implement Debug; show names only.
174        f.debug_struct("ToolRegistry")
175            .field("tools", &self.tools.keys().collect::<Vec<_>>())
176            .finish()
177    }
178}
179
180/// Internal adapter: wraps a closure and exposes it through the [`Tool`]
181/// trait. Created by [`ToolRegistry::register`] and
182/// [`ToolRegistry::register_described`].
183pub struct FnTool<F, Fut>
184where
185    F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
186    Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
187{
188    name: String,
189    schema: serde_json::Value,
190    description: Option<String>,
191    handler: F,
192    _phantom: PhantomData<fn() -> Fut>,
193}
194
195impl<F, Fut> FnTool<F, Fut>
196where
197    F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
198    Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
199{
200    /// Build an `FnTool` from a name, JSON schema, and async closure.
201    pub fn new(name: impl Into<String>, schema: serde_json::Value, handler: F) -> Self {
202        Self {
203            name: name.into(),
204            schema,
205            description: None,
206            handler,
207            _phantom: PhantomData,
208        }
209    }
210
211    /// Attach a description.
212    #[must_use]
213    pub fn with_description(mut self, description: impl Into<String>) -> Self {
214        self.description = Some(description.into());
215        self
216    }
217}
218
219#[async_trait]
220impl<F, Fut> Tool for FnTool<F, Fut>
221where
222    F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
223    Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
224{
225    fn name(&self) -> &str {
226        &self.name
227    }
228
229    fn description(&self) -> Option<&str> {
230        self.description.as_deref()
231    }
232
233    fn schema(&self) -> serde_json::Value {
234        self.schema.clone()
235    }
236
237    async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
238        (self.handler)(input).await
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::messages::tools::Tool as MessagesTool;
246    use pretty_assertions::assert_eq;
247    use serde_json::{Value, json};
248
249    fn echo_schema() -> Value {
250        json!({"type": "object", "properties": {"text": {"type": "string"}}})
251    }
252
253    // A trait-impl tool, so we cover both registration paths.
254    struct UpperTool;
255
256    #[async_trait]
257    impl Tool for UpperTool {
258        // Trait dictates the return type; literal-vs-stored is up to the impl.
259        #[allow(clippy::unnecessary_literal_bound)]
260        fn name(&self) -> &str {
261            "upper"
262        }
263        fn schema(&self) -> Value {
264            json!({"type": "object", "properties": {"text": {"type": "string"}}})
265        }
266        async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
267            let s = input
268                .get("text")
269                .and_then(Value::as_str)
270                .ok_or_else(|| ToolError::invalid_input("missing 'text'"))?;
271            Ok(json!({"upper": s.to_uppercase()}))
272        }
273    }
274
275    #[tokio::test]
276    async fn register_and_dispatch_closure_tool() {
277        let mut registry = ToolRegistry::new();
278        registry.register("echo", echo_schema(), |input| async move { Ok(input) });
279
280        assert!(registry.contains("echo"));
281        assert_eq!(registry.len(), 1);
282
283        let result = registry
284            .dispatch("echo", json!({"text": "hi"}))
285            .await
286            .unwrap();
287        assert_eq!(result, json!({"text": "hi"}));
288    }
289
290    #[tokio::test]
291    async fn register_and_dispatch_trait_tool() {
292        let mut registry = ToolRegistry::new();
293        registry.register_tool(UpperTool);
294
295        let result = registry
296            .dispatch("upper", json!({"text": "rust"}))
297            .await
298            .unwrap();
299        assert_eq!(result, json!({"upper": "RUST"}));
300    }
301
302    #[tokio::test]
303    async fn closure_and_trait_tools_coexist() {
304        let mut registry = ToolRegistry::new();
305        registry
306            .register_tool(UpperTool)
307            .register("echo", echo_schema(), |input| async move { Ok(input) });
308
309        assert_eq!(registry.len(), 2);
310        let names: std::collections::HashSet<_> = registry.names().collect();
311        assert!(names.contains("upper"));
312        assert!(names.contains("echo"));
313
314        let r1 = registry
315            .dispatch("upper", json!({"text": "ok"}))
316            .await
317            .unwrap();
318        let r2 = registry
319            .dispatch("echo", json!({"text": "ok"}))
320            .await
321            .unwrap();
322        assert_eq!(r1, json!({"upper": "OK"}));
323        assert_eq!(r2, json!({"text": "ok"}));
324    }
325
326    #[tokio::test]
327    async fn dispatch_unknown_returns_unknown_error() {
328        let registry = ToolRegistry::new();
329        let err = registry.dispatch("nope", json!({})).await.unwrap_err();
330        let ToolError::Unknown { name } = err else {
331            panic!("expected Unknown variant");
332        };
333        assert_eq!(name, "nope");
334    }
335
336    #[tokio::test]
337    async fn dispatch_propagates_invalid_input_error_from_tool() {
338        let mut registry = ToolRegistry::new();
339        registry.register_tool(UpperTool);
340        let err = registry.dispatch("upper", json!({})).await.unwrap_err();
341        let ToolError::InvalidInput(msg) = err else {
342            panic!("expected InvalidInput");
343        };
344        assert!(msg.contains("'text'"));
345    }
346
347    #[tokio::test]
348    async fn duplicate_register_replaces_previous_entry() {
349        let mut registry = ToolRegistry::new();
350        registry.register("dup", echo_schema(), |_| async move {
351            Ok(json!({"version": "first"}))
352        });
353        registry.register("dup", echo_schema(), |_| async move {
354            Ok(json!({"version": "second"}))
355        });
356        assert_eq!(registry.len(), 1);
357        let r = registry.dispatch("dup", json!({})).await.unwrap();
358        assert_eq!(r, json!({"version": "second"}));
359    }
360
361    #[test]
362    fn to_messages_tools_includes_name_schema_and_description() {
363        let mut registry = ToolRegistry::new();
364        registry.register_tool(UpperTool).register_described(
365            "echo",
366            "Returns its input verbatim.",
367            echo_schema(),
368            |input| async move { Ok(input) },
369        );
370
371        let tools = registry.to_messages_tools();
372        assert_eq!(tools.len(), 2);
373
374        // Every entry is a Custom tool with the right name and schema.
375        let mut by_name: std::collections::HashMap<String, MessagesTool> =
376            std::collections::HashMap::new();
377        for t in tools {
378            let MessagesTool::Custom(ct) = &t else {
379                panic!("expected custom variant");
380            };
381            by_name.insert(ct.name.clone(), t);
382        }
383
384        let MessagesTool::Custom(echo) = by_name.get("echo").unwrap() else {
385            panic!("expected echo Custom");
386        };
387        assert_eq!(
388            echo.description.as_deref(),
389            Some("Returns its input verbatim.")
390        );
391        assert!(echo.input_schema.is_object());
392
393        let MessagesTool::Custom(upper) = by_name.get("upper").unwrap() else {
394            panic!("expected upper Custom");
395        };
396        assert_eq!(upper.description, None); // UpperTool didn't override description
397    }
398
399    #[tokio::test]
400    async fn registry_works_through_dyn_dispatch() {
401        // Sanity check: tools live behind Arc<dyn Tool> and dispatch correctly
402        // through trait objects.
403        let mut registry = ToolRegistry::new();
404        registry.register_tool(UpperTool);
405
406        let tool: &Arc<dyn Tool> = registry.get("upper").unwrap();
407        let r = tool.invoke(json!({"text": "abc"})).await.unwrap();
408        assert_eq!(r, json!({"upper": "ABC"}));
409    }
410
411    #[test]
412    fn debug_impl_lists_tool_names() {
413        let mut registry = ToolRegistry::new();
414        registry.register_tool(UpperTool);
415        let dbg = format!("{registry:?}");
416        assert!(dbg.contains("upper"), "{dbg}");
417    }
418
419    #[test]
420    fn registry_is_send_and_sync() {
421        fn assert_send_sync<T: Send + Sync>() {}
422        assert_send_sync::<ToolRegistry>();
423    }
424}