Skip to main content

agentic_tools_core/
registry.rs

1//! Tool registry for dynamic dispatch and type-safe native calls.
2
3use crate::context::ToolContext;
4use crate::error::ToolError;
5use crate::fmt::TextFormat;
6use crate::fmt::TextOptions;
7use crate::schema::mcp_schema;
8use crate::tool::Tool;
9use crate::tool::ToolCodec;
10use futures::future::BoxFuture;
11use schemars::Schema;
12use serde_json::Value;
13use std::any::TypeId;
14use std::collections::HashMap;
15use std::collections::HashSet;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19/// Result from `dispatch_json_formatted` containing both JSON data and optional text.
20#[derive(Debug, Clone)]
21pub struct FormattedResult {
22    /// The JSON-serialized output data.
23    pub data: Value,
24    /// Human-readable text representation. None if no `TextFormat` implementation exists
25    /// and fallback wasn't requested.
26    pub text: Option<String>,
27}
28
29/// Type-erased tool for dynamic dispatch.
30pub trait ErasedTool: Send + Sync {
31    /// Get the tool's name.
32    fn name(&self) -> &'static str;
33
34    /// Get the tool's description.
35    fn description(&self) -> &'static str;
36
37    /// Get the input JSON schema.
38    fn input_schema(&self) -> Schema;
39
40    /// Get the output JSON schema (if available).
41    fn output_schema(&self) -> Option<Schema>;
42
43    /// Call the tool with JSON arguments.
44    fn call_json(
45        &self,
46        args: Value,
47        ctx: &ToolContext,
48    ) -> BoxFuture<'static, Result<Value, ToolError>>;
49
50    /// Call the tool with JSON arguments, returning both JSON data and formatted text.
51    ///
52    /// This method enables dual output for MCP and NAPI servers. The text is derived
53    /// from the tool's `TextFormat` implementation if available, otherwise it falls back
54    /// to pretty-printed JSON.
55    fn call_json_formatted(
56        &self,
57        args: Value,
58        ctx: &ToolContext,
59        text_opts: &TextOptions,
60    ) -> BoxFuture<'static, Result<FormattedResult, ToolError>>;
61
62    /// Get the `TypeId` for type-safe handle retrieval.
63    fn type_id(&self) -> TypeId;
64}
65
66/// Registry of tools for dynamic dispatch and type-safe native calls.
67pub struct ToolRegistry {
68    map: HashMap<String, Arc<dyn ErasedTool>>,
69    by_type: HashMap<TypeId, String>,
70}
71
72impl ToolRegistry {
73    /// Create a new registry builder.
74    pub fn builder() -> ToolRegistryBuilder {
75        ToolRegistryBuilder::default()
76    }
77
78    /// List all tool names in the registry.
79    pub fn list_names(&self) -> Vec<String> {
80        self.map.keys().cloned().collect()
81    }
82
83    /// Get a tool by name.
84    pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool>> {
85        self.map.get(name)
86    }
87
88    /// Create a subset registry containing only the specified tools.
89    ///
90    /// Tools not found in the registry are silently ignored.
91    #[must_use]
92    pub fn subset<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
93        let allowed: HashSet<&str> = names.into_iter().collect();
94
95        // Copy the allowed entries into the new map
96        let mut map = HashMap::new();
97        for (k, v) in &self.map {
98            if allowed.contains(k.as_str()) {
99                map.insert(k.clone(), Arc::clone(v));
100            }
101        }
102
103        // Reuse original TypeIds from by_type (don't recompute via trait object
104        // to avoid cross-crate monomorphization issues with TypeId)
105        let mut by_type = HashMap::new();
106        for (type_id, name) in &self.by_type {
107            if allowed.contains(name.as_str()) {
108                by_type.insert(*type_id, name.clone());
109            }
110        }
111
112        Self { map, by_type }
113    }
114
115    /// Dispatch a tool call using JSON arguments.
116    pub async fn dispatch_json(
117        &self,
118        name: &str,
119        args: Value,
120        ctx: &ToolContext,
121    ) -> Result<Value, ToolError> {
122        let entry = self
123            .map
124            .get(name)
125            .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {name}")))?;
126        entry.call_json(args, ctx).await
127    }
128
129    /// Dispatch a tool call using JSON arguments, returning both JSON data and formatted text.
130    ///
131    /// This method enables dual output for MCP and NAPI servers. The text is derived
132    /// from the tool's `TextFormat` implementation if available, otherwise it falls back
133    /// to pretty-printed JSON.
134    pub async fn dispatch_json_formatted(
135        &self,
136        name: &str,
137        args: Value,
138        ctx: &ToolContext,
139        text_opts: &TextOptions,
140    ) -> Result<FormattedResult, ToolError> {
141        let entry = self
142            .map
143            .get(name)
144            .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {name}")))?;
145        entry.call_json_formatted(args, ctx, text_opts).await
146    }
147
148    /// Get a type-safe handle for calling a tool natively (zero JSON).
149    ///
150    /// Returns an error if the tool type is not registered.
151    pub fn handle<T: Tool>(&self) -> Result<ToolHandle<T>, ToolError> {
152        let type_id = TypeId::of::<T>();
153        self.by_type.get(&type_id).ok_or_else(|| {
154            ToolError::invalid_input(format!(
155                "Tool type not registered: {}",
156                std::any::type_name::<T>()
157            ))
158        })?;
159        Ok(ToolHandle {
160            _marker: PhantomData,
161        })
162    }
163
164    /// Check if a tool is registered by name.
165    pub fn contains(&self, name: &str) -> bool {
166        self.map.contains_key(name)
167    }
168
169    /// Get the number of registered tools.
170    pub fn len(&self) -> usize {
171        self.map.len()
172    }
173
174    /// Check if the registry is empty.
175    pub fn is_empty(&self) -> bool {
176        self.map.is_empty()
177    }
178
179    /// Clone and return erased tool entries (Arc) for composition.
180    ///
181    /// This enables merging multiple registries by iterating over their
182    /// erased tool entries and re-registering them in a new registry.
183    pub fn iter_erased(&self) -> Vec<Arc<dyn ErasedTool>> {
184        self.map.values().cloned().collect()
185    }
186
187    /// Merge multiple registries into one.
188    ///
189    /// Later entries with duplicate names overwrite earlier ones.
190    /// This is useful for composing domain-specific registries into
191    /// a unified registry.
192    pub fn merge_all(regs: impl IntoIterator<Item = Self>) -> Self {
193        let mut builder = Self::builder();
194        for reg in regs {
195            for erased in reg.iter_erased() {
196                builder = builder.register_erased(erased);
197            }
198        }
199        builder.finish()
200    }
201}
202
203/// Builder for constructing a [`ToolRegistry`].
204#[derive(Default)]
205pub struct ToolRegistryBuilder {
206    items: Vec<(String, TypeId, Arc<dyn ErasedTool>)>,
207}
208
209impl ToolRegistryBuilder {
210    /// Register a tool with its codec.
211    ///
212    /// Use `()` as the codec when the tool's Input/Output types
213    /// already implement serde and schemars traits.
214    ///
215    /// The tool's output type must implement [`TextFormat`] for human-readable
216    /// formatting. Types can override `fmt_text()` for custom formatting, or
217    /// use the default which produces pretty-printed JSON.
218    #[must_use]
219    pub fn register<T, C>(mut self, tool: T) -> Self
220    where
221        T: Tool + Clone + 'static,
222        C: ToolCodec<T> + 'static,
223        T::Output: TextFormat,
224    {
225        struct Impl<T: Tool + Clone, C: ToolCodec<T>> {
226            tool: T,
227            _codec: PhantomData<C>,
228        }
229
230        impl<T: Tool + Clone, C: ToolCodec<T>> ErasedTool for Impl<T, C>
231        where
232            T::Output: TextFormat,
233        {
234            fn name(&self) -> &'static str {
235                T::NAME
236            }
237
238            fn description(&self) -> &'static str {
239                T::DESCRIPTION
240            }
241
242            fn input_schema(&self) -> Schema {
243                mcp_schema::cached_schema_for::<C::WireIn>()
244                    .as_ref()
245                    .clone()
246            }
247
248            fn output_schema(&self) -> Option<Schema> {
249                match mcp_schema::cached_output_schema_for::<C::WireOut>() {
250                    Ok(arc) => Some(arc.as_ref().clone()),
251                    Err(_) => None,
252                }
253            }
254
255            fn call_json(
256                &self,
257                args: Value,
258                ctx: &ToolContext,
259            ) -> BoxFuture<'static, Result<Value, ToolError>> {
260                let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
261                let ctx = ctx.clone();
262                let tool = self.tool.clone();
263
264                match wire_in {
265                    Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
266                    Ok(wire) => match C::decode(wire) {
267                        Err(e) => Box::pin(async move { Err(e) }),
268                        Ok(native_in) => {
269                            let fut = tool.call(native_in, &ctx);
270                            Box::pin(async move {
271                                let out = fut.await?;
272                                let wired = C::encode(out)?;
273                                serde_json::to_value(wired)
274                                    .map_err(|e| ToolError::internal(e.to_string()))
275                            })
276                        }
277                    },
278                }
279            }
280
281            fn call_json_formatted(
282                &self,
283                args: Value,
284                ctx: &ToolContext,
285                text_opts: &TextOptions,
286            ) -> BoxFuture<'static, Result<FormattedResult, ToolError>> {
287                let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
288                let ctx = ctx.clone();
289                let tool = self.tool.clone();
290                let text_opts = text_opts.clone();
291
292                match wire_in {
293                    Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
294                    Ok(wire) => match C::decode(wire) {
295                        Err(e) => Box::pin(async move { Err(e) }),
296                        Ok(native_in) => {
297                            let fut = tool.call(native_in, &ctx);
298                            Box::pin(async move {
299                                let out = fut.await?;
300                                // Format text from the native output using TextFormat
301                                let text = out.fmt_text(&text_opts);
302                                // Then encode to wire and JSON-serialize for data
303                                let wired = C::encode(out)?;
304                                let data = serde_json::to_value(&wired)
305                                    .map_err(|e| ToolError::internal(e.to_string()))?;
306                                Ok(FormattedResult {
307                                    data,
308                                    text: Some(text),
309                                })
310                            })
311                        }
312                    },
313                }
314            }
315
316            fn type_id(&self) -> TypeId {
317                TypeId::of::<T>()
318            }
319        }
320
321        let erased: Arc<dyn ErasedTool> = Arc::new(Impl::<T, C> {
322            tool,
323            _codec: PhantomData,
324        });
325        self.items
326            .push((T::NAME.to_string(), TypeId::of::<T>(), erased));
327        self
328    }
329
330    /// Register an already-erased tool entry.
331    ///
332    /// This enables merging registries by iterating over their erased tools
333    /// and re-registering them without needing the concrete tool types.
334    #[must_use]
335    pub fn register_erased(mut self, erased: Arc<dyn ErasedTool>) -> Self {
336        let name = erased.name().to_string();
337        let type_id = erased.type_id();
338        self.items.push((name, type_id, erased));
339        self
340    }
341
342    /// Build the registry from registered tools.
343    pub fn finish(self) -> ToolRegistry {
344        let mut map = HashMap::new();
345        let mut by_type = HashMap::new();
346        for (name, type_id, erased) in self.items {
347            by_type.insert(type_id, name.clone());
348            map.insert(name, erased);
349        }
350        ToolRegistry { map, by_type }
351    }
352}
353
354/// Type-safe handle for calling a tool natively without JSON serialization.
355///
356/// Obtained from [`ToolRegistry::handle`].
357pub struct ToolHandle<T: Tool> {
358    _marker: PhantomData<T>,
359}
360
361impl<T: Tool> ToolHandle<T> {
362    /// Call the tool directly with native types (zero JSON serialization).
363    pub async fn call(
364        &self,
365        tool: &T,
366        input: T::Input,
367        ctx: &ToolContext,
368    ) -> Result<T::Output, ToolError> {
369        tool.call(input, ctx).await
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[derive(Clone)]
378    struct TestTool;
379
380    impl Tool for TestTool {
381        type Input = String;
382        type Output = String;
383        const NAME: &'static str = "test_tool";
384        const DESCRIPTION: &'static str = "A test tool";
385
386        fn call(
387            &self,
388            input: Self::Input,
389            _ctx: &ToolContext,
390        ) -> BoxFuture<'static, Result<Self::Output, ToolError>> {
391            Box::pin(async move { Ok(format!("Hello, {input}!")) })
392        }
393    }
394
395    #[test]
396    fn test_registry_builder() {
397        let registry = ToolRegistry::builder()
398            .register::<TestTool, ()>(TestTool)
399            .finish();
400
401        assert!(registry.contains("test_tool"));
402        assert_eq!(registry.len(), 1);
403        assert!(!registry.is_empty());
404    }
405
406    #[test]
407    fn test_registry_list_names() {
408        let registry = ToolRegistry::builder()
409            .register::<TestTool, ()>(TestTool)
410            .finish();
411
412        let names = registry.list_names();
413        assert_eq!(names, vec!["test_tool"]);
414    }
415
416    #[test]
417    fn test_registry_subset() {
418        let registry = ToolRegistry::builder()
419            .register::<TestTool, ()>(TestTool)
420            .finish();
421
422        let subset = registry.subset(["test_tool"]);
423        assert!(subset.contains("test_tool"));
424
425        let empty_subset = registry.subset(["nonexistent"]);
426        assert!(empty_subset.is_empty());
427    }
428
429    #[test]
430    fn test_tool_handle() {
431        let registry = ToolRegistry::builder()
432            .register::<TestTool, ()>(TestTool)
433            .finish();
434
435        let handle = registry.handle::<TestTool>();
436        assert!(handle.is_ok());
437    }
438
439    #[tokio::test]
440    async fn test_dispatch_json_formatted() {
441        let registry = ToolRegistry::builder()
442            .register::<TestTool, ()>(TestTool)
443            .finish();
444
445        let ctx = ToolContext::default();
446        let args = serde_json::json!("World");
447        let opts = TextOptions::default();
448
449        let result = registry
450            .dispatch_json_formatted("test_tool", args, &ctx, &opts)
451            .await;
452
453        assert!(result.is_ok());
454        let formatted = result.unwrap();
455        assert_eq!(formatted.data, serde_json::json!("Hello, World!"));
456        assert!(formatted.text.is_some());
457        // Text should be pretty-printed JSON
458        assert!(formatted.text.unwrap().contains("Hello, World!"));
459    }
460
461    #[tokio::test]
462    async fn test_dispatch_json_formatted_unknown_tool() {
463        let registry = ToolRegistry::builder()
464            .register::<TestTool, ()>(TestTool)
465            .finish();
466
467        let ctx = ToolContext::default();
468        let args = serde_json::json!("test");
469        let opts = TextOptions::default();
470
471        let result = registry
472            .dispatch_json_formatted("nonexistent", args, &ctx, &opts)
473            .await;
474
475        assert!(result.is_err());
476    }
477
478    #[test]
479    fn test_iter_erased() {
480        let registry = ToolRegistry::builder()
481            .register::<TestTool, ()>(TestTool)
482            .finish();
483
484        let erased = registry.iter_erased();
485        assert_eq!(erased.len(), 1);
486        assert_eq!(erased[0].name(), "test_tool");
487    }
488
489    #[test]
490    fn test_register_erased_roundtrip() {
491        // Create a registry with a tool
492        let r1 = ToolRegistry::builder()
493            .register::<TestTool, ()>(TestTool)
494            .finish();
495
496        // Extract erased tool and re-register
497        let erased = r1.iter_erased().into_iter().next().unwrap();
498        let r2 = ToolRegistry::builder().register_erased(erased).finish();
499
500        // Verify the tool was re-registered correctly
501        assert_eq!(r2.len(), 1);
502        assert!(r2.contains("test_tool"));
503        assert_eq!(r2.get("test_tool").unwrap().name(), "test_tool");
504    }
505
506    #[test]
507    fn test_merge_all_combines_registries() {
508        // Create two registries with the same tool (simulating domain registries)
509        let r1 = ToolRegistry::builder()
510            .register::<TestTool, ()>(TestTool)
511            .finish();
512        let r2 = ToolRegistry::builder()
513            .register::<TestTool, ()>(TestTool)
514            .finish();
515
516        // Merge them
517        let merged = ToolRegistry::merge_all(vec![r1, r2]);
518
519        // Duplicate names should result in last-wins (still only one tool)
520        assert_eq!(merged.len(), 1);
521        assert!(merged.contains("test_tool"));
522    }
523
524    #[test]
525    fn test_merge_all_empty() {
526        let merged = ToolRegistry::merge_all(Vec::<ToolRegistry>::new());
527        assert!(merged.is_empty());
528    }
529
530    #[test]
531    fn test_merge_all_preserves_subset() {
532        let r1 = ToolRegistry::builder()
533            .register::<TestTool, ()>(TestTool)
534            .finish();
535
536        let merged = ToolRegistry::merge_all(vec![r1]);
537        let subset = merged.subset(["test_tool"]);
538
539        assert_eq!(subset.len(), 1);
540        assert!(subset.contains("test_tool"));
541    }
542}