Skip to main content

claude_api/tool_dispatch/
registry.rs

1//! [`ToolRegistry`] and the [`FnTool`] closure adapter.
2//!
3//! The registry holds heterogeneous [`Tool`] implementations behind
4//! `Arc<dyn Tool>` and supports two registration shapes:
5//!
6//! - [`ToolRegistry::register_tool`] takes anything that implements [`Tool`]
7//!   directly.
8//! - [`ToolRegistry::register`] takes a closure plus name/schema; the
9//!   closure is wrapped in an internal [`FnTool`] that implements [`Tool`].
10//!
11//! Both reduce to the same `Arc<dyn Tool>`, so the agent loop runner
12//! (#20) and the model's tool list ([`ToolRegistry::to_messages_tools`])
13//! treat them identically.
14
15use std::collections::HashMap;
16use std::future::Future;
17use std::marker::PhantomData;
18use std::sync::Arc;
19
20use async_trait::async_trait;
21
22use crate::messages::tools::{CustomTool, Tool as MessagesTool};
23use crate::tool_dispatch::tool::{Tool, ToolError};
24
25/// In-memory registry of tools keyed by name.
26///
27/// Names must be unique; registering twice with the same name **replaces**
28/// the existing entry. Use [`Self::contains`] to check first when overwrite
29/// is undesired.
30#[derive(Default)]
31pub struct ToolRegistry {
32    tools: HashMap<String, Arc<dyn Tool>>,
33}
34
35impl ToolRegistry {
36    /// An empty registry.
37    #[must_use]
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Register a value that implements [`Tool`] directly. Useful for tools
43    /// that have their own state or non-trivial logic worth giving a
44    /// dedicated type.
45    pub fn register_tool<T: Tool>(&mut self, tool: T) -> &mut Self {
46        let name = tool.name().to_owned();
47        self.tools.insert(name, Arc::new(tool));
48        self
49    }
50
51    /// Register a closure-based tool. The closure receives the model's raw
52    /// input as a [`serde_json::Value`] and returns the tool result. Use
53    /// [`ToolError::invalid_input`] for input-shape failures and
54    /// [`ToolError::execution`] to wrap any other error type.
55    ///
56    /// # Example
57    ///
58    /// ```
59    /// use claude_api::tool_dispatch::ToolRegistry;
60    /// use serde_json::json;
61    ///
62    /// let mut registry = ToolRegistry::new();
63    /// registry.register(
64    ///     "echo",
65    ///     json!({"type": "object", "properties": {"text": {"type": "string"}}}),
66    ///     |input| async move { Ok(input) },
67    /// );
68    /// assert!(registry.contains("echo"));
69    /// ```
70    pub fn register<F, Fut>(
71        &mut self,
72        name: impl Into<String>,
73        schema: serde_json::Value,
74        handler: F,
75    ) -> &mut Self
76    where
77        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
78        Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
79    {
80        let name = name.into();
81        let tool = FnTool::new(name.clone(), schema, handler);
82        self.tools.insert(name, Arc::new(tool));
83        self
84    }
85
86    /// Like [`Self::register`] but also attaches a description.
87    pub fn register_described<F, Fut>(
88        &mut self,
89        name: impl Into<String>,
90        description: impl Into<String>,
91        schema: serde_json::Value,
92        handler: F,
93    ) -> &mut Self
94    where
95        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
96        Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
97    {
98        let name = name.into();
99        let mut tool = FnTool::new(name.clone(), schema, handler);
100        tool.description = Some(description.into());
101        self.tools.insert(name, Arc::new(tool));
102        self
103    }
104
105    /// Borrow a registered tool by name.
106    #[must_use]
107    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
108        self.tools.get(name)
109    }
110
111    /// Whether a tool with the given name is registered.
112    #[must_use]
113    pub fn contains(&self, name: &str) -> bool {
114        self.tools.contains_key(name)
115    }
116
117    /// Number of registered tools.
118    #[must_use]
119    pub fn len(&self) -> usize {
120        self.tools.len()
121    }
122
123    /// Whether the registry is empty.
124    #[must_use]
125    pub fn is_empty(&self) -> bool {
126        self.tools.is_empty()
127    }
128
129    /// Iterator over registered tool names.
130    pub fn names(&self) -> impl Iterator<Item = &str> {
131        self.tools.keys().map(String::as_str)
132    }
133
134    /// Build the [`Vec<MessagesTool>`](crate::messages::tools::Tool) you
135    /// pass to `CreateMessageRequestBuilder::tools`. Includes name,
136    /// description, and schema for every registered tool.
137    #[must_use]
138    pub fn to_messages_tools(&self) -> Vec<MessagesTool> {
139        self.tools
140            .values()
141            .map(|t| {
142                let mut ct = CustomTool::new(t.name(), t.schema());
143                if let Some(desc) = t.description() {
144                    ct = ct.description(desc);
145                }
146                MessagesTool::Custom(ct)
147            })
148            .collect()
149    }
150
151    /// Look up a tool by name and invoke it with the given input.
152    ///
153    /// Returns [`ToolError::Unknown`] if no tool by that name is registered.
154    /// Other errors are propagated from the tool's `invoke` impl.
155    pub async fn dispatch(
156        &self,
157        name: &str,
158        input: serde_json::Value,
159    ) -> Result<serde_json::Value, ToolError> {
160        let tool = self.tools.get(name).ok_or_else(|| ToolError::Unknown {
161            name: name.to_owned(),
162        })?;
163        tool.invoke(input).await
164    }
165}
166
167impl std::fmt::Debug for ToolRegistry {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        // Tools don't necessarily implement Debug; show names only.
170        f.debug_struct("ToolRegistry")
171            .field("tools", &self.tools.keys().collect::<Vec<_>>())
172            .finish()
173    }
174}
175
176/// Internal adapter: wraps a closure and exposes it through the [`Tool`]
177/// trait. Created by [`ToolRegistry::register`] and
178/// [`ToolRegistry::register_described`].
179pub struct FnTool<F, Fut>
180where
181    F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
182    Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
183{
184    name: String,
185    schema: serde_json::Value,
186    description: Option<String>,
187    handler: F,
188    _phantom: PhantomData<fn() -> Fut>,
189}
190
191impl<F, Fut> FnTool<F, Fut>
192where
193    F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
194    Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
195{
196    /// Build an `FnTool` from a name, JSON schema, and async closure.
197    pub fn new(name: impl Into<String>, schema: serde_json::Value, handler: F) -> Self {
198        Self {
199            name: name.into(),
200            schema,
201            description: None,
202            handler,
203            _phantom: PhantomData,
204        }
205    }
206
207    /// Attach a description.
208    #[must_use]
209    pub fn with_description(mut self, description: impl Into<String>) -> Self {
210        self.description = Some(description.into());
211        self
212    }
213}
214
215#[async_trait]
216impl<F, Fut> Tool for FnTool<F, Fut>
217where
218    F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
219    Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
220{
221    fn name(&self) -> &str {
222        &self.name
223    }
224
225    fn description(&self) -> Option<&str> {
226        self.description.as_deref()
227    }
228
229    fn schema(&self) -> serde_json::Value {
230        self.schema.clone()
231    }
232
233    async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
234        (self.handler)(input).await
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::messages::tools::Tool as MessagesTool;
242    use pretty_assertions::assert_eq;
243    use serde_json::{Value, json};
244
245    fn echo_schema() -> Value {
246        json!({"type": "object", "properties": {"text": {"type": "string"}}})
247    }
248
249    // A trait-impl tool, so we cover both registration paths.
250    struct UpperTool;
251
252    #[async_trait]
253    impl Tool for UpperTool {
254        // Trait dictates the return type; literal-vs-stored is up to the impl.
255        #[allow(clippy::unnecessary_literal_bound)]
256        fn name(&self) -> &str {
257            "upper"
258        }
259        fn schema(&self) -> Value {
260            json!({"type": "object", "properties": {"text": {"type": "string"}}})
261        }
262        async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
263            let s = input
264                .get("text")
265                .and_then(Value::as_str)
266                .ok_or_else(|| ToolError::invalid_input("missing 'text'"))?;
267            Ok(json!({"upper": s.to_uppercase()}))
268        }
269    }
270
271    #[tokio::test]
272    async fn register_and_dispatch_closure_tool() {
273        let mut registry = ToolRegistry::new();
274        registry.register("echo", echo_schema(), |input| async move { Ok(input) });
275
276        assert!(registry.contains("echo"));
277        assert_eq!(registry.len(), 1);
278
279        let result = registry
280            .dispatch("echo", json!({"text": "hi"}))
281            .await
282            .unwrap();
283        assert_eq!(result, json!({"text": "hi"}));
284    }
285
286    #[tokio::test]
287    async fn register_and_dispatch_trait_tool() {
288        let mut registry = ToolRegistry::new();
289        registry.register_tool(UpperTool);
290
291        let result = registry
292            .dispatch("upper", json!({"text": "rust"}))
293            .await
294            .unwrap();
295        assert_eq!(result, json!({"upper": "RUST"}));
296    }
297
298    #[tokio::test]
299    async fn closure_and_trait_tools_coexist() {
300        let mut registry = ToolRegistry::new();
301        registry
302            .register_tool(UpperTool)
303            .register("echo", echo_schema(), |input| async move { Ok(input) });
304
305        assert_eq!(registry.len(), 2);
306        let names: std::collections::HashSet<_> = registry.names().collect();
307        assert!(names.contains("upper"));
308        assert!(names.contains("echo"));
309
310        let r1 = registry
311            .dispatch("upper", json!({"text": "ok"}))
312            .await
313            .unwrap();
314        let r2 = registry
315            .dispatch("echo", json!({"text": "ok"}))
316            .await
317            .unwrap();
318        assert_eq!(r1, json!({"upper": "OK"}));
319        assert_eq!(r2, json!({"text": "ok"}));
320    }
321
322    #[tokio::test]
323    async fn dispatch_unknown_returns_unknown_error() {
324        let registry = ToolRegistry::new();
325        let err = registry.dispatch("nope", json!({})).await.unwrap_err();
326        let ToolError::Unknown { name } = err else {
327            panic!("expected Unknown variant");
328        };
329        assert_eq!(name, "nope");
330    }
331
332    #[tokio::test]
333    async fn dispatch_propagates_invalid_input_error_from_tool() {
334        let mut registry = ToolRegistry::new();
335        registry.register_tool(UpperTool);
336        let err = registry.dispatch("upper", json!({})).await.unwrap_err();
337        let ToolError::InvalidInput(msg) = err else {
338            panic!("expected InvalidInput");
339        };
340        assert!(msg.contains("'text'"));
341    }
342
343    #[tokio::test]
344    async fn duplicate_register_replaces_previous_entry() {
345        let mut registry = ToolRegistry::new();
346        registry.register("dup", echo_schema(), |_| async move {
347            Ok(json!({"version": "first"}))
348        });
349        registry.register("dup", echo_schema(), |_| async move {
350            Ok(json!({"version": "second"}))
351        });
352        assert_eq!(registry.len(), 1);
353        let r = registry.dispatch("dup", json!({})).await.unwrap();
354        assert_eq!(r, json!({"version": "second"}));
355    }
356
357    #[test]
358    fn to_messages_tools_includes_name_schema_and_description() {
359        let mut registry = ToolRegistry::new();
360        registry.register_tool(UpperTool).register_described(
361            "echo",
362            "Returns its input verbatim.",
363            echo_schema(),
364            |input| async move { Ok(input) },
365        );
366
367        let tools = registry.to_messages_tools();
368        assert_eq!(tools.len(), 2);
369
370        // Every entry is a Custom tool with the right name and schema.
371        let mut by_name: std::collections::HashMap<String, MessagesTool> =
372            std::collections::HashMap::new();
373        for t in tools {
374            let MessagesTool::Custom(ct) = &t else {
375                panic!("expected custom variant");
376            };
377            by_name.insert(ct.name.clone(), t);
378        }
379
380        let MessagesTool::Custom(echo) = by_name.get("echo").unwrap() else {
381            panic!("expected echo Custom");
382        };
383        assert_eq!(
384            echo.description.as_deref(),
385            Some("Returns its input verbatim.")
386        );
387        assert!(echo.input_schema.is_object());
388
389        let MessagesTool::Custom(upper) = by_name.get("upper").unwrap() else {
390            panic!("expected upper Custom");
391        };
392        assert_eq!(upper.description, None); // UpperTool didn't override description
393    }
394
395    #[tokio::test]
396    async fn registry_works_through_dyn_dispatch() {
397        // Sanity check: tools live behind Arc<dyn Tool> and dispatch correctly
398        // through trait objects.
399        let mut registry = ToolRegistry::new();
400        registry.register_tool(UpperTool);
401
402        let tool: &Arc<dyn Tool> = registry.get("upper").unwrap();
403        let r = tool.invoke(json!({"text": "abc"})).await.unwrap();
404        assert_eq!(r, json!({"upper": "ABC"}));
405    }
406
407    #[test]
408    fn debug_impl_lists_tool_names() {
409        let mut registry = ToolRegistry::new();
410        registry.register_tool(UpperTool);
411        let dbg = format!("{registry:?}");
412        assert!(dbg.contains("upper"), "{dbg}");
413    }
414
415    #[test]
416    fn registry_is_send_and_sync() {
417        fn assert_send_sync<T: Send + Sync>() {}
418        assert_send_sync::<ToolRegistry>();
419    }
420}