Skip to main content

llm_tool/
registry.rs

1//! Tool registry: registry and concurrent dispatch of named tools.
2
3use std::collections::HashMap;
4
5use super::{
6    rust_tool::{ErasedTool, RustTool, definition_of},
7    types::{ToolContext, ToolDefinition, ToolError, ToolOutput},
8};
9
10/// Entry holding a cached [`ToolDefinition`] alongside the type-erased tool.
11///
12/// The definition is computed once at registration time so that
13/// [`ToolRegistry::definitions`] and [`ToolRegistry::iter`] never
14/// regenerate JSON schemas.
15struct RegisteredTool {
16    definition: ToolDefinition,
17    erased: Box<dyn ErasedTool>,
18}
19
20pub struct ToolRegistry {
21    tools: HashMap<&'static str, RegisteredTool>,
22}
23
24impl std::fmt::Debug for ToolRegistry {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        let names: Vec<&str> = self
27            .tools
28            .values()
29            .map(|r| r.definition.name.as_str())
30            .collect();
31        f.debug_struct("ToolRegistry")
32            .field("tool_count", &self.tools.len())
33            .field("tool_names", &names)
34            .finish()
35    }
36}
37
38impl Default for ToolRegistry {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl ToolRegistry {
45    /// Create an empty registry.
46    #[must_use]
47    pub fn new() -> Self {
48        Self {
49            tools: HashMap::new(),
50        }
51    }
52
53    /// Register a [`RustTool`]. Returns `&mut Self` for chaining.
54    ///
55    /// The tool's [`ToolDefinition`] (including JSON schema) is computed once
56    /// here and cached for the lifetime of the registration.
57    ///
58    /// If a tool with the same name was already registered, it is replaced.
59    ///
60    /// # Panics
61    ///
62    /// Panics if the tool's JSON schema cannot be serialized. This indicates a
63    /// bug in the tool's `Params` type (e.g. a broken `JsonSchema` impl).
64    pub fn register<T: RustTool + 'static>(&mut self, tool: T) -> &mut Self {
65        let definition = definition_of(&tool)
66            .unwrap_or_else(|e| panic!("Failed to build definition for tool '{}': {e}", T::NAME));
67        self.tools.insert(
68            T::NAME,
69            RegisteredTool {
70                definition,
71                erased: Box::new(tool),
72            },
73        );
74        self
75    }
76
77    /// Register a [`RustTool`], consuming and returning `Self` for owned chaining.
78    ///
79    /// This is the owned counterpart of [`register`](Self::register), enabling
80    /// patterns like:
81    /// ```
82    /// use llm_tool::{RustTool, ToolContext, ToolError, ToolOutput, ToolRegistry};
83    /// use schemars::JsonSchema;
84    /// use serde::Deserialize;
85    ///
86    /// #[derive(Deserialize, JsonSchema)]
87    /// struct NoParams {}
88    ///
89    /// struct ToolA;
90    /// impl RustTool for ToolA {
91    ///     type Params = NoParams;
92    ///     const NAME: &'static str = "tool_a";
93    ///     const DESCRIPTION: &'static str = "Tool A";
94    ///     async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
95    ///         Ok("a".into())
96    ///     }
97    /// }
98    ///
99    /// struct ToolB;
100    /// impl RustTool for ToolB {
101    ///     type Params = NoParams;
102    ///     const NAME: &'static str = "tool_b";
103    ///     const DESCRIPTION: &'static str = "Tool B";
104    ///     async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
105    ///         Ok("b".into())
106    ///     }
107    /// }
108    ///
109    /// let registry = ToolRegistry::new().with_tool(ToolA).with_tool(ToolB);
110    ///
111    /// assert_eq!(registry.definitions().len(), 2);
112    /// ```
113    #[must_use]
114    pub fn with_tool<T: RustTool + 'static>(mut self, tool: T) -> Self {
115        self.register(tool);
116        self
117    }
118
119    /// Collect [`ToolDefinition`]s for all registered tools.
120    ///
121    /// Returns clones of the cached definitions computed at registration time.
122    #[must_use]
123    pub fn definitions(&self) -> Vec<ToolDefinition> {
124        self.tools
125            .values()
126            .map(|entry| entry.definition.clone())
127            .collect()
128    }
129
130    /// Dispatch a tool call by name with raw JSON arguments and a context.
131    ///
132    /// # Errors
133    ///
134    /// Returns `Err` if the tool name is unknown or the handler returns an error.
135    pub async fn dispatch(
136        &self,
137        name: &str,
138        args: serde_json::Value,
139        ctx: &ToolContext,
140    ) -> Result<ToolOutput, ToolError> {
141        let entry = self
142            .tools
143            .get(name)
144            .ok_or_else(|| ToolError::new(format!("Unknown tool: {name}")))?;
145        entry.erased.call_erased(args, ctx).await
146    }
147
148    /// Number of registered tools.
149    #[must_use]
150    pub fn len(&self) -> usize {
151        self.tools.len()
152    }
153
154    /// Whether the registry has no registered tools.
155    #[must_use]
156    pub fn is_empty(&self) -> bool {
157        self.tools.is_empty()
158    }
159
160    /// Iterate over `(name, definition)` pairs for every registered tool.
161    ///
162    /// Returns clones of the cached definitions computed at registration time.
163    pub fn iter(&self) -> impl Iterator<Item = (&'static str, ToolDefinition)> + '_ {
164        self.tools
165            .iter()
166            .map(|(name, entry)| (*name, entry.definition.clone()))
167    }
168}
169
170/// Iterate over `(name, definition)` pairs for every registered tool.
171///
172/// Yields `(&'static str, ToolDefinition)` for each tool in the registry.
173impl<'a> IntoIterator for &'a ToolRegistry {
174    type Item = (&'static str, ToolDefinition);
175    type IntoIter = Box<dyn Iterator<Item = (&'static str, ToolDefinition)> + 'a>;
176
177    fn into_iter(self) -> Self::IntoIter {
178        Box::new(
179            self.tools
180                .iter()
181                .map(|(name, entry)| (*name, entry.definition.clone())),
182        )
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use serde::Deserialize;
189
190    use super::{
191        super::{EmptyParams, definition_of},
192        *,
193    };
194    use crate::llm_tool;
195
196    /// Create a default `ToolContext` for tests.
197    fn test_ctx() -> ToolContext {
198        ToolContext::new(None)
199    }
200
201    // ── Sample tool structs for tests ────────────────────────────────
202
203    #[derive(Deserialize, schemars::JsonSchema)]
204    struct PathParams {
205        /// Filesystem path.
206        path: String,
207    }
208
209    struct SampleTool;
210
211    impl RustTool for SampleTool {
212        type Params = PathParams;
213        const NAME: &'static str = "sample";
214        const DESCRIPTION: &'static str = "A sample tool";
215        async fn call(
216            &self,
217            params: Self::Params,
218            _ctx: &ToolContext,
219        ) -> Result<ToolOutput, ToolError> {
220            Ok(params.path.into())
221        }
222    }
223
224    #[derive(Deserialize, schemars::JsonSchema)]
225    struct RunCommandParams {
226        /// Command to run.
227        command: String,
228        /// Timeout in seconds.
229        #[serde(default)]
230        timeout: Option<i64>,
231        /// Environment variables.
232        #[serde(default)]
233        env: Option<std::collections::HashMap<String, String>>,
234    }
235
236    struct RunCommandTool;
237
238    impl RustTool for RunCommandTool {
239        type Params = RunCommandParams;
240        const NAME: &'static str = "run_command";
241        const DESCRIPTION: &'static str = "Runs a command.";
242        async fn call(
243            &self,
244            params: Self::Params,
245            _ctx: &ToolContext,
246        ) -> Result<ToolOutput, ToolError> {
247            assert!(params.timeout.is_none());
248            assert!(params.env.is_none());
249            Ok(format!("Ran: {}", params.command).into())
250        }
251    }
252
253    // ── ToolDefinition tests ─────────────────────────────────────────
254
255    #[test]
256    fn tool_definition_serde_roundtrip() {
257        let def = definition_of(&SampleTool).expect("schema");
258        let json = serde_json::to_string(&def).expect("serialize");
259        let parsed: ToolDefinition = serde_json::from_str(&json).expect("deserialize");
260        assert_eq!(parsed.name, def.name);
261        assert_eq!(parsed.description, def.description);
262        assert_eq!(parsed.parameter_schema, def.parameter_schema);
263    }
264
265    struct EmptyParamTool;
266    impl RustTool for EmptyParamTool {
267        type Params = EmptyParams;
268        const NAME: &'static str = "empty";
269        const DESCRIPTION: &'static str = "No params";
270        async fn call(
271            &self,
272            _params: Self::Params,
273            _ctx: &ToolContext,
274        ) -> Result<ToolOutput, ToolError> {
275            Ok("ok".into())
276        }
277    }
278
279    #[test]
280    fn tool_definition_with_empty_schema() {
281        let tool = definition_of(&EmptyParamTool).expect("schema");
282        let json = serde_json::to_string(&tool).expect("serialize");
283        let parsed: ToolDefinition = serde_json::from_str(&json).expect("deserialize");
284        // Compare via JSON to handle serde normalization (None vs empty struct).
285        let orig_json = serde_json::to_value(&tool.parameter_schema).unwrap();
286        let parsed_json = serde_json::to_value(&parsed.parameter_schema).unwrap();
287        assert_eq!(orig_json, parsed_json);
288    }
289
290    #[test]
291    fn tool_definition_with_complex_schema() {
292        let tool = definition_of(&RunCommandTool).expect("schema");
293        let schema_json = serde_json::to_value(&tool.parameter_schema).expect("schema to json");
294        // The schema should have 'command' as a required field.
295        let required = schema_json["required"]
296            .as_array()
297            .expect("required should be an array");
298        assert!(
299            required.iter().any(|v| v == "command"),
300            "'command' should be required, got: {required:?}"
301        );
302    }
303
304    // ── ToolRegistry tests ────────────────────────────────────────
305
306    #[tokio::test]
307    async fn registry_dispatch_valid_tool() {
308        let mut d = ToolRegistry::new();
309        d.register(SampleTool);
310        let result = d
311            .dispatch(
312                "sample",
313                serde_json::json!({"path": "/tmp/foo"}),
314                &test_ctx(),
315            )
316            .await;
317        assert_eq!(result.unwrap().content(), "/tmp/foo");
318    }
319
320    #[tokio::test]
321    async fn registry_dispatch_unknown_tool() {
322        let d = ToolRegistry::new();
323        let result = d
324            .dispatch("nonexistent", serde_json::json!({}), &test_ctx())
325            .await;
326        assert_eq!(
327            result.unwrap_err(),
328            ToolError::new("Unknown tool: nonexistent")
329        );
330    }
331
332    #[tokio::test]
333    async fn registry_dispatch_invalid_args() {
334        let mut d = ToolRegistry::new();
335        d.register(SampleTool);
336        // SampleTool expects {"path": String}, not an integer.
337        let result = d
338            .dispatch("sample", serde_json::json!({"path": 42}), &test_ctx())
339            .await;
340        let err = result.unwrap_err();
341        assert!(
342            err.message.contains("deserialize"),
343            "Error should mention deserialization, got: {err}"
344        );
345    }
346
347    #[tokio::test]
348    async fn registry_dispatch_missing_required_field() {
349        let mut d = ToolRegistry::new();
350        d.register(SampleTool);
351        // Missing the required "path" field entirely.
352        let err = d
353            .dispatch("sample", serde_json::json!({}), &test_ctx())
354            .await
355            .expect_err("Expected error for missing required field");
356        assert!(
357            err.message.contains("missing field"),
358            "Error should mention missing field, got: {err}"
359        );
360    }
361
362    #[test]
363    fn registry_definitions_returns_all() {
364        let mut d = ToolRegistry::new();
365        d.register(SampleTool);
366        d.register(RunCommandTool);
367
368        let defs = d.definitions();
369        assert_eq!(defs.len(), 2);
370
371        let mut names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
372        names.sort_unstable();
373        assert_eq!(names, vec!["run_command", "sample"]);
374    }
375
376    #[test]
377    fn registry_register_chaining() {
378        let mut d = ToolRegistry::new();
379        d.register(SampleTool).register(RunCommandTool);
380        assert_eq!(d.len(), 2);
381        assert!(!d.is_empty());
382    }
383
384    #[test]
385    fn registry_with_tool_owned_chaining() {
386        let d = ToolRegistry::new()
387            .with_tool(SampleTool)
388            .with_tool(RunCommandTool);
389        assert_eq!(d.len(), 2);
390        assert!(!d.is_empty());
391
392        let defs = d.definitions();
393        let mut names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
394        names.sort_unstable();
395        assert_eq!(names, vec!["run_command", "sample"]);
396    }
397
398    #[test]
399    fn registry_default_is_empty() {
400        let d = ToolRegistry::default();
401        assert!(d.is_empty());
402        assert_eq!(d.len(), 0);
403    }
404
405    #[tokio::test]
406    async fn registry_replaces_on_duplicate_name() {
407        struct AlternateSample;
408        impl RustTool for AlternateSample {
409            type Params = PathParams;
410            const NAME: &'static str = "sample";
411            const DESCRIPTION: &'static str = "Alternate sample";
412            async fn call(
413                &self,
414                params: Self::Params,
415                _ctx: &ToolContext,
416            ) -> Result<ToolOutput, ToolError> {
417                Ok(format!("alt: {}", params.path).into())
418            }
419        }
420
421        let mut d = ToolRegistry::new();
422        d.register(SampleTool);
423        d.register(AlternateSample);
424        assert_eq!(d.len(), 1);
425
426        let result = d
427            .dispatch("sample", serde_json::json!({"path": "x"}), &test_ctx())
428            .await;
429        assert_eq!(result.unwrap().content(), "alt: x");
430    }
431
432    #[tokio::test]
433    async fn registry_tool_returning_error() {
434        struct FailingTool;
435        impl RustTool for FailingTool {
436            type Params = EmptyParams;
437            const NAME: &'static str = "fail";
438            const DESCRIPTION: &'static str = "Always fails";
439            async fn call(
440                &self,
441                _params: Self::Params,
442                _ctx: &ToolContext,
443            ) -> Result<ToolOutput, ToolError> {
444                Err(ToolError::new("intentional failure"))
445            }
446        }
447
448        let mut d = ToolRegistry::new();
449        d.register(FailingTool);
450        let result = d.dispatch("fail", serde_json::json!({}), &test_ctx()).await;
451        assert_eq!(result.unwrap_err(), ToolError::new("intentional failure"));
452    }
453
454    #[test]
455    fn registry_debug_shows_tool_names() {
456        let mut d = ToolRegistry::new();
457        d.register(SampleTool);
458        let dbg = format!("{d:?}");
459        assert!(dbg.contains("ToolRegistry"));
460        assert!(dbg.contains("sample"));
461        assert!(dbg.contains("tool_count: 1"));
462    }
463
464    // ── Async-specific tests ────────────────────────────────────────
465
466    /// A tool that actually awaits a tokio sleep, proving async dispatch works.
467    struct AsyncSleepTool;
468
469    impl RustTool for AsyncSleepTool {
470        type Params = EmptyParams;
471        const NAME: &'static str = "async_sleep";
472        const DESCRIPTION: &'static str = "Sleeps briefly then returns.";
473
474        async fn call(
475            &self,
476            _params: Self::Params,
477            _ctx: &ToolContext,
478        ) -> Result<ToolOutput, ToolError> {
479            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
480            Ok("slept".into())
481        }
482    }
483
484    #[tokio::test]
485    async fn async_tool_with_tokio_sleep() {
486        let mut d = ToolRegistry::new();
487        d.register(AsyncSleepTool);
488        let result = d
489            .dispatch("async_sleep", serde_json::json!({}), &test_ctx())
490            .await;
491        assert_eq!(result.unwrap().content(), "slept");
492    }
493
494    /// A tool that reads a file using `tokio::fs`.
495    struct AsyncReadFileTool;
496
497    #[derive(Deserialize, schemars::JsonSchema)]
498    struct ReadFileParams {
499        /// Path to the file to read.
500        path: String,
501    }
502
503    impl RustTool for AsyncReadFileTool {
504        type Params = ReadFileParams;
505        const NAME: &'static str = "read_file";
506        const DESCRIPTION: &'static str = "Reads a file asynchronously.";
507
508        async fn call(
509            &self,
510            params: Self::Params,
511            _ctx: &ToolContext,
512        ) -> Result<ToolOutput, ToolError> {
513            tokio::fs::read_to_string(&params.path)
514                .await
515                .map(ToolOutput::from)
516                .map_err(|e| ToolError::new(format!("IO error: {e}")))
517        }
518    }
519
520    #[tokio::test]
521    async fn async_tool_with_tokio_fs() {
522        let tmp = tempfile::NamedTempFile::new().expect("create tempfile");
523        std::fs::write(tmp.path(), "hello async").expect("write tempfile");
524
525        let mut d = ToolRegistry::new();
526        d.register(AsyncReadFileTool);
527
528        let path_str = tmp.path().to_str().expect("path to str").to_owned();
529        let result = d
530            .dispatch(
531                "read_file",
532                serde_json::json!({"path": path_str}),
533                &test_ctx(),
534            )
535            .await;
536        assert_eq!(result.unwrap().content(), "hello async");
537    }
538
539    #[tokio::test]
540    async fn async_tool_tokio_fs_missing_file() {
541        let mut d = ToolRegistry::new();
542        d.register(AsyncReadFileTool);
543        let result = d
544            .dispatch(
545                "read_file",
546                serde_json::json!({"path": "/nonexistent/file.txt"}),
547                &test_ctx(),
548            )
549            .await;
550        let err = result.unwrap_err();
551        assert!(
552            err.message.contains("IO error"),
553            "Expected IO error, got: {err}"
554        );
555    }
556
557    /// A tool that uses a tokio channel to receive its result, proving
558    /// the full async machinery works end-to-end.
559    struct ChannelTool {
560        tx: tokio::sync::mpsc::Sender<String>,
561        rx: std::sync::Mutex<Option<tokio::sync::mpsc::Receiver<String>>>,
562    }
563
564    impl ChannelTool {
565        fn new() -> Self {
566            let (tx, rx) = tokio::sync::mpsc::channel(1);
567            Self {
568                tx,
569                rx: std::sync::Mutex::new(Some(rx)),
570            }
571        }
572    }
573
574    impl RustTool for ChannelTool {
575        type Params = EmptyParams;
576        const NAME: &'static str = "channel_tool";
577        const DESCRIPTION: &'static str = "Awaits a value from a channel.";
578
579        async fn call(
580            &self,
581            _params: Self::Params,
582            _ctx: &ToolContext,
583        ) -> Result<ToolOutput, ToolError> {
584            let mut rx = self
585                .rx
586                .lock()
587                .unwrap()
588                .take()
589                .ok_or_else(|| ToolError::new("channel already consumed"))?;
590            rx.recv()
591                .await
592                .map(ToolOutput::from)
593                .ok_or_else(|| ToolError::new("channel closed"))
594        }
595    }
596
597    #[tokio::test]
598    async fn async_tool_awaits_channel() {
599        let tool = ChannelTool::new();
600        let tx = tool.tx.clone();
601
602        let mut d = ToolRegistry::new();
603        d.register(tool);
604
605        // Send the value from another task.
606        let ctx = test_ctx();
607        let dispatch_future = d.dispatch("channel_tool", serde_json::json!({}), &ctx);
608        let send_future = async move {
609            tx.send("from_channel".to_string()).await.unwrap();
610        };
611
612        let (result, ()) = tokio::join!(dispatch_future, send_future);
613        assert_eq!(result.unwrap().content(), "from_channel");
614    }
615
616    // ── Concurrent dispatch tests ───────────────────────────────────
617
618    #[tokio::test]
619    async fn concurrent_dispatches_to_different_tools() {
620        let mut d = ToolRegistry::new();
621        d.register(SampleTool);
622        d.register(AsyncSleepTool);
623        d.register(RunCommandTool);
624
625        let ctx = test_ctx();
626        let (r1, r2, r3) = tokio::join!(
627            d.dispatch("sample", serde_json::json!({"path": "a"}), &ctx),
628            d.dispatch("async_sleep", serde_json::json!({}), &ctx),
629            d.dispatch("run_command", serde_json::json!({"command": "ls"}), &ctx),
630        );
631
632        assert_eq!(r1.unwrap().content(), "a");
633        assert_eq!(r2.unwrap().content(), "slept");
634        assert_eq!(r3.unwrap().content(), "Ran: ls");
635    }
636
637    #[tokio::test]
638    async fn concurrent_dispatches_to_same_tool() {
639        let mut d = ToolRegistry::new();
640        d.register(SampleTool);
641
642        let ctx = test_ctx();
643        let futs: Vec<_> = (0..10)
644            .map(|i| d.dispatch("sample", serde_json::json!({"path": format!("p{i}")}), &ctx))
645            .collect();
646
647        let results = futures::future::join_all(futs).await;
648        for (i, r) in results.into_iter().enumerate() {
649            assert_eq!(r.unwrap().content(), format!("p{i}"));
650        }
651    }
652
653    // ── Schema / doc comment tests ──────────────────────────────────
654
655    #[derive(Deserialize, schemars::JsonSchema)]
656    struct DocumentedParams {
657        /// The target hostname to connect to.
658        hostname: String,
659        /// Port number (1-65535).
660        port: u16,
661        /// Optional timeout in seconds.
662        #[serde(default)]
663        timeout: Option<f64>,
664    }
665
666    struct DocumentedTool;
667    impl RustTool for DocumentedTool {
668        type Params = DocumentedParams;
669        const NAME: &'static str = "connect";
670        const DESCRIPTION: &'static str = "Connects to a remote host.";
671        async fn call(&self, p: Self::Params, _ctx: &ToolContext) -> Result<ToolOutput, ToolError> {
672            Ok(format!("{}:{}:{:?}", p.hostname, p.port, p.timeout).into())
673        }
674    }
675
676    #[test]
677    fn schema_contains_field_descriptions() {
678        let def = definition_of(&DocumentedTool).expect("schema");
679        let schema = &def.parameter_schema;
680
681        // Check the properties contain our fields.
682        let props = schema["properties"].as_object().expect("properties object");
683        assert!(props.contains_key("hostname"), "missing hostname");
684        assert!(props.contains_key("port"), "missing port");
685        assert!(props.contains_key("timeout"), "missing timeout");
686
687        // Check the descriptions from doc comments made it through.
688        let hostname_desc = props["hostname"]["description"]
689            .as_str()
690            .expect("hostname description");
691        assert!(
692            hostname_desc.contains("hostname"),
693            "hostname description should mention 'hostname', got: {hostname_desc}"
694        );
695
696        let port_desc = props["port"]["description"]
697            .as_str()
698            .expect("port description");
699        assert!(
700            port_desc.contains("1-65535"),
701            "port description should mention range, got: {port_desc}"
702        );
703    }
704
705    #[test]
706    fn schema_required_vs_optional_fields() {
707        let def = definition_of(&DocumentedTool).expect("schema");
708        let schema = &def.parameter_schema;
709
710        let required = schema["required"]
711            .as_array()
712            .expect("required should be an array");
713
714        // hostname and port are required, timeout is Option → not required.
715        assert!(
716            required.iter().any(|v| v == "hostname"),
717            "hostname required"
718        );
719        assert!(required.iter().any(|v| v == "port"), "port required");
720        assert!(
721            !required.iter().any(|v| v == "timeout"),
722            "timeout should NOT be required"
723        );
724    }
725
726    #[tokio::test]
727    async fn dispatch_with_optional_field_missing() {
728        let mut d = ToolRegistry::new();
729        d.register(DocumentedTool);
730
731        // Dispatch without `timeout` (it has serde(default)).
732        let result = d
733            .dispatch(
734                "connect",
735                serde_json::json!({"hostname": "example.com", "port": 443}),
736                &test_ctx(),
737            )
738            .await;
739        assert_eq!(result.unwrap().content(), "example.com:443:None");
740    }
741
742    #[tokio::test]
743    async fn dispatch_with_optional_field_present() {
744        let mut d = ToolRegistry::new();
745        d.register(DocumentedTool);
746
747        let result = d
748            .dispatch(
749                "connect",
750                serde_json::json!({"hostname": "localhost", "port": 8080, "timeout": 30.0}),
751                &test_ctx(),
752            )
753            .await;
754        assert_eq!(result.unwrap().content(), "localhost:8080:Some(30.0)");
755    }
756
757    #[tokio::test]
758    async fn dispatch_with_extra_fields_ignored() {
759        // serde's default behavior ignores unknown fields.
760        let mut d = ToolRegistry::new();
761        d.register(SampleTool);
762
763        let result = d
764            .dispatch(
765                "sample",
766                serde_json::json!({"path": "/tmp/x", "unknown_field": 42}),
767                &test_ctx(),
768            )
769            .await;
770        assert_eq!(result.unwrap().content(), "/tmp/x");
771    }
772
773    // ── BoxFuture / ErasedTool edge case tests ──────────────────────
774
775    #[tokio::test]
776    async fn erased_dispatch_preserves_borrow_lifetime() {
777        // Ensures the BoxToolFuture lifetime is tied to &self correctly,
778        // i.e. the registry can be borrowed immutably while the future runs.
779        let mut d = ToolRegistry::new();
780        d.register(AsyncSleepTool);
781        d.register(SampleTool);
782
783        // Dispatch two calls on the same registry reference.
784        let r1 = d
785            .dispatch("async_sleep", serde_json::json!({}), &test_ctx())
786            .await;
787        let r2 = d
788            .dispatch("sample", serde_json::json!({"path": "test"}), &test_ctx())
789            .await;
790
791        assert_eq!(r1.unwrap().content(), "slept");
792        assert_eq!(r2.unwrap().content(), "test");
793    }
794
795    #[tokio::test]
796    async fn dispatch_returns_meaningful_error_for_wrong_type() {
797        let mut d = ToolRegistry::new();
798        d.register(RunCommandTool);
799
800        // `command` expects a String, pass an object instead.
801        let result = d
802            .dispatch(
803                "run_command",
804                serde_json::json!({"command": {"nested": "object"}}),
805                &test_ctx(),
806            )
807            .await;
808        let err = result.unwrap_err();
809        assert!(
810            err.message
811                .contains("Failed to deserialize tool parameters"),
812            "Error should mention deserialization failure, got: {err}"
813        );
814    }
815
816    // ── R7: #[llm_tool] on async fn ─────────────────────────────────────
817
818    /// Async tool defined with the `#[llm_tool]` proc macro. The body uses
819    /// `.await` to prove it runs in an async context.
820    #[llm_tool]
821    async fn async_delayed_echo(
822        /// The message to echo back.
823        message: String,
824    ) -> Result<String, ToolError> {
825        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
826        Ok(format!("echo: {message}"))
827    }
828
829    #[tokio::test]
830    async fn tool_macro_async_fn_dispatches_with_await() {
831        let mut d = ToolRegistry::new();
832        d.register(AsyncDelayedEcho);
833
834        let result = d
835            .dispatch(
836                "async_delayed_echo",
837                serde_json::json!({"message": "hello async"}),
838                &test_ctx(),
839            )
840            .await;
841        assert_eq!(result.unwrap().content(), "echo: hello async");
842    }
843
844    /// Async tool that reads a file via `tokio::fs`, proving real I/O works.
845    #[llm_tool]
846    async fn async_file_reader(
847        /// Path to read.
848        path: String,
849    ) -> Result<String, ToolError> {
850        tokio::fs::read_to_string(&path)
851            .await
852            .map_err(|e| ToolError::new(format!("IO error: {e}")))
853    }
854
855    #[tokio::test]
856    async fn tool_macro_async_fn_reads_file() {
857        let tmp = tempfile::NamedTempFile::new().expect("create tempfile");
858        std::fs::write(tmp.path(), "async macro content").expect("write");
859
860        let mut d = ToolRegistry::new();
861        d.register(AsyncFileReader);
862
863        let path_str = tmp.path().to_str().expect("path").to_owned();
864        let result = d
865            .dispatch(
866                "async_file_reader",
867                serde_json::json!({"path": path_str}),
868                &test_ctx(),
869            )
870            .await;
871        assert_eq!(result.unwrap().content(), "async macro content");
872    }
873
874    // ── R8: Option<T> auto-default via #[llm_tool] ────────────────────
875
876    /// Tool with an optional greeting parameter.
877    #[llm_tool]
878    fn greet_optional(
879        /// Name to greet.
880        name: String,
881        /// Custom greeting (defaults to None if omitted).
882        greeting: Option<String>,
883    ) -> Result<String, ToolError> {
884        let g = greeting.unwrap_or_else(|| "Hello".to_string());
885        Ok(format!("{g}, {name}!"))
886    }
887
888    #[test]
889    fn tool_macro_option_param_not_in_required() {
890        let def = definition_of(&GreetOptional).expect("schema");
891        let schema = &def.parameter_schema;
892
893        let required = schema["required"]
894            .as_array()
895            .expect("required should be an array");
896
897        // `name` is required, `greeting` is Option<T> → not required.
898        assert!(
899            required.iter().any(|v| v == "name"),
900            "'name' should be required, got: {required:?}"
901        );
902        assert!(
903            !required.iter().any(|v| v == "greeting"),
904            "'greeting' (Option<String>) should NOT be required, got: {required:?}"
905        );
906    }
907
908    #[tokio::test]
909    async fn tool_macro_option_param_missing_from_json() {
910        let mut d = ToolRegistry::new();
911        d.register(GreetOptional);
912
913        // Dispatch without the optional `greeting` field.
914        let result = d
915            .dispatch(
916                "greet_optional",
917                serde_json::json!({"name": "World"}),
918                &test_ctx(),
919            )
920            .await;
921        assert_eq!(result.unwrap().content(), "Hello, World!");
922    }
923
924    #[tokio::test]
925    async fn tool_macro_option_param_provided_in_json() {
926        let mut d = ToolRegistry::new();
927        d.register(GreetOptional);
928
929        // Dispatch with the optional `greeting` field present.
930        let result = d
931            .dispatch(
932                "greet_optional",
933                serde_json::json!({"name": "World", "greeting": "Hi"}),
934                &test_ctx(),
935            )
936            .await;
937        assert_eq!(result.unwrap().content(), "Hi, World!");
938    }
939
940    /// Tool combining async + Option<T> to verify both features work together.
941    #[llm_tool]
942    async fn async_optional_tool(
943        /// Required input.
944        input: String,
945        /// Optional suffix.
946        suffix: Option<String>,
947    ) -> Result<String, ToolError> {
948        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
949        let s = suffix.unwrap_or_default();
950        Ok(format!("{input}{s}"))
951    }
952
953    #[tokio::test]
954    async fn tool_macro_async_with_optional_param() {
955        let mut d = ToolRegistry::new();
956        d.register(AsyncOptionalTool);
957
958        // Without optional param.
959        let r1 = d
960            .dispatch(
961                "async_optional_tool",
962                serde_json::json!({"input": "base"}),
963                &test_ctx(),
964            )
965            .await;
966        assert_eq!(r1.unwrap().content(), "base");
967
968        // With optional param.
969        let r2 = d
970            .dispatch(
971                "async_optional_tool",
972                serde_json::json!({"input": "base", "suffix": "_ext"}),
973                &test_ctx(),
974            )
975            .await;
976        assert_eq!(r2.unwrap().content(), "base_ext");
977    }
978
979    #[test]
980    fn tool_macro_async_optional_schema_correctness() {
981        let def = definition_of(&AsyncOptionalTool).expect("schema");
982        let schema = &def.parameter_schema;
983
984        let required = schema["required"].as_array().expect("required array");
985        assert!(required.iter().any(|v| v == "input"), "'input' required");
986        assert!(
987            !required.iter().any(|v| v == "suffix"),
988            "'suffix' (Option) should NOT be required"
989        );
990    }
991
992    // ── IntoIterator tests ──────────────────────────────────────────
993
994    #[test]
995    fn into_iter_yields_all_tool_name_definition_pairs() {
996        let mut d = ToolRegistry::new();
997        d.register(SampleTool);
998        d.register(RunCommandTool);
999
1000        let mut pairs: Vec<(&str, String)> = (&d)
1001            .into_iter()
1002            .map(|(name, def)| (name, def.name))
1003            .collect();
1004        pairs.sort();
1005
1006        assert_eq!(pairs.len(), 2);
1007        assert_eq!(pairs[0].0, "run_command");
1008        assert_eq!(pairs[0].1, "run_command");
1009        assert_eq!(pairs[1].0, "sample");
1010        assert_eq!(pairs[1].1, "sample");
1011    }
1012
1013    #[test]
1014    fn into_iter_empty_registry_yields_nothing() {
1015        let d = ToolRegistry::new();
1016        let count = (&d).into_iter().count();
1017        assert_eq!(count, 0);
1018    }
1019
1020    #[test]
1021    fn into_iter_for_loop_syntax() {
1022        let mut d = ToolRegistry::new();
1023        d.register(SampleTool);
1024
1025        let mut found = false;
1026        for (name, def) in &d {
1027            if name == "sample" {
1028                assert_eq!(def.description, "A sample tool");
1029                found = true;
1030            }
1031        }
1032        assert!(found, "Expected to find 'sample' tool via for-in loop");
1033    }
1034
1035    // ── ToolContext tests ───────────────────────────────────────────
1036
1037    #[test]
1038    fn tool_context_conversation_id_none_by_default() {
1039        let ctx = ToolContext::new(None);
1040        assert!(ctx.conversation_id().is_none());
1041        assert!(!ctx.is_idle());
1042    }
1043
1044    #[test]
1045    fn tool_context_conversation_id_returns_value() {
1046        let ctx = ToolContext::new(Some("conv-123".to_owned()));
1047        assert_eq!(ctx.conversation_id(), Some("conv-123"));
1048    }
1049
1050    #[test]
1051    fn tool_context_get_set_state_roundtrip() {
1052        let ctx = ToolContext::new(None);
1053
1054        // Default for missing key.
1055        let val = ctx.get_state("missing", serde_json::json!("fallback"));
1056        assert_eq!(val, serde_json::json!("fallback"));
1057
1058        // Set and retrieve.
1059        ctx.set_state("counter", serde_json::json!(42))
1060            .expect("set_state");
1061        let val = ctx.get_state("counter", serde_json::json!(0));
1062        assert_eq!(val, serde_json::json!(42));
1063
1064        // Overwrite.
1065        ctx.set_state("counter", serde_json::json!(99))
1066            .expect("set_state");
1067        let val = ctx.get_state("counter", serde_json::json!(0));
1068        assert_eq!(val, serde_json::json!(99));
1069    }
1070
1071    #[test]
1072    fn tool_context_state_persists_across_reads() {
1073        let ctx = ToolContext::new(None);
1074        ctx.set_state("key", serde_json::json!({"nested": true}))
1075            .expect("set_state");
1076
1077        // Multiple reads return the same value.
1078        let v1 = ctx.get_state("key", serde_json::json!(null));
1079        let v2 = ctx.get_state("key", serde_json::json!(null));
1080        assert_eq!(v1, v2);
1081        assert_eq!(v1, serde_json::json!({"nested": true}));
1082    }
1083
1084    #[tokio::test]
1085    async fn dispatch_passes_context_to_tool() {
1086        /// A tool that reads from the `ToolContext` state.
1087        struct ContextAwareTool;
1088
1089        impl RustTool for ContextAwareTool {
1090            type Params = EmptyParams;
1091            const NAME: &'static str = "ctx_tool";
1092            const DESCRIPTION: &'static str = "Reads conversation_id from context.";
1093
1094            async fn call(
1095                &self,
1096                _params: Self::Params,
1097                ctx: &ToolContext,
1098            ) -> Result<ToolOutput, ToolError> {
1099                let conv = ctx.conversation_id().unwrap_or("none");
1100                let count = ctx.get_state("call_count", serde_json::json!(0));
1101                let n = count.as_i64().unwrap_or(0);
1102                ctx.set_state("call_count", serde_json::json!(n + 1))
1103                    .map_err(|e| ToolError::new(format!("set_state failed: {e}")))?;
1104                Ok(format!("conv={conv}, call={n}").into())
1105            }
1106        }
1107
1108        let mut d = ToolRegistry::new();
1109        d.register(ContextAwareTool);
1110
1111        let ctx = ToolContext::new(Some("test-conv".to_owned()));
1112
1113        // First call.
1114        let r1 = d.dispatch("ctx_tool", serde_json::json!({}), &ctx).await;
1115        assert_eq!(r1.unwrap().content(), "conv=test-conv, call=0");
1116
1117        // Second call — state persists.
1118        let r2 = d.dispatch("ctx_tool", serde_json::json!({}), &ctx).await;
1119        assert_eq!(r2.unwrap().content(), "conv=test-conv, call=1");
1120    }
1121
1122    // ── ToolOutput metadata tests ───────────────────────────────────
1123
1124    #[derive(serde::Serialize)]
1125    struct ProcessMeta {
1126        bytes_read: usize,
1127        source: String,
1128    }
1129
1130    /// A tool that attaches typed metadata to its output.
1131    struct MetadataTool;
1132
1133    impl RustTool for MetadataTool {
1134        type Params = PathParams;
1135        const NAME: &'static str = "metadata_tool";
1136        const DESCRIPTION: &'static str = "Returns output with metadata.";
1137
1138        async fn call(
1139            &self,
1140            params: Self::Params,
1141            _ctx: &ToolContext,
1142        ) -> Result<ToolOutput, ToolError> {
1143            ToolOutput::new(format!("processed: {}", params.path)).with_metadata(&ProcessMeta {
1144                bytes_read: 1024,
1145                source: params.path,
1146            })
1147        }
1148    }
1149
1150    #[tokio::test]
1151    async fn dispatch_preserves_tool_output_metadata() {
1152        let mut d = ToolRegistry::new();
1153        d.register(MetadataTool);
1154
1155        let result = d
1156            .dispatch(
1157                "metadata_tool",
1158                serde_json::json!({"path": "/etc/hosts"}),
1159                &test_ctx(),
1160            )
1161            .await
1162            .unwrap();
1163
1164        assert_eq!(result.content(), "processed: /etc/hosts");
1165        assert_eq!(result.metadata()["bytes_read"], 1024);
1166        assert_eq!(result.metadata()["source"], "/etc/hosts");
1167        assert_eq!(result.metadata().len(), 2);
1168    }
1169
1170    #[tokio::test]
1171    async fn dispatch_tool_output_display_uses_content() {
1172        let output = ToolOutput::new("hello world").with_meta("ignored", serde_json::json!(true));
1173        assert_eq!(output.to_string(), "hello world");
1174    }
1175
1176    #[tokio::test]
1177    async fn dispatch_tool_output_into_content_consumes() {
1178        let output = ToolOutput::new("owned").with_meta("key", serde_json::json!("val"));
1179        let content: String = output.into_content();
1180        assert_eq!(content, "owned");
1181    }
1182
1183    #[test]
1184    fn tool_output_from_str_has_empty_metadata() {
1185        let output: ToolOutput = "plain".into();
1186        assert_eq!(output.content(), "plain");
1187        assert!(output.metadata().is_empty());
1188    }
1189
1190    #[test]
1191    fn tool_output_from_string_has_empty_metadata() {
1192        let output: ToolOutput = "owned".to_string().into();
1193        assert_eq!(output.content(), "owned");
1194        assert!(output.metadata().is_empty());
1195    }
1196
1197    // ── ToolError metadata tests ────────────────────────────────────
1198
1199    #[test]
1200    fn tool_error_with_metadata() {
1201        let err = ToolError::new("HTTP request failed")
1202            .with_meta("status_code", serde_json::json!(503))
1203            .with_meta("url", serde_json::json!("https://example.com"));
1204
1205        assert_eq!(err.message, "HTTP request failed");
1206        assert_eq!(err.metadata()["status_code"], 503);
1207        assert_eq!(err.metadata()["url"], "https://example.com");
1208        assert_eq!(err.metadata().len(), 2);
1209    }
1210
1211    #[test]
1212    fn tool_error_without_metadata_is_empty() {
1213        let err = ToolError::new("simple error");
1214        assert!(err.metadata().is_empty());
1215    }
1216
1217    #[test]
1218    fn tool_error_display_ignores_metadata() {
1219        let err = ToolError::new("visible").with_meta("hidden", serde_json::json!(true));
1220        assert_eq!(err.to_string(), "visible");
1221    }
1222
1223    #[test]
1224    fn tool_error_equality_includes_metadata() {
1225        let a = ToolError::new("err").with_meta("k", serde_json::json!(1));
1226        let b = ToolError::new("err").with_meta("k", serde_json::json!(1));
1227        let c = ToolError::new("err").with_meta("k", serde_json::json!(2));
1228        assert_eq!(a, b);
1229        assert_ne!(a, c);
1230    }
1231
1232    /// A tool that returns `ToolError` with metadata.
1233    struct MetadataErrorTool;
1234
1235    impl RustTool for MetadataErrorTool {
1236        type Params = EmptyParams;
1237        const NAME: &'static str = "metadata_error_tool";
1238        const DESCRIPTION: &'static str = "Always fails with metadata.";
1239
1240        async fn call(
1241            &self,
1242            _params: Self::Params,
1243            _ctx: &ToolContext,
1244        ) -> Result<ToolOutput, ToolError> {
1245            Err(ToolError::new("service unavailable")
1246                .with_meta("retry_after_secs", serde_json::json!(30)))
1247        }
1248    }
1249
1250    #[tokio::test]
1251    async fn dispatch_preserves_tool_error_metadata() {
1252        let mut d = ToolRegistry::new();
1253        d.register(MetadataErrorTool);
1254
1255        let err = d
1256            .dispatch("metadata_error_tool", serde_json::json!({}), &test_ctx())
1257            .await
1258            .unwrap_err();
1259
1260        assert_eq!(err.message, "service unavailable");
1261        assert_eq!(err.metadata()["retry_after_secs"], 30);
1262    }
1263
1264    // ── #[llm_tool] macro returning ToolOutput directly ─────────────────
1265
1266    /// A tool that returns ToolOutput with metadata via the macro.
1267    #[llm_tool]
1268    fn tool_with_metadata(
1269        /// Input value.
1270        input: String,
1271    ) -> Result<ToolOutput, ToolError> {
1272        Ok(ToolOutput::new(format!("echoed: {input}"))
1273            .with_meta("input_len", serde_json::json!(input.len())))
1274    }
1275
1276    #[tokio::test]
1277    async fn macro_tool_returning_tool_output_preserves_metadata() {
1278        let mut d = ToolRegistry::new();
1279        d.register(ToolWithMetadata);
1280
1281        let result = d
1282            .dispatch(
1283                "tool_with_metadata",
1284                serde_json::json!({"input": "hello"}),
1285                &test_ctx(),
1286            )
1287            .await
1288            .unwrap();
1289
1290        assert_eq!(result.content(), "echoed: hello");
1291        assert_eq!(result.metadata()["input_len"], 5);
1292    }
1293
1294    // ── with_metadata struct-based tests ─────────────────────────────
1295
1296    #[test]
1297    fn tool_output_with_metadata_struct() {
1298        #[derive(serde::Serialize)]
1299        struct Meta {
1300            status: String,
1301            count: u32,
1302        }
1303
1304        let out = ToolOutput::new("done")
1305            .with_metadata(&Meta {
1306                status: "ok".into(),
1307                count: 42,
1308            })
1309            .unwrap();
1310
1311        assert_eq!(out.metadata()["status"], "ok");
1312        assert_eq!(out.metadata()["count"], 42);
1313        assert_eq!(out.metadata().len(), 2);
1314    }
1315
1316    #[test]
1317    fn tool_output_with_metadata_merges_with_existing() {
1318        #[derive(serde::Serialize)]
1319        struct Extra {
1320            source: String,
1321        }
1322
1323        let out = ToolOutput::new("data")
1324            .with_meta("version", serde_json::json!(1))
1325            .with_metadata(&Extra {
1326                source: "cache".into(),
1327            })
1328            .unwrap();
1329
1330        assert_eq!(out.metadata()["version"], 1);
1331        assert_eq!(out.metadata()["source"], "cache");
1332        assert_eq!(out.metadata().len(), 2);
1333    }
1334
1335    #[test]
1336    fn tool_output_with_metadata_rejects_non_object() {
1337        let err = ToolOutput::new("x").with_metadata(&42_i32).unwrap_err();
1338
1339        assert!(
1340            err.message.contains("JSON object"),
1341            "Expected object error, got: {err}"
1342        );
1343    }
1344
1345    #[test]
1346    fn tool_error_with_metadata_struct() {
1347        #[derive(serde::Serialize)]
1348        struct ErrorMeta {
1349            status_code: u16,
1350            url: String,
1351        }
1352
1353        let err = ToolError::new("HTTP request failed")
1354            .with_metadata(&ErrorMeta {
1355                status_code: 503,
1356                url: "https://example.com".into(),
1357            })
1358            .unwrap();
1359
1360        assert_eq!(err.message, "HTTP request failed");
1361        assert_eq!(err.metadata()["status_code"], 503);
1362        assert_eq!(err.metadata()["url"], "https://example.com");
1363        assert_eq!(err.metadata().len(), 2);
1364    }
1365
1366    // ── from_metadata tests ─────────────────────────────────────────
1367
1368    #[test]
1369    fn tool_output_from_metadata_populates_both() {
1370        #[derive(serde::Serialize)]
1371        struct Weather {
1372            location: String,
1373            temp_f: i32,
1374        }
1375
1376        let out = ToolOutput::from_metadata(&Weather {
1377            location: "Seattle".into(),
1378            temp_f: 72,
1379        })
1380        .unwrap();
1381
1382        // Content is the JSON string sent to the model.
1383        assert!(out.content().contains("Seattle"));
1384        assert!(out.content().contains("72"));
1385
1386        // Metadata has typed fields for hooks.
1387        assert_eq!(out.metadata()["location"], "Seattle");
1388        assert_eq!(out.metadata()["temp_f"], 72);
1389        assert_eq!(out.metadata().len(), 2);
1390    }
1391
1392    #[test]
1393    fn tool_output_from_metadata_rejects_non_object() {
1394        let err = ToolOutput::from_metadata(&"just a string").unwrap_err();
1395        assert!(
1396            err.message.contains("JSON object"),
1397            "Expected object error, got: {err}"
1398        );
1399    }
1400}