Skip to main content

agentic_tools_core/
registry.rs

1//! Tool registry for dynamic dispatch and type-safe native calls.
2
3use crate::context::ToolContext;
4use crate::error::ToolError;
5use crate::fmt::TextFormat;
6use crate::fmt::TextOptions;
7use crate::schema::mcp_schema;
8use crate::tool::Tool;
9use crate::tool::ToolCodec;
10use futures::future::BoxFuture;
11use schemars::Schema;
12use serde_json::Value;
13use std::any::TypeId;
14use std::collections::HashMap;
15use std::collections::HashSet;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19/// Result from dispatch_json_formatted containing both JSON data and optional text.
20#[derive(Debug, Clone)]
21pub struct FormattedResult {
22    /// The JSON-serialized output data.
23    pub data: Value,
24    /// Human-readable text representation. None if no TextFormat implementation exists
25    /// and fallback wasn't requested.
26    pub text: Option<String>,
27}
28
29/// Type-erased tool for dynamic dispatch.
30pub trait ErasedTool: Send + Sync {
31    /// Get the tool's name.
32    fn name(&self) -> &'static str;
33
34    /// Get the tool's description.
35    fn description(&self) -> &'static str;
36
37    /// Get the input JSON schema.
38    fn input_schema(&self) -> Schema;
39
40    /// Get the output JSON schema (if available).
41    fn output_schema(&self) -> Option<Schema>;
42
43    /// Call the tool with JSON arguments.
44    fn call_json(
45        &self,
46        args: Value,
47        ctx: &ToolContext,
48    ) -> BoxFuture<'static, Result<Value, ToolError>>;
49
50    /// Call the tool with JSON arguments, returning both JSON data and formatted text.
51    ///
52    /// This method enables dual output for MCP and NAPI servers. The text is derived
53    /// from the tool's TextFormat implementation if available, otherwise it falls back
54    /// to pretty-printed JSON.
55    fn call_json_formatted(
56        &self,
57        args: Value,
58        ctx: &ToolContext,
59        text_opts: &TextOptions,
60    ) -> BoxFuture<'static, Result<FormattedResult, ToolError>>;
61
62    /// Get the TypeId for type-safe handle retrieval.
63    fn type_id(&self) -> TypeId;
64}
65
66/// Registry of tools for dynamic dispatch and type-safe native calls.
67pub struct ToolRegistry {
68    map: HashMap<String, Arc<dyn ErasedTool>>,
69    by_type: HashMap<TypeId, String>,
70}
71
72impl ToolRegistry {
73    /// Create a new registry builder.
74    pub fn builder() -> ToolRegistryBuilder {
75        ToolRegistryBuilder::default()
76    }
77
78    /// List all tool names in the registry.
79    pub fn list_names(&self) -> Vec<String> {
80        self.map.keys().cloned().collect()
81    }
82
83    /// Get a tool by name.
84    pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool>> {
85        self.map.get(name)
86    }
87
88    /// Create a subset registry containing only the specified tools.
89    ///
90    /// Tools not found in the registry are silently ignored.
91    pub fn subset<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> ToolRegistry {
92        let allowed: HashSet<&str> = names.into_iter().collect();
93
94        // Copy the allowed entries into the new map
95        let mut map = HashMap::new();
96        for (k, v) in &self.map {
97            if allowed.contains(k.as_str()) {
98                map.insert(k.clone(), v.clone());
99            }
100        }
101
102        // Reuse original TypeIds from by_type (don't recompute via trait object
103        // to avoid cross-crate monomorphization issues with TypeId)
104        let mut by_type = HashMap::new();
105        for (type_id, name) in &self.by_type {
106            if allowed.contains(name.as_str()) {
107                by_type.insert(*type_id, name.clone());
108            }
109        }
110
111        ToolRegistry { map, by_type }
112    }
113
114    /// Dispatch a tool call using JSON arguments.
115    pub async fn dispatch_json(
116        &self,
117        name: &str,
118        args: Value,
119        ctx: &ToolContext,
120    ) -> Result<Value, ToolError> {
121        let entry = self
122            .map
123            .get(name)
124            .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {}", name)))?;
125        entry.call_json(args, ctx).await
126    }
127
128    /// Dispatch a tool call using JSON arguments, returning both JSON data and formatted text.
129    ///
130    /// This method enables dual output for MCP and NAPI servers. The text is derived
131    /// from the tool's TextFormat implementation if available, otherwise it falls back
132    /// to pretty-printed JSON.
133    pub async fn dispatch_json_formatted(
134        &self,
135        name: &str,
136        args: Value,
137        ctx: &ToolContext,
138        text_opts: &TextOptions,
139    ) -> Result<FormattedResult, ToolError> {
140        let entry = self
141            .map
142            .get(name)
143            .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {}", name)))?;
144        entry.call_json_formatted(args, ctx, text_opts).await
145    }
146
147    /// Get a type-safe handle for calling a tool natively (zero JSON).
148    ///
149    /// Returns an error if the tool type is not registered.
150    pub fn handle<T: Tool>(&self) -> Result<ToolHandle<T>, ToolError> {
151        let type_id = TypeId::of::<T>();
152        self.by_type.get(&type_id).ok_or_else(|| {
153            ToolError::invalid_input(format!(
154                "Tool type not registered: {}",
155                std::any::type_name::<T>()
156            ))
157        })?;
158        Ok(ToolHandle {
159            _marker: PhantomData,
160        })
161    }
162
163    /// Check if a tool is registered by name.
164    pub fn contains(&self, name: &str) -> bool {
165        self.map.contains_key(name)
166    }
167
168    /// Get the number of registered tools.
169    pub fn len(&self) -> usize {
170        self.map.len()
171    }
172
173    /// Check if the registry is empty.
174    pub fn is_empty(&self) -> bool {
175        self.map.is_empty()
176    }
177
178    /// Clone and return erased tool entries (Arc) for composition.
179    ///
180    /// This enables merging multiple registries by iterating over their
181    /// erased tool entries and re-registering them in a new registry.
182    pub fn iter_erased(&self) -> Vec<Arc<dyn ErasedTool>> {
183        self.map.values().cloned().collect()
184    }
185
186    /// Merge multiple registries into one.
187    ///
188    /// Later entries with duplicate names overwrite earlier ones.
189    /// This is useful for composing domain-specific registries into
190    /// a unified registry.
191    pub fn merge_all(regs: impl IntoIterator<Item = ToolRegistry>) -> ToolRegistry {
192        let mut builder = ToolRegistry::builder();
193        for reg in regs {
194            for erased in reg.iter_erased() {
195                builder = builder.register_erased(erased);
196            }
197        }
198        builder.finish()
199    }
200}
201
202/// Builder for constructing a [`ToolRegistry`].
203#[derive(Default)]
204pub struct ToolRegistryBuilder {
205    items: Vec<(String, TypeId, Arc<dyn ErasedTool>)>,
206}
207
208impl ToolRegistryBuilder {
209    /// Register a tool with its codec.
210    ///
211    /// Use `()` as the codec when the tool's Input/Output types
212    /// already implement serde and schemars traits.
213    ///
214    /// The tool's output type must implement [`TextFormat`] for human-readable
215    /// formatting. Types can override `fmt_text()` for custom formatting, or
216    /// use the default which produces pretty-printed JSON.
217    pub fn register<T, C>(mut self, tool: T) -> Self
218    where
219        T: Tool + Clone + 'static,
220        C: ToolCodec<T> + 'static,
221        T::Output: TextFormat,
222    {
223        struct Impl<T: Tool + Clone, C: ToolCodec<T>> {
224            tool: T,
225            _codec: PhantomData<C>,
226        }
227
228        impl<T: Tool + Clone, C: ToolCodec<T>> ErasedTool for Impl<T, C>
229        where
230            T::Output: TextFormat,
231        {
232            fn name(&self) -> &'static str {
233                T::NAME
234            }
235
236            fn description(&self) -> &'static str {
237                T::DESCRIPTION
238            }
239
240            fn input_schema(&self) -> Schema {
241                mcp_schema::cached_schema_for::<C::WireIn>()
242                    .as_ref()
243                    .clone()
244            }
245
246            fn output_schema(&self) -> Option<Schema> {
247                match mcp_schema::cached_output_schema_for::<C::WireOut>() {
248                    Ok(arc) => Some(arc.as_ref().clone()),
249                    Err(_) => None,
250                }
251            }
252
253            fn call_json(
254                &self,
255                args: Value,
256                ctx: &ToolContext,
257            ) -> BoxFuture<'static, Result<Value, ToolError>> {
258                let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
259                let ctx = ctx.clone();
260                let tool = self.tool.clone();
261
262                match wire_in {
263                    Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
264                    Ok(wire) => match C::decode(wire) {
265                        Err(e) => Box::pin(async move { Err(e) }),
266                        Ok(native_in) => {
267                            let fut = tool.call(native_in, &ctx);
268                            Box::pin(async move {
269                                let out = fut.await?;
270                                let wired = C::encode(out)?;
271                                serde_json::to_value(wired)
272                                    .map_err(|e| ToolError::internal(e.to_string()))
273                            })
274                        }
275                    },
276                }
277            }
278
279            fn call_json_formatted(
280                &self,
281                args: Value,
282                ctx: &ToolContext,
283                text_opts: &TextOptions,
284            ) -> BoxFuture<'static, Result<FormattedResult, ToolError>> {
285                let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
286                let ctx = ctx.clone();
287                let tool = self.tool.clone();
288                let text_opts = text_opts.clone();
289
290                match wire_in {
291                    Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
292                    Ok(wire) => match C::decode(wire) {
293                        Err(e) => Box::pin(async move { Err(e) }),
294                        Ok(native_in) => {
295                            let fut = tool.call(native_in, &ctx);
296                            Box::pin(async move {
297                                let out = fut.await?;
298                                // Format text from the native output using TextFormat
299                                let text = out.fmt_text(&text_opts);
300                                // Then encode to wire and JSON-serialize for data
301                                let wired = C::encode(out)?;
302                                let data = serde_json::to_value(&wired)
303                                    .map_err(|e| ToolError::internal(e.to_string()))?;
304                                Ok(FormattedResult {
305                                    data,
306                                    text: Some(text),
307                                })
308                            })
309                        }
310                    },
311                }
312            }
313
314            fn type_id(&self) -> TypeId {
315                TypeId::of::<T>()
316            }
317        }
318
319        let erased: Arc<dyn ErasedTool> = Arc::new(Impl::<T, C> {
320            tool,
321            _codec: PhantomData,
322        });
323        self.items
324            .push((T::NAME.to_string(), TypeId::of::<T>(), erased));
325        self
326    }
327
328    /// Register an already-erased tool entry.
329    ///
330    /// This enables merging registries by iterating over their erased tools
331    /// and re-registering them without needing the concrete tool types.
332    pub fn register_erased(mut self, erased: Arc<dyn ErasedTool>) -> Self {
333        let name = erased.name().to_string();
334        let type_id = erased.type_id();
335        self.items.push((name, type_id, erased));
336        self
337    }
338
339    /// Build the registry from registered tools.
340    pub fn finish(self) -> ToolRegistry {
341        let mut map = HashMap::new();
342        let mut by_type = HashMap::new();
343        for (name, type_id, erased) in self.items {
344            by_type.insert(type_id, name.clone());
345            map.insert(name, erased);
346        }
347        ToolRegistry { map, by_type }
348    }
349}
350
351/// Type-safe handle for calling a tool natively without JSON serialization.
352///
353/// Obtained from [`ToolRegistry::handle`].
354pub struct ToolHandle<T: Tool> {
355    _marker: PhantomData<T>,
356}
357
358impl<T: Tool> ToolHandle<T> {
359    /// Call the tool directly with native types (zero JSON serialization).
360    pub async fn call(
361        &self,
362        tool: &T,
363        input: T::Input,
364        ctx: &ToolContext,
365    ) -> Result<T::Output, ToolError> {
366        tool.call(input, ctx).await
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[derive(Clone)]
375    struct TestTool;
376
377    impl Tool for TestTool {
378        type Input = String;
379        type Output = String;
380        const NAME: &'static str = "test_tool";
381        const DESCRIPTION: &'static str = "A test tool";
382
383        fn call(
384            &self,
385            input: Self::Input,
386            _ctx: &ToolContext,
387        ) -> BoxFuture<'static, Result<Self::Output, ToolError>> {
388            Box::pin(async move { Ok(format!("Hello, {}!", input)) })
389        }
390    }
391
392    #[test]
393    fn test_registry_builder() {
394        let registry = ToolRegistry::builder()
395            .register::<TestTool, ()>(TestTool)
396            .finish();
397
398        assert!(registry.contains("test_tool"));
399        assert_eq!(registry.len(), 1);
400        assert!(!registry.is_empty());
401    }
402
403    #[test]
404    fn test_registry_list_names() {
405        let registry = ToolRegistry::builder()
406            .register::<TestTool, ()>(TestTool)
407            .finish();
408
409        let names = registry.list_names();
410        assert_eq!(names, vec!["test_tool"]);
411    }
412
413    #[test]
414    fn test_registry_subset() {
415        let registry = ToolRegistry::builder()
416            .register::<TestTool, ()>(TestTool)
417            .finish();
418
419        let subset = registry.subset(["test_tool"]);
420        assert!(subset.contains("test_tool"));
421
422        let empty_subset = registry.subset(["nonexistent"]);
423        assert!(empty_subset.is_empty());
424    }
425
426    #[test]
427    fn test_tool_handle() {
428        let registry = ToolRegistry::builder()
429            .register::<TestTool, ()>(TestTool)
430            .finish();
431
432        let handle = registry.handle::<TestTool>();
433        assert!(handle.is_ok());
434    }
435
436    #[tokio::test]
437    async fn test_dispatch_json_formatted() {
438        let registry = ToolRegistry::builder()
439            .register::<TestTool, ()>(TestTool)
440            .finish();
441
442        let ctx = ToolContext::default();
443        let args = serde_json::json!("World");
444        let opts = TextOptions::default();
445
446        let result = registry
447            .dispatch_json_formatted("test_tool", args, &ctx, &opts)
448            .await;
449
450        assert!(result.is_ok());
451        let formatted = result.unwrap();
452        assert_eq!(formatted.data, serde_json::json!("Hello, World!"));
453        assert!(formatted.text.is_some());
454        // Text should be pretty-printed JSON
455        assert!(formatted.text.unwrap().contains("Hello, World!"));
456    }
457
458    #[tokio::test]
459    async fn test_dispatch_json_formatted_unknown_tool() {
460        let registry = ToolRegistry::builder()
461            .register::<TestTool, ()>(TestTool)
462            .finish();
463
464        let ctx = ToolContext::default();
465        let args = serde_json::json!("test");
466        let opts = TextOptions::default();
467
468        let result = registry
469            .dispatch_json_formatted("nonexistent", args, &ctx, &opts)
470            .await;
471
472        assert!(result.is_err());
473    }
474
475    #[test]
476    fn test_iter_erased() {
477        let registry = ToolRegistry::builder()
478            .register::<TestTool, ()>(TestTool)
479            .finish();
480
481        let erased = registry.iter_erased();
482        assert_eq!(erased.len(), 1);
483        assert_eq!(erased[0].name(), "test_tool");
484    }
485
486    #[test]
487    fn test_register_erased_roundtrip() {
488        // Create a registry with a tool
489        let r1 = ToolRegistry::builder()
490            .register::<TestTool, ()>(TestTool)
491            .finish();
492
493        // Extract erased tool and re-register
494        let erased = r1.iter_erased().into_iter().next().unwrap();
495        let r2 = ToolRegistry::builder().register_erased(erased).finish();
496
497        // Verify the tool was re-registered correctly
498        assert_eq!(r2.len(), 1);
499        assert!(r2.contains("test_tool"));
500        assert_eq!(r2.get("test_tool").unwrap().name(), "test_tool");
501    }
502
503    #[test]
504    fn test_merge_all_combines_registries() {
505        // Create two registries with the same tool (simulating domain registries)
506        let r1 = ToolRegistry::builder()
507            .register::<TestTool, ()>(TestTool)
508            .finish();
509        let r2 = ToolRegistry::builder()
510            .register::<TestTool, ()>(TestTool)
511            .finish();
512
513        // Merge them
514        let merged = ToolRegistry::merge_all(vec![r1, r2]);
515
516        // Duplicate names should result in last-wins (still only one tool)
517        assert_eq!(merged.len(), 1);
518        assert!(merged.contains("test_tool"));
519    }
520
521    #[test]
522    fn test_merge_all_empty() {
523        let merged = ToolRegistry::merge_all(Vec::<ToolRegistry>::new());
524        assert!(merged.is_empty());
525    }
526
527    #[test]
528    fn test_merge_all_preserves_subset() {
529        let r1 = ToolRegistry::builder()
530            .register::<TestTool, ()>(TestTool)
531            .finish();
532
533        let merged = ToolRegistry::merge_all(vec![r1]);
534        let subset = merged.subset(["test_tool"]);
535
536        assert_eq!(subset.len(), 1);
537        assert!(subset.contains("test_tool"));
538    }
539}