Skip to main content

agentic_tools_core/
registry.rs

1//! Tool registry for dynamic dispatch and type-safe native calls.
2
3use crate::context::ToolContext;
4use crate::error::ToolError;
5use crate::fmt::{TextFormat, TextOptions};
6use crate::schema::mcp_schema;
7use crate::tool::{Tool, ToolCodec};
8use futures::future::BoxFuture;
9use schemars::Schema;
10use serde_json::Value;
11use std::any::TypeId;
12use std::collections::{HashMap, HashSet};
13use std::marker::PhantomData;
14use std::sync::Arc;
15
16/// Result from dispatch_json_formatted containing both JSON data and optional text.
17#[derive(Debug, Clone)]
18pub struct FormattedResult {
19    /// The JSON-serialized output data.
20    pub data: Value,
21    /// Human-readable text representation. None if no TextFormat implementation exists
22    /// and fallback wasn't requested.
23    pub text: Option<String>,
24}
25
26/// Type-erased tool for dynamic dispatch.
27pub trait ErasedTool: Send + Sync {
28    /// Get the tool's name.
29    fn name(&self) -> &'static str;
30
31    /// Get the tool's description.
32    fn description(&self) -> &'static str;
33
34    /// Get the input JSON schema.
35    fn input_schema(&self) -> Schema;
36
37    /// Get the output JSON schema (if available).
38    fn output_schema(&self) -> Option<Schema>;
39
40    /// Call the tool with JSON arguments.
41    fn call_json(
42        &self,
43        args: Value,
44        ctx: &ToolContext,
45    ) -> BoxFuture<'static, Result<Value, ToolError>>;
46
47    /// Call the tool with JSON arguments, returning both JSON data and formatted text.
48    ///
49    /// This method enables dual output for MCP and NAPI servers. The text is derived
50    /// from the tool's TextFormat implementation if available, otherwise it falls back
51    /// to pretty-printed JSON.
52    fn call_json_formatted(
53        &self,
54        args: Value,
55        ctx: &ToolContext,
56        text_opts: &TextOptions,
57    ) -> BoxFuture<'static, Result<FormattedResult, ToolError>>;
58
59    /// Get the TypeId for type-safe handle retrieval.
60    fn type_id(&self) -> TypeId;
61}
62
63/// Registry of tools for dynamic dispatch and type-safe native calls.
64pub struct ToolRegistry {
65    map: HashMap<String, Arc<dyn ErasedTool>>,
66    by_type: HashMap<TypeId, String>,
67}
68
69impl ToolRegistry {
70    /// Create a new registry builder.
71    pub fn builder() -> ToolRegistryBuilder {
72        ToolRegistryBuilder::default()
73    }
74
75    /// List all tool names in the registry.
76    pub fn list_names(&self) -> Vec<String> {
77        self.map.keys().cloned().collect()
78    }
79
80    /// Get a tool by name.
81    pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool>> {
82        self.map.get(name)
83    }
84
85    /// Create a subset registry containing only the specified tools.
86    ///
87    /// Tools not found in the registry are silently ignored.
88    pub fn subset<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> ToolRegistry {
89        let allowed: HashSet<&str> = names.into_iter().collect();
90
91        // Copy the allowed entries into the new map
92        let mut map = HashMap::new();
93        for (k, v) in &self.map {
94            if allowed.contains(k.as_str()) {
95                map.insert(k.clone(), v.clone());
96            }
97        }
98
99        // Reuse original TypeIds from by_type (don't recompute via trait object
100        // to avoid cross-crate monomorphization issues with TypeId)
101        let mut by_type = HashMap::new();
102        for (type_id, name) in &self.by_type {
103            if allowed.contains(name.as_str()) {
104                by_type.insert(*type_id, name.clone());
105            }
106        }
107
108        ToolRegistry { map, by_type }
109    }
110
111    /// Dispatch a tool call using JSON arguments.
112    pub async fn dispatch_json(
113        &self,
114        name: &str,
115        args: Value,
116        ctx: &ToolContext,
117    ) -> Result<Value, ToolError> {
118        let entry = self
119            .map
120            .get(name)
121            .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {}", name)))?;
122        entry.call_json(args, ctx).await
123    }
124
125    /// Dispatch a tool call using JSON arguments, returning both JSON data and formatted text.
126    ///
127    /// This method enables dual output for MCP and NAPI servers. The text is derived
128    /// from the tool's TextFormat implementation if available, otherwise it falls back
129    /// to pretty-printed JSON.
130    pub async fn dispatch_json_formatted(
131        &self,
132        name: &str,
133        args: Value,
134        ctx: &ToolContext,
135        text_opts: &TextOptions,
136    ) -> Result<FormattedResult, ToolError> {
137        let entry = self
138            .map
139            .get(name)
140            .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {}", name)))?;
141        entry.call_json_formatted(args, ctx, text_opts).await
142    }
143
144    /// Get a type-safe handle for calling a tool natively (zero JSON).
145    ///
146    /// Returns an error if the tool type is not registered.
147    pub fn handle<T: Tool>(&self) -> Result<ToolHandle<T>, ToolError> {
148        let type_id = TypeId::of::<T>();
149        self.by_type.get(&type_id).ok_or_else(|| {
150            ToolError::invalid_input(format!(
151                "Tool type not registered: {}",
152                std::any::type_name::<T>()
153            ))
154        })?;
155        Ok(ToolHandle {
156            _marker: PhantomData,
157        })
158    }
159
160    /// Check if a tool is registered by name.
161    pub fn contains(&self, name: &str) -> bool {
162        self.map.contains_key(name)
163    }
164
165    /// Get the number of registered tools.
166    pub fn len(&self) -> usize {
167        self.map.len()
168    }
169
170    /// Check if the registry is empty.
171    pub fn is_empty(&self) -> bool {
172        self.map.is_empty()
173    }
174
175    /// Clone and return erased tool entries (Arc) for composition.
176    ///
177    /// This enables merging multiple registries by iterating over their
178    /// erased tool entries and re-registering them in a new registry.
179    pub fn iter_erased(&self) -> Vec<Arc<dyn ErasedTool>> {
180        self.map.values().cloned().collect()
181    }
182
183    /// Merge multiple registries into one.
184    ///
185    /// Later entries with duplicate names overwrite earlier ones.
186    /// This is useful for composing domain-specific registries into
187    /// a unified registry.
188    pub fn merge_all(regs: impl IntoIterator<Item = ToolRegistry>) -> ToolRegistry {
189        let mut builder = ToolRegistry::builder();
190        for reg in regs {
191            for erased in reg.iter_erased() {
192                builder = builder.register_erased(erased);
193            }
194        }
195        builder.finish()
196    }
197}
198
199/// Builder for constructing a [`ToolRegistry`].
200#[derive(Default)]
201pub struct ToolRegistryBuilder {
202    items: Vec<(String, TypeId, Arc<dyn ErasedTool>)>,
203}
204
205impl ToolRegistryBuilder {
206    /// Register a tool with its codec.
207    ///
208    /// Use `()` as the codec when the tool's Input/Output types
209    /// already implement serde and schemars traits.
210    ///
211    /// The tool's output type must implement [`TextFormat`] for human-readable
212    /// formatting. Types can override `fmt_text()` for custom formatting, or
213    /// use the default which produces pretty-printed JSON.
214    pub fn register<T, C>(mut self, tool: T) -> Self
215    where
216        T: Tool + Clone + 'static,
217        C: ToolCodec<T> + 'static,
218        T::Output: TextFormat,
219    {
220        struct Impl<T: Tool + Clone, C: ToolCodec<T>> {
221            tool: T,
222            _codec: PhantomData<C>,
223        }
224
225        impl<T: Tool + Clone, C: ToolCodec<T>> ErasedTool for Impl<T, C>
226        where
227            T::Output: TextFormat,
228        {
229            fn name(&self) -> &'static str {
230                T::NAME
231            }
232
233            fn description(&self) -> &'static str {
234                T::DESCRIPTION
235            }
236
237            fn input_schema(&self) -> Schema {
238                mcp_schema::cached_schema_for::<C::WireIn>()
239                    .as_ref()
240                    .clone()
241            }
242
243            fn output_schema(&self) -> Option<Schema> {
244                match mcp_schema::cached_output_schema_for::<C::WireOut>() {
245                    Ok(arc) => Some(arc.as_ref().clone()),
246                    Err(_) => None,
247                }
248            }
249
250            fn call_json(
251                &self,
252                args: Value,
253                ctx: &ToolContext,
254            ) -> BoxFuture<'static, Result<Value, ToolError>> {
255                let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
256                let ctx = ctx.clone();
257                let tool = self.tool.clone();
258
259                match wire_in {
260                    Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
261                    Ok(wire) => match C::decode(wire) {
262                        Err(e) => Box::pin(async move { Err(e) }),
263                        Ok(native_in) => {
264                            let fut = tool.call(native_in, &ctx);
265                            Box::pin(async move {
266                                let out = fut.await?;
267                                let wired = C::encode(out)?;
268                                serde_json::to_value(wired)
269                                    .map_err(|e| ToolError::internal(e.to_string()))
270                            })
271                        }
272                    },
273                }
274            }
275
276            fn call_json_formatted(
277                &self,
278                args: Value,
279                ctx: &ToolContext,
280                text_opts: &TextOptions,
281            ) -> BoxFuture<'static, Result<FormattedResult, ToolError>> {
282                let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
283                let ctx = ctx.clone();
284                let tool = self.tool.clone();
285                let text_opts = text_opts.clone();
286
287                match wire_in {
288                    Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
289                    Ok(wire) => match C::decode(wire) {
290                        Err(e) => Box::pin(async move { Err(e) }),
291                        Ok(native_in) => {
292                            let fut = tool.call(native_in, &ctx);
293                            Box::pin(async move {
294                                let out = fut.await?;
295                                // Format text from the native output using TextFormat
296                                let text = out.fmt_text(&text_opts);
297                                // Then encode to wire and JSON-serialize for data
298                                let wired = C::encode(out)?;
299                                let data = serde_json::to_value(&wired)
300                                    .map_err(|e| ToolError::internal(e.to_string()))?;
301                                Ok(FormattedResult {
302                                    data,
303                                    text: Some(text),
304                                })
305                            })
306                        }
307                    },
308                }
309            }
310
311            fn type_id(&self) -> TypeId {
312                TypeId::of::<T>()
313            }
314        }
315
316        let erased: Arc<dyn ErasedTool> = Arc::new(Impl::<T, C> {
317            tool,
318            _codec: PhantomData,
319        });
320        self.items
321            .push((T::NAME.to_string(), TypeId::of::<T>(), erased));
322        self
323    }
324
325    /// Register an already-erased tool entry.
326    ///
327    /// This enables merging registries by iterating over their erased tools
328    /// and re-registering them without needing the concrete tool types.
329    pub fn register_erased(mut self, erased: Arc<dyn ErasedTool>) -> Self {
330        let name = erased.name().to_string();
331        let type_id = erased.type_id();
332        self.items.push((name, type_id, erased));
333        self
334    }
335
336    /// Build the registry from registered tools.
337    pub fn finish(self) -> ToolRegistry {
338        let mut map = HashMap::new();
339        let mut by_type = HashMap::new();
340        for (name, type_id, erased) in self.items {
341            by_type.insert(type_id, name.clone());
342            map.insert(name, erased);
343        }
344        ToolRegistry { map, by_type }
345    }
346}
347
348/// Type-safe handle for calling a tool natively without JSON serialization.
349///
350/// Obtained from [`ToolRegistry::handle`].
351pub struct ToolHandle<T: Tool> {
352    _marker: PhantomData<T>,
353}
354
355impl<T: Tool> ToolHandle<T> {
356    /// Call the tool directly with native types (zero JSON serialization).
357    pub async fn call(
358        &self,
359        tool: &T,
360        input: T::Input,
361        ctx: &ToolContext,
362    ) -> Result<T::Output, ToolError> {
363        tool.call(input, ctx).await
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[derive(Clone)]
372    struct TestTool;
373
374    impl Tool for TestTool {
375        type Input = String;
376        type Output = String;
377        const NAME: &'static str = "test_tool";
378        const DESCRIPTION: &'static str = "A test tool";
379
380        fn call(
381            &self,
382            input: Self::Input,
383            _ctx: &ToolContext,
384        ) -> BoxFuture<'static, Result<Self::Output, ToolError>> {
385            Box::pin(async move { Ok(format!("Hello, {}!", input)) })
386        }
387    }
388
389    #[test]
390    fn test_registry_builder() {
391        let registry = ToolRegistry::builder()
392            .register::<TestTool, ()>(TestTool)
393            .finish();
394
395        assert!(registry.contains("test_tool"));
396        assert_eq!(registry.len(), 1);
397        assert!(!registry.is_empty());
398    }
399
400    #[test]
401    fn test_registry_list_names() {
402        let registry = ToolRegistry::builder()
403            .register::<TestTool, ()>(TestTool)
404            .finish();
405
406        let names = registry.list_names();
407        assert_eq!(names, vec!["test_tool"]);
408    }
409
410    #[test]
411    fn test_registry_subset() {
412        let registry = ToolRegistry::builder()
413            .register::<TestTool, ()>(TestTool)
414            .finish();
415
416        let subset = registry.subset(["test_tool"]);
417        assert!(subset.contains("test_tool"));
418
419        let empty_subset = registry.subset(["nonexistent"]);
420        assert!(empty_subset.is_empty());
421    }
422
423    #[test]
424    fn test_tool_handle() {
425        let registry = ToolRegistry::builder()
426            .register::<TestTool, ()>(TestTool)
427            .finish();
428
429        let handle = registry.handle::<TestTool>();
430        assert!(handle.is_ok());
431    }
432
433    #[tokio::test]
434    async fn test_dispatch_json_formatted() {
435        let registry = ToolRegistry::builder()
436            .register::<TestTool, ()>(TestTool)
437            .finish();
438
439        let ctx = ToolContext::default();
440        let args = serde_json::json!("World");
441        let opts = TextOptions::default();
442
443        let result = registry
444            .dispatch_json_formatted("test_tool", args, &ctx, &opts)
445            .await;
446
447        assert!(result.is_ok());
448        let formatted = result.unwrap();
449        assert_eq!(formatted.data, serde_json::json!("Hello, World!"));
450        assert!(formatted.text.is_some());
451        // Text should be pretty-printed JSON
452        assert!(formatted.text.unwrap().contains("Hello, World!"));
453    }
454
455    #[tokio::test]
456    async fn test_dispatch_json_formatted_unknown_tool() {
457        let registry = ToolRegistry::builder()
458            .register::<TestTool, ()>(TestTool)
459            .finish();
460
461        let ctx = ToolContext::default();
462        let args = serde_json::json!("test");
463        let opts = TextOptions::default();
464
465        let result = registry
466            .dispatch_json_formatted("nonexistent", args, &ctx, &opts)
467            .await;
468
469        assert!(result.is_err());
470    }
471
472    #[test]
473    fn test_iter_erased() {
474        let registry = ToolRegistry::builder()
475            .register::<TestTool, ()>(TestTool)
476            .finish();
477
478        let erased = registry.iter_erased();
479        assert_eq!(erased.len(), 1);
480        assert_eq!(erased[0].name(), "test_tool");
481    }
482
483    #[test]
484    fn test_register_erased_roundtrip() {
485        // Create a registry with a tool
486        let r1 = ToolRegistry::builder()
487            .register::<TestTool, ()>(TestTool)
488            .finish();
489
490        // Extract erased tool and re-register
491        let erased = r1.iter_erased().into_iter().next().unwrap();
492        let r2 = ToolRegistry::builder().register_erased(erased).finish();
493
494        // Verify the tool was re-registered correctly
495        assert_eq!(r2.len(), 1);
496        assert!(r2.contains("test_tool"));
497        assert_eq!(r2.get("test_tool").unwrap().name(), "test_tool");
498    }
499
500    #[test]
501    fn test_merge_all_combines_registries() {
502        // Create two registries with the same tool (simulating domain registries)
503        let r1 = ToolRegistry::builder()
504            .register::<TestTool, ()>(TestTool)
505            .finish();
506        let r2 = ToolRegistry::builder()
507            .register::<TestTool, ()>(TestTool)
508            .finish();
509
510        // Merge them
511        let merged = ToolRegistry::merge_all(vec![r1, r2]);
512
513        // Duplicate names should result in last-wins (still only one tool)
514        assert_eq!(merged.len(), 1);
515        assert!(merged.contains("test_tool"));
516    }
517
518    #[test]
519    fn test_merge_all_empty() {
520        let merged = ToolRegistry::merge_all(Vec::<ToolRegistry>::new());
521        assert!(merged.is_empty());
522    }
523
524    #[test]
525    fn test_merge_all_preserves_subset() {
526        let r1 = ToolRegistry::builder()
527            .register::<TestTool, ()>(TestTool)
528            .finish();
529
530        let merged = ToolRegistry::merge_all(vec![r1]);
531        let subset = merged.subset(["test_tool"]);
532
533        assert_eq!(subset.len(), 1);
534        assert!(subset.contains("test_tool"));
535    }
536}