Skip to main content

llm_tool/
registry.rs

1//! Tool registry: registry and concurrent dispatch of named tools.
2
3use std::collections::HashMap;
4
5use super::{
6    rust_tool::{ErasedTool, RustTool},
7    types::{ToolContext, ToolDefinition, ToolError, ToolOutput},
8};
9
10pub struct ToolRegistry {
11    tools: HashMap<String, Box<dyn ErasedTool>>,
12}
13
14impl std::fmt::Debug for ToolRegistry {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        f.debug_struct("ToolRegistry")
17            .field("tool_count", &self.tools.len())
18            .field("tool_names", &self.tools.keys().collect::<Vec<_>>())
19            .finish()
20    }
21}
22
23impl Default for ToolRegistry {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl ToolRegistry {
30    /// Create an empty registry.
31    #[must_use]
32    pub fn new() -> Self {
33        Self {
34            tools: HashMap::new(),
35        }
36    }
37
38    /// Register a [`RustTool`]. Returns `&mut Self` for chaining.
39    ///
40    /// If a tool with the same name was already registered, it is replaced.
41    pub fn register<T: RustTool + 'static>(&mut self, tool: T) -> &mut Self {
42        self.tools.insert(T::NAME.to_string(), Box::new(tool));
43        self
44    }
45
46    /// Register a [`RustTool`], consuming and returning `Self` for owned chaining.
47    ///
48    /// This is the owned counterpart of [`register`](Self::register), enabling
49    /// patterns like:
50    /// ```
51    /// use llm_tool::{RustTool, ToolContext, ToolError, ToolOutput, ToolRegistry};
52    /// use schemars::JsonSchema;
53    /// use serde::Deserialize;
54    ///
55    /// #[derive(Deserialize, JsonSchema)]
56    /// struct NoParams {}
57    ///
58    /// struct ToolA;
59    /// impl RustTool for ToolA {
60    ///     type Params = NoParams;
61    ///     const NAME: &'static str = "tool_a";
62    ///     const DESCRIPTION: &'static str = "Tool A";
63    ///     async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
64    ///         Ok("a".into())
65    ///     }
66    /// }
67    ///
68    /// struct ToolB;
69    /// impl RustTool for ToolB {
70    ///     type Params = NoParams;
71    ///     const NAME: &'static str = "tool_b";
72    ///     const DESCRIPTION: &'static str = "Tool B";
73    ///     async fn call(&self, _: NoParams, _: &ToolContext) -> Result<ToolOutput, ToolError> {
74    ///         Ok("b".into())
75    ///     }
76    /// }
77    ///
78    /// let registry = ToolRegistry::new().with_tool(ToolA).with_tool(ToolB);
79    ///
80    /// assert_eq!(registry.definitions().len(), 2);
81    /// ```
82    #[must_use]
83    pub fn with_tool<T: RustTool + 'static>(mut self, tool: T) -> Self {
84        self.register(tool);
85        self
86    }
87
88    /// Try to build a definition for a named tool, logging on failure.
89    fn try_definition<'a>(
90        name: &'a str,
91        tool: &'a dyn ErasedTool,
92    ) -> Option<(&'a str, ToolDefinition)> {
93        match tool.definition() {
94            Ok(def) => Some((name, def)),
95            Err(e) => {
96                tracing::warn!(tool = %name, error = %e, "Skipping tool with invalid schema");
97                None
98            }
99        }
100    }
101
102    /// Collect [`ToolDefinition`]s for all registered tools.
103    ///
104    /// Logs a warning and skips any tool whose schema serialization fails.
105    #[must_use]
106    pub fn definitions(&self) -> Vec<ToolDefinition> {
107        self.tools
108            .iter()
109            .filter_map(|(name, tool)| {
110                Self::try_definition(name, tool.as_ref()).map(|(_, def)| def)
111            })
112            .collect()
113    }
114
115    /// Dispatch a tool call by name with raw JSON arguments and a context.
116    ///
117    /// # Errors
118    ///
119    /// Returns `Err` if the tool name is unknown or the handler returns an error.
120    pub async fn dispatch(
121        &self,
122        name: &str,
123        args: serde_json::Value,
124        ctx: &ToolContext,
125    ) -> Result<ToolOutput, ToolError> {
126        let tool = self
127            .tools
128            .get(name)
129            .ok_or_else(|| ToolError::new(format!("Unknown tool: {name}")))?;
130        tool.call_erased(args, ctx).await
131    }
132
133    /// Number of registered tools.
134    #[must_use]
135    pub fn len(&self) -> usize {
136        self.tools.len()
137    }
138
139    /// Whether the registry has no registered tools.
140    #[must_use]
141    pub fn is_empty(&self) -> bool {
142        self.tools.is_empty()
143    }
144
145    /// Iterate over `(name, definition)` pairs for every registered tool.
146    ///
147    /// Logs a warning and skips tools whose schema serialization fails.
148    pub fn iter(&self) -> impl Iterator<Item = (&String, ToolDefinition)> + '_ {
149        self.tools.iter().filter_map(|(name, tool)| {
150            Self::try_definition(name, tool.as_ref()).map(|(_, def)| (name, def))
151        })
152    }
153}
154
155/// Iterate over `(name, definition)` pairs for every registered tool.
156///
157/// Yields `(&String, ToolDefinition)` for each tool in the registry.
158impl<'a> IntoIterator for &'a ToolRegistry {
159    type Item = (&'a String, ToolDefinition);
160    type IntoIter = Box<dyn Iterator<Item = (&'a String, ToolDefinition)> + 'a>;
161
162    fn into_iter(self) -> Self::IntoIter {
163        Box::new(self.tools.iter().filter_map(|(name, tool)| {
164            ToolRegistry::try_definition(name, tool.as_ref()).map(|(_, def)| (name, def))
165        }))
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use crate::llm_tool;
172    use serde::Deserialize;
173
174    use super::{
175        super::{EmptyParams, definition_of},
176        *,
177    };
178
179    /// Create a default `ToolContext` for tests.
180    fn test_ctx() -> ToolContext {
181        ToolContext::new(None)
182    }
183
184    // ── Sample tool structs for tests ────────────────────────────────
185
186    #[derive(Deserialize, schemars::JsonSchema)]
187    struct PathParams {
188        /// Filesystem path.
189        path: String,
190    }
191
192    struct SampleTool;
193
194    impl RustTool for SampleTool {
195        type Params = PathParams;
196        const NAME: &'static str = "sample";
197        const DESCRIPTION: &'static str = "A sample tool";
198        async fn call(
199            &self,
200            params: Self::Params,
201            _ctx: &ToolContext,
202        ) -> Result<ToolOutput, ToolError> {
203            Ok(params.path.into())
204        }
205    }
206
207    #[derive(Deserialize, schemars::JsonSchema)]
208    struct RunCommandParams {
209        /// Command to run.
210        command: String,
211        /// Timeout in seconds.
212        #[serde(default)]
213        timeout: Option<i64>,
214        /// Environment variables.
215        #[serde(default)]
216        env: Option<std::collections::HashMap<String, String>>,
217    }
218
219    struct RunCommandTool;
220
221    impl RustTool for RunCommandTool {
222        type Params = RunCommandParams;
223        const NAME: &'static str = "run_command";
224        const DESCRIPTION: &'static str = "Runs a command.";
225        async fn call(
226            &self,
227            params: Self::Params,
228            _ctx: &ToolContext,
229        ) -> Result<ToolOutput, ToolError> {
230            assert!(params.timeout.is_none());
231            assert!(params.env.is_none());
232            Ok(format!("Ran: {}", params.command).into())
233        }
234    }
235
236    // ── ToolDefinition tests ─────────────────────────────────────────
237
238    #[test]
239    fn tool_definition_serde_roundtrip() {
240        let def = definition_of(&SampleTool).expect("schema");
241        let json = serde_json::to_string(&def).expect("serialize");
242        let parsed: ToolDefinition = serde_json::from_str(&json).expect("deserialize");
243        assert_eq!(parsed.name, def.name);
244        assert_eq!(parsed.description, def.description);
245        assert_eq!(parsed.parameter_schema, def.parameter_schema);
246    }
247
248    struct EmptyParamTool;
249    impl RustTool for EmptyParamTool {
250        type Params = EmptyParams;
251        const NAME: &'static str = "empty";
252        const DESCRIPTION: &'static str = "No params";
253        async fn call(
254            &self,
255            _params: Self::Params,
256            _ctx: &ToolContext,
257        ) -> Result<ToolOutput, ToolError> {
258            Ok("ok".into())
259        }
260    }
261
262    #[test]
263    fn tool_definition_with_empty_schema() {
264        let tool = definition_of(&EmptyParamTool).expect("schema");
265        let json = serde_json::to_string(&tool).expect("serialize");
266        let parsed: ToolDefinition = serde_json::from_str(&json).expect("deserialize");
267        // Compare via JSON to handle serde normalization (None vs empty struct).
268        let orig_json = serde_json::to_value(&tool.parameter_schema).unwrap();
269        let parsed_json = serde_json::to_value(&parsed.parameter_schema).unwrap();
270        assert_eq!(orig_json, parsed_json);
271    }
272
273    #[test]
274    fn tool_definition_with_complex_schema() {
275        let tool = definition_of(&RunCommandTool).expect("schema");
276        let schema_json = serde_json::to_value(&tool.parameter_schema).expect("schema to json");
277        // The schema should have 'command' as a required field.
278        let required = schema_json["required"]
279            .as_array()
280            .expect("required should be an array");
281        assert!(
282            required.iter().any(|v| v == "command"),
283            "'command' should be required, got: {required:?}"
284        );
285    }
286
287    // ── ToolRegistry tests ────────────────────────────────────────
288
289    #[tokio::test]
290    async fn registry_dispatch_valid_tool() {
291        let mut d = ToolRegistry::new();
292        d.register(SampleTool);
293        let result = d
294            .dispatch(
295                "sample",
296                serde_json::json!({"path": "/tmp/foo"}),
297                &test_ctx(),
298            )
299            .await;
300        assert_eq!(result.unwrap().content(), "/tmp/foo");
301    }
302
303    #[tokio::test]
304    async fn registry_dispatch_unknown_tool() {
305        let d = ToolRegistry::new();
306        let result = d
307            .dispatch("nonexistent", serde_json::json!({}), &test_ctx())
308            .await;
309        assert_eq!(
310            result.unwrap_err(),
311            ToolError::new("Unknown tool: nonexistent")
312        );
313    }
314
315    #[tokio::test]
316    async fn registry_dispatch_invalid_args() {
317        let mut d = ToolRegistry::new();
318        d.register(SampleTool);
319        // SampleTool expects {"path": String}, not an integer.
320        let result = d
321            .dispatch("sample", serde_json::json!({"path": 42}), &test_ctx())
322            .await;
323        assert!(
324            result.is_err(),
325            "Expected deserialization error for wrong type"
326        );
327        let err = result.unwrap_err();
328        assert!(
329            err.message.contains("deserialize"),
330            "Error should mention deserialization, got: {err}"
331        );
332    }
333
334    #[tokio::test]
335    async fn registry_dispatch_missing_required_field() {
336        let mut d = ToolRegistry::new();
337        d.register(SampleTool);
338        // Missing the required "path" field entirely.
339        let result = d
340            .dispatch("sample", serde_json::json!({}), &test_ctx())
341            .await;
342        assert!(result.is_err(), "Expected error for missing required field");
343    }
344
345    #[test]
346    fn registry_definitions_returns_all() {
347        let mut d = ToolRegistry::new();
348        d.register(SampleTool);
349        d.register(RunCommandTool);
350
351        let defs = d.definitions();
352        assert_eq!(defs.len(), 2);
353
354        let mut names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
355        names.sort_unstable();
356        assert_eq!(names, vec!["run_command", "sample"]);
357    }
358
359    #[test]
360    fn registry_register_chaining() {
361        let mut d = ToolRegistry::new();
362        d.register(SampleTool).register(RunCommandTool);
363        assert_eq!(d.len(), 2);
364        assert!(!d.is_empty());
365    }
366
367    #[test]
368    fn registry_with_tool_owned_chaining() {
369        let d = ToolRegistry::new()
370            .with_tool(SampleTool)
371            .with_tool(RunCommandTool);
372        assert_eq!(d.len(), 2);
373        assert!(!d.is_empty());
374
375        let defs = d.definitions();
376        let mut names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
377        names.sort_unstable();
378        assert_eq!(names, vec!["run_command", "sample"]);
379    }
380
381    #[test]
382    fn registry_default_is_empty() {
383        let d = ToolRegistry::default();
384        assert!(d.is_empty());
385        assert_eq!(d.len(), 0);
386    }
387
388    #[tokio::test]
389    async fn registry_replaces_on_duplicate_name() {
390        struct AlternateSample;
391        impl RustTool for AlternateSample {
392            type Params = PathParams;
393            const NAME: &'static str = "sample";
394            const DESCRIPTION: &'static str = "Alternate sample";
395            async fn call(
396                &self,
397                params: Self::Params,
398                _ctx: &ToolContext,
399            ) -> Result<ToolOutput, ToolError> {
400                Ok(format!("alt: {}", params.path).into())
401            }
402        }
403
404        let mut d = ToolRegistry::new();
405        d.register(SampleTool);
406        d.register(AlternateSample);
407        assert_eq!(d.len(), 1);
408
409        let result = d
410            .dispatch("sample", serde_json::json!({"path": "x"}), &test_ctx())
411            .await;
412        assert_eq!(result.unwrap().content(), "alt: x");
413    }
414
415    #[tokio::test]
416    async fn registry_tool_returning_error() {
417        struct FailingTool;
418        impl RustTool for FailingTool {
419            type Params = EmptyParams;
420            const NAME: &'static str = "fail";
421            const DESCRIPTION: &'static str = "Always fails";
422            async fn call(
423                &self,
424                _params: Self::Params,
425                _ctx: &ToolContext,
426            ) -> Result<ToolOutput, ToolError> {
427                Err(ToolError::new("intentional failure"))
428            }
429        }
430
431        let mut d = ToolRegistry::new();
432        d.register(FailingTool);
433        let result = d.dispatch("fail", serde_json::json!({}), &test_ctx()).await;
434        assert_eq!(result.unwrap_err(), ToolError::new("intentional failure"));
435    }
436
437    #[test]
438    fn registry_debug_shows_tool_names() {
439        let mut d = ToolRegistry::new();
440        d.register(SampleTool);
441        let dbg = format!("{d:?}");
442        assert!(dbg.contains("ToolRegistry"));
443        assert!(dbg.contains("sample"));
444        assert!(dbg.contains("tool_count: 1"));
445    }
446
447    // ── Async-specific tests ────────────────────────────────────────
448
449    /// A tool that actually awaits a tokio sleep, proving async dispatch works.
450    struct AsyncSleepTool;
451
452    impl RustTool for AsyncSleepTool {
453        type Params = EmptyParams;
454        const NAME: &'static str = "async_sleep";
455        const DESCRIPTION: &'static str = "Sleeps briefly then returns.";
456
457        async fn call(
458            &self,
459            _params: Self::Params,
460            _ctx: &ToolContext,
461        ) -> Result<ToolOutput, ToolError> {
462            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
463            Ok("slept".into())
464        }
465    }
466
467    #[tokio::test]
468    async fn async_tool_with_tokio_sleep() {
469        let mut d = ToolRegistry::new();
470        d.register(AsyncSleepTool);
471        let result = d
472            .dispatch("async_sleep", serde_json::json!({}), &test_ctx())
473            .await;
474        assert_eq!(result.unwrap().content(), "slept");
475    }
476
477    /// A tool that reads a file using `tokio::fs`.
478    struct AsyncReadFileTool;
479
480    #[derive(Deserialize, schemars::JsonSchema)]
481    struct ReadFileParams {
482        /// Path to the file to read.
483        path: String,
484    }
485
486    impl RustTool for AsyncReadFileTool {
487        type Params = ReadFileParams;
488        const NAME: &'static str = "read_file";
489        const DESCRIPTION: &'static str = "Reads a file asynchronously.";
490
491        async fn call(
492            &self,
493            params: Self::Params,
494            _ctx: &ToolContext,
495        ) -> Result<ToolOutput, ToolError> {
496            tokio::fs::read_to_string(&params.path)
497                .await
498                .map(ToolOutput::from)
499                .map_err(|e| ToolError::new(format!("IO error: {e}")))
500        }
501    }
502
503    #[tokio::test]
504    async fn async_tool_with_tokio_fs() {
505        let tmp = tempfile::NamedTempFile::new().expect("create tempfile");
506        std::fs::write(tmp.path(), "hello async").expect("write tempfile");
507
508        let mut d = ToolRegistry::new();
509        d.register(AsyncReadFileTool);
510
511        let path_str = tmp.path().to_str().expect("path to str").to_owned();
512        let result = d
513            .dispatch(
514                "read_file",
515                serde_json::json!({"path": path_str}),
516                &test_ctx(),
517            )
518            .await;
519        assert_eq!(result.unwrap().content(), "hello async");
520    }
521
522    #[tokio::test]
523    async fn async_tool_tokio_fs_missing_file() {
524        let mut d = ToolRegistry::new();
525        d.register(AsyncReadFileTool);
526        let result = d
527            .dispatch(
528                "read_file",
529                serde_json::json!({"path": "/nonexistent/file.txt"}),
530                &test_ctx(),
531            )
532            .await;
533        let err = result.unwrap_err();
534        assert!(
535            err.message.contains("IO error"),
536            "Expected IO error, got: {err}"
537        );
538    }
539
540    /// A tool that uses a tokio channel to receive its result, proving
541    /// the full async machinery works end-to-end.
542    struct ChannelTool {
543        tx: tokio::sync::mpsc::Sender<String>,
544        rx: std::sync::Mutex<Option<tokio::sync::mpsc::Receiver<String>>>,
545    }
546
547    impl ChannelTool {
548        fn new() -> Self {
549            let (tx, rx) = tokio::sync::mpsc::channel(1);
550            Self {
551                tx,
552                rx: std::sync::Mutex::new(Some(rx)),
553            }
554        }
555    }
556
557    impl RustTool for ChannelTool {
558        type Params = EmptyParams;
559        const NAME: &'static str = "channel_tool";
560        const DESCRIPTION: &'static str = "Awaits a value from a channel.";
561
562        async fn call(
563            &self,
564            _params: Self::Params,
565            _ctx: &ToolContext,
566        ) -> Result<ToolOutput, ToolError> {
567            let mut rx = self
568                .rx
569                .lock()
570                .unwrap()
571                .take()
572                .ok_or_else(|| ToolError::new("channel already consumed"))?;
573            rx.recv()
574                .await
575                .map(ToolOutput::from)
576                .ok_or_else(|| ToolError::new("channel closed"))
577        }
578    }
579
580    #[tokio::test]
581    async fn async_tool_awaits_channel() {
582        let tool = ChannelTool::new();
583        let tx = tool.tx.clone();
584
585        let mut d = ToolRegistry::new();
586        d.register(tool);
587
588        // Send the value from another task.
589        let ctx = test_ctx();
590        let dispatch_future = d.dispatch("channel_tool", serde_json::json!({}), &ctx);
591        let send_future = async move {
592            tx.send("from_channel".to_string()).await.unwrap();
593        };
594
595        let (result, ()) = tokio::join!(dispatch_future, send_future);
596        assert_eq!(result.unwrap().content(), "from_channel");
597    }
598
599    // ── Concurrent dispatch tests ───────────────────────────────────
600
601    #[tokio::test]
602    async fn concurrent_dispatches_to_different_tools() {
603        let mut d = ToolRegistry::new();
604        d.register(SampleTool);
605        d.register(AsyncSleepTool);
606        d.register(RunCommandTool);
607
608        let ctx = test_ctx();
609        let (r1, r2, r3) = tokio::join!(
610            d.dispatch("sample", serde_json::json!({"path": "a"}), &ctx),
611            d.dispatch("async_sleep", serde_json::json!({}), &ctx),
612            d.dispatch("run_command", serde_json::json!({"command": "ls"}), &ctx),
613        );
614
615        assert_eq!(r1.unwrap().content(), "a");
616        assert_eq!(r2.unwrap().content(), "slept");
617        assert_eq!(r3.unwrap().content(), "Ran: ls");
618    }
619
620    #[tokio::test]
621    async fn concurrent_dispatches_to_same_tool() {
622        let mut d = ToolRegistry::new();
623        d.register(SampleTool);
624
625        let ctx = test_ctx();
626        let futs: Vec<_> = (0..10)
627            .map(|i| d.dispatch("sample", serde_json::json!({"path": format!("p{i}")}), &ctx))
628            .collect();
629
630        let results = futures::future::join_all(futs).await;
631        for (i, r) in results.into_iter().enumerate() {
632            assert_eq!(r.unwrap().content(), format!("p{i}"));
633        }
634    }
635
636    // ── Schema / doc comment tests ──────────────────────────────────
637
638    #[derive(Deserialize, schemars::JsonSchema)]
639    struct DocumentedParams {
640        /// The target hostname to connect to.
641        hostname: String,
642        /// Port number (1-65535).
643        port: u16,
644        /// Optional timeout in seconds.
645        #[serde(default)]
646        timeout: Option<f64>,
647    }
648
649    struct DocumentedTool;
650    impl RustTool for DocumentedTool {
651        type Params = DocumentedParams;
652        const NAME: &'static str = "connect";
653        const DESCRIPTION: &'static str = "Connects to a remote host.";
654        async fn call(&self, p: Self::Params, _ctx: &ToolContext) -> Result<ToolOutput, ToolError> {
655            Ok(format!("{}:{}:{:?}", p.hostname, p.port, p.timeout).into())
656        }
657    }
658
659    #[test]
660    fn schema_contains_field_descriptions() {
661        let def = definition_of(&DocumentedTool).expect("schema");
662        let schema = &def.parameter_schema;
663
664        // Check the properties contain our fields.
665        let props = schema["properties"].as_object().expect("properties object");
666        assert!(props.contains_key("hostname"), "missing hostname");
667        assert!(props.contains_key("port"), "missing port");
668        assert!(props.contains_key("timeout"), "missing timeout");
669
670        // Check the descriptions from doc comments made it through.
671        let hostname_desc = props["hostname"]["description"]
672            .as_str()
673            .expect("hostname description");
674        assert!(
675            hostname_desc.contains("hostname"),
676            "hostname description should mention 'hostname', got: {hostname_desc}"
677        );
678
679        let port_desc = props["port"]["description"]
680            .as_str()
681            .expect("port description");
682        assert!(
683            port_desc.contains("1-65535"),
684            "port description should mention range, got: {port_desc}"
685        );
686    }
687
688    #[test]
689    fn schema_required_vs_optional_fields() {
690        let def = definition_of(&DocumentedTool).expect("schema");
691        let schema = &def.parameter_schema;
692
693        let required = schema["required"]
694            .as_array()
695            .expect("required should be an array");
696
697        // hostname and port are required, timeout is Option → not required.
698        assert!(
699            required.iter().any(|v| v == "hostname"),
700            "hostname required"
701        );
702        assert!(required.iter().any(|v| v == "port"), "port required");
703        assert!(
704            !required.iter().any(|v| v == "timeout"),
705            "timeout should NOT be required"
706        );
707    }
708
709    #[tokio::test]
710    async fn dispatch_with_optional_field_missing() {
711        let mut d = ToolRegistry::new();
712        d.register(DocumentedTool);
713
714        // Dispatch without `timeout` (it has serde(default)).
715        let result = d
716            .dispatch(
717                "connect",
718                serde_json::json!({"hostname": "example.com", "port": 443}),
719                &test_ctx(),
720            )
721            .await;
722        assert_eq!(result.unwrap().content(), "example.com:443:None");
723    }
724
725    #[tokio::test]
726    async fn dispatch_with_optional_field_present() {
727        let mut d = ToolRegistry::new();
728        d.register(DocumentedTool);
729
730        let result = d
731            .dispatch(
732                "connect",
733                serde_json::json!({"hostname": "localhost", "port": 8080, "timeout": 30.0}),
734                &test_ctx(),
735            )
736            .await;
737        assert_eq!(result.unwrap().content(), "localhost:8080:Some(30.0)");
738    }
739
740    #[tokio::test]
741    async fn dispatch_with_extra_fields_ignored() {
742        // serde's default behavior ignores unknown fields.
743        let mut d = ToolRegistry::new();
744        d.register(SampleTool);
745
746        let result = d
747            .dispatch(
748                "sample",
749                serde_json::json!({"path": "/tmp/x", "unknown_field": 42}),
750                &test_ctx(),
751            )
752            .await;
753        assert_eq!(result.unwrap().content(), "/tmp/x");
754    }
755
756    // ── BoxFuture / ErasedTool edge case tests ──────────────────────
757
758    #[tokio::test]
759    async fn erased_dispatch_preserves_borrow_lifetime() {
760        // Ensures the BoxToolFuture lifetime is tied to &self correctly,
761        // i.e. the registry can be borrowed immutably while the future runs.
762        let mut d = ToolRegistry::new();
763        d.register(AsyncSleepTool);
764        d.register(SampleTool);
765
766        // Dispatch two calls on the same registry reference.
767        let r1 = d
768            .dispatch("async_sleep", serde_json::json!({}), &test_ctx())
769            .await;
770        let r2 = d
771            .dispatch("sample", serde_json::json!({"path": "test"}), &test_ctx())
772            .await;
773
774        assert_eq!(r1.unwrap().content(), "slept");
775        assert_eq!(r2.unwrap().content(), "test");
776    }
777
778    #[tokio::test]
779    async fn dispatch_returns_meaningful_error_for_wrong_type() {
780        let mut d = ToolRegistry::new();
781        d.register(RunCommandTool);
782
783        // `command` expects a String, pass an object instead.
784        let result = d
785            .dispatch(
786                "run_command",
787                serde_json::json!({"command": {"nested": "object"}}),
788                &test_ctx(),
789            )
790            .await;
791        let err = result.unwrap_err();
792        assert!(
793            err.message
794                .contains("Failed to deserialize tool parameters"),
795            "Error should mention deserialization failure, got: {err}"
796        );
797    }
798
799    // ── R7: #[llm_tool] on async fn ─────────────────────────────────────
800
801    /// Async tool defined with the `#[llm_tool]` proc macro. The body uses
802    /// `.await` to prove it runs in an async context.
803    #[llm_tool]
804    async fn async_delayed_echo(
805        /// The message to echo back.
806        message: String,
807    ) -> Result<String, ToolError> {
808        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
809        Ok(format!("echo: {message}"))
810    }
811
812    #[tokio::test]
813    async fn tool_macro_async_fn_dispatches_with_await() {
814        let mut d = ToolRegistry::new();
815        d.register(AsyncDelayedEcho);
816
817        let result = d
818            .dispatch(
819                "async_delayed_echo",
820                serde_json::json!({"message": "hello async"}),
821                &test_ctx(),
822            )
823            .await;
824        assert_eq!(result.unwrap().content(), "echo: hello async");
825    }
826
827    /// Async tool that reads a file via `tokio::fs`, proving real I/O works.
828    #[llm_tool]
829    async fn async_file_reader(
830        /// Path to read.
831        path: String,
832    ) -> Result<String, ToolError> {
833        tokio::fs::read_to_string(&path)
834            .await
835            .map_err(|e| ToolError::new(format!("IO error: {e}")))
836    }
837
838    #[tokio::test]
839    async fn tool_macro_async_fn_reads_file() {
840        let tmp = tempfile::NamedTempFile::new().expect("create tempfile");
841        std::fs::write(tmp.path(), "async macro content").expect("write");
842
843        let mut d = ToolRegistry::new();
844        d.register(AsyncFileReader);
845
846        let path_str = tmp.path().to_str().expect("path").to_owned();
847        let result = d
848            .dispatch(
849                "async_file_reader",
850                serde_json::json!({"path": path_str}),
851                &test_ctx(),
852            )
853            .await;
854        assert_eq!(result.unwrap().content(), "async macro content");
855    }
856
857    // ── R8: Option<T> auto-default via #[llm_tool] ────────────────────
858
859    /// Tool with an optional greeting parameter.
860    #[llm_tool]
861    fn greet_optional(
862        /// Name to greet.
863        name: String,
864        /// Custom greeting (defaults to None if omitted).
865        greeting: Option<String>,
866    ) -> Result<String, ToolError> {
867        let g = greeting.unwrap_or_else(|| "Hello".to_string());
868        Ok(format!("{g}, {name}!"))
869    }
870
871    #[test]
872    fn tool_macro_option_param_not_in_required() {
873        let def = definition_of(&GreetOptional).expect("schema");
874        let schema = &def.parameter_schema;
875
876        let required = schema["required"]
877            .as_array()
878            .expect("required should be an array");
879
880        // `name` is required, `greeting` is Option<T> → not required.
881        assert!(
882            required.iter().any(|v| v == "name"),
883            "'name' should be required, got: {required:?}"
884        );
885        assert!(
886            !required.iter().any(|v| v == "greeting"),
887            "'greeting' (Option<String>) should NOT be required, got: {required:?}"
888        );
889    }
890
891    #[tokio::test]
892    async fn tool_macro_option_param_missing_from_json() {
893        let mut d = ToolRegistry::new();
894        d.register(GreetOptional);
895
896        // Dispatch without the optional `greeting` field.
897        let result = d
898            .dispatch(
899                "greet_optional",
900                serde_json::json!({"name": "World"}),
901                &test_ctx(),
902            )
903            .await;
904        assert_eq!(result.unwrap().content(), "Hello, World!");
905    }
906
907    #[tokio::test]
908    async fn tool_macro_option_param_provided_in_json() {
909        let mut d = ToolRegistry::new();
910        d.register(GreetOptional);
911
912        // Dispatch with the optional `greeting` field present.
913        let result = d
914            .dispatch(
915                "greet_optional",
916                serde_json::json!({"name": "World", "greeting": "Hi"}),
917                &test_ctx(),
918            )
919            .await;
920        assert_eq!(result.unwrap().content(), "Hi, World!");
921    }
922
923    /// Tool combining async + Option<T> to verify both features work together.
924    #[llm_tool]
925    async fn async_optional_tool(
926        /// Required input.
927        input: String,
928        /// Optional suffix.
929        suffix: Option<String>,
930    ) -> Result<String, ToolError> {
931        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
932        let s = suffix.unwrap_or_default();
933        Ok(format!("{input}{s}"))
934    }
935
936    #[tokio::test]
937    async fn tool_macro_async_with_optional_param() {
938        let mut d = ToolRegistry::new();
939        d.register(AsyncOptionalTool);
940
941        // Without optional param.
942        let r1 = d
943            .dispatch(
944                "async_optional_tool",
945                serde_json::json!({"input": "base"}),
946                &test_ctx(),
947            )
948            .await;
949        assert_eq!(r1.unwrap().content(), "base");
950
951        // With optional param.
952        let r2 = d
953            .dispatch(
954                "async_optional_tool",
955                serde_json::json!({"input": "base", "suffix": "_ext"}),
956                &test_ctx(),
957            )
958            .await;
959        assert_eq!(r2.unwrap().content(), "base_ext");
960    }
961
962    #[test]
963    fn tool_macro_async_optional_schema_correctness() {
964        let def = definition_of(&AsyncOptionalTool).expect("schema");
965        let schema = &def.parameter_schema;
966
967        let required = schema["required"].as_array().expect("required array");
968        assert!(required.iter().any(|v| v == "input"), "'input' required");
969        assert!(
970            !required.iter().any(|v| v == "suffix"),
971            "'suffix' (Option) should NOT be required"
972        );
973    }
974
975    // ── IntoIterator tests ──────────────────────────────────────────
976
977    #[test]
978    fn into_iter_yields_all_tool_name_definition_pairs() {
979        let mut d = ToolRegistry::new();
980        d.register(SampleTool);
981        d.register(RunCommandTool);
982
983        let mut pairs: Vec<(String, String)> = (&d)
984            .into_iter()
985            .map(|(name, def)| (name.clone(), def.name))
986            .collect();
987        pairs.sort();
988
989        assert_eq!(pairs.len(), 2);
990        assert_eq!(pairs[0].0, "run_command");
991        assert_eq!(pairs[0].1, "run_command");
992        assert_eq!(pairs[1].0, "sample");
993        assert_eq!(pairs[1].1, "sample");
994    }
995
996    #[test]
997    fn into_iter_empty_registry_yields_nothing() {
998        let d = ToolRegistry::new();
999        let count = (&d).into_iter().count();
1000        assert_eq!(count, 0);
1001    }
1002
1003    #[test]
1004    fn into_iter_for_loop_syntax() {
1005        let mut d = ToolRegistry::new();
1006        d.register(SampleTool);
1007
1008        let mut found = false;
1009        for (name, def) in &d {
1010            if name == "sample" {
1011                assert_eq!(def.description, "A sample tool");
1012                found = true;
1013            }
1014        }
1015        assert!(found, "Expected to find 'sample' tool via for-in loop");
1016    }
1017
1018    // ── ToolContext tests ───────────────────────────────────────────
1019
1020    #[test]
1021    fn tool_context_conversation_id_none_by_default() {
1022        let ctx = ToolContext::new(None);
1023        assert!(ctx.conversation_id().is_none());
1024        assert!(!ctx.is_idle());
1025    }
1026
1027    #[test]
1028    fn tool_context_conversation_id_returns_value() {
1029        let ctx = ToolContext::new(Some("conv-123".to_owned()));
1030        assert_eq!(ctx.conversation_id(), Some("conv-123"));
1031    }
1032
1033    #[test]
1034    fn tool_context_get_set_state_roundtrip() {
1035        let ctx = ToolContext::new(None);
1036
1037        // Default for missing key.
1038        let val = ctx.get_state("missing", serde_json::json!("fallback"));
1039        assert_eq!(val, serde_json::json!("fallback"));
1040
1041        // Set and retrieve.
1042        ctx.set_state("counter", serde_json::json!(42))
1043            .expect("set_state");
1044        let val = ctx.get_state("counter", serde_json::json!(0));
1045        assert_eq!(val, serde_json::json!(42));
1046
1047        // Overwrite.
1048        ctx.set_state("counter", serde_json::json!(99))
1049            .expect("set_state");
1050        let val = ctx.get_state("counter", serde_json::json!(0));
1051        assert_eq!(val, serde_json::json!(99));
1052    }
1053
1054    #[test]
1055    fn tool_context_state_persists_across_reads() {
1056        let ctx = ToolContext::new(None);
1057        ctx.set_state("key", serde_json::json!({"nested": true}))
1058            .expect("set_state");
1059
1060        // Multiple reads return the same value.
1061        let v1 = ctx.get_state("key", serde_json::json!(null));
1062        let v2 = ctx.get_state("key", serde_json::json!(null));
1063        assert_eq!(v1, v2);
1064        assert_eq!(v1, serde_json::json!({"nested": true}));
1065    }
1066
1067    #[tokio::test]
1068    async fn dispatch_passes_context_to_tool() {
1069        /// A tool that reads from the `ToolContext` state.
1070        struct ContextAwareTool;
1071
1072        impl RustTool for ContextAwareTool {
1073            type Params = EmptyParams;
1074            const NAME: &'static str = "ctx_tool";
1075            const DESCRIPTION: &'static str = "Reads conversation_id from context.";
1076
1077            async fn call(
1078                &self,
1079                _params: Self::Params,
1080                ctx: &ToolContext,
1081            ) -> Result<ToolOutput, ToolError> {
1082                let conv = ctx.conversation_id().unwrap_or("none");
1083                let count = ctx.get_state("call_count", serde_json::json!(0));
1084                let n = count.as_i64().unwrap_or(0);
1085                ctx.set_state("call_count", serde_json::json!(n + 1))
1086                    .map_err(|e| ToolError::new(format!("set_state failed: {e}")))?;
1087                Ok(format!("conv={conv}, call={n}").into())
1088            }
1089        }
1090
1091        let mut d = ToolRegistry::new();
1092        d.register(ContextAwareTool);
1093
1094        let ctx = ToolContext::new(Some("test-conv".to_owned()));
1095
1096        // First call.
1097        let r1 = d.dispatch("ctx_tool", serde_json::json!({}), &ctx).await;
1098        assert_eq!(r1.unwrap().content(), "conv=test-conv, call=0");
1099
1100        // Second call — state persists.
1101        let r2 = d.dispatch("ctx_tool", serde_json::json!({}), &ctx).await;
1102        assert_eq!(r2.unwrap().content(), "conv=test-conv, call=1");
1103    }
1104
1105    // ── ToolOutput metadata tests ───────────────────────────────────
1106
1107    #[derive(serde::Serialize)]
1108    struct ProcessMeta {
1109        bytes_read: usize,
1110        source: String,
1111    }
1112
1113    /// A tool that attaches typed metadata to its output.
1114    struct MetadataTool;
1115
1116    impl RustTool for MetadataTool {
1117        type Params = PathParams;
1118        const NAME: &'static str = "metadata_tool";
1119        const DESCRIPTION: &'static str = "Returns output with metadata.";
1120
1121        async fn call(
1122            &self,
1123            params: Self::Params,
1124            _ctx: &ToolContext,
1125        ) -> Result<ToolOutput, ToolError> {
1126            ToolOutput::new(format!("processed: {}", params.path)).with_metadata(&ProcessMeta {
1127                bytes_read: 1024,
1128                source: params.path,
1129            })
1130        }
1131    }
1132
1133    #[tokio::test]
1134    async fn dispatch_preserves_tool_output_metadata() {
1135        let mut d = ToolRegistry::new();
1136        d.register(MetadataTool);
1137
1138        let result = d
1139            .dispatch(
1140                "metadata_tool",
1141                serde_json::json!({"path": "/etc/hosts"}),
1142                &test_ctx(),
1143            )
1144            .await
1145            .unwrap();
1146
1147        assert_eq!(result.content(), "processed: /etc/hosts");
1148        assert_eq!(result.metadata()["bytes_read"], 1024);
1149        assert_eq!(result.metadata()["source"], "/etc/hosts");
1150        assert_eq!(result.metadata().len(), 2);
1151    }
1152
1153    #[tokio::test]
1154    async fn dispatch_tool_output_display_uses_content() {
1155        let output = ToolOutput::new("hello world").with_meta("ignored", serde_json::json!(true));
1156        assert_eq!(output.to_string(), "hello world");
1157    }
1158
1159    #[tokio::test]
1160    async fn dispatch_tool_output_into_content_consumes() {
1161        let output = ToolOutput::new("owned").with_meta("key", serde_json::json!("val"));
1162        let content: String = output.into_content();
1163        assert_eq!(content, "owned");
1164    }
1165
1166    #[test]
1167    fn tool_output_from_str_has_empty_metadata() {
1168        let output: ToolOutput = "plain".into();
1169        assert_eq!(output.content(), "plain");
1170        assert!(output.metadata().is_empty());
1171    }
1172
1173    #[test]
1174    fn tool_output_from_string_has_empty_metadata() {
1175        let output: ToolOutput = "owned".to_string().into();
1176        assert_eq!(output.content(), "owned");
1177        assert!(output.metadata().is_empty());
1178    }
1179
1180    // ── ToolError metadata tests ────────────────────────────────────
1181
1182    #[test]
1183    fn tool_error_with_metadata() {
1184        let err = ToolError::new("HTTP request failed")
1185            .with_meta("status_code", serde_json::json!(503))
1186            .with_meta("url", serde_json::json!("https://example.com"));
1187
1188        assert_eq!(err.message, "HTTP request failed");
1189        assert_eq!(err.metadata()["status_code"], 503);
1190        assert_eq!(err.metadata()["url"], "https://example.com");
1191        assert_eq!(err.metadata().len(), 2);
1192    }
1193
1194    #[test]
1195    fn tool_error_without_metadata_is_empty() {
1196        let err = ToolError::new("simple error");
1197        assert!(err.metadata().is_empty());
1198    }
1199
1200    #[test]
1201    fn tool_error_display_ignores_metadata() {
1202        let err = ToolError::new("visible").with_meta("hidden", serde_json::json!(true));
1203        assert_eq!(err.to_string(), "visible");
1204    }
1205
1206    #[test]
1207    fn tool_error_equality_includes_metadata() {
1208        let a = ToolError::new("err").with_meta("k", serde_json::json!(1));
1209        let b = ToolError::new("err").with_meta("k", serde_json::json!(1));
1210        let c = ToolError::new("err").with_meta("k", serde_json::json!(2));
1211        assert_eq!(a, b);
1212        assert_ne!(a, c);
1213    }
1214
1215    /// A tool that returns `ToolError` with metadata.
1216    struct MetadataErrorTool;
1217
1218    impl RustTool for MetadataErrorTool {
1219        type Params = EmptyParams;
1220        const NAME: &'static str = "metadata_error_tool";
1221        const DESCRIPTION: &'static str = "Always fails with metadata.";
1222
1223        async fn call(
1224            &self,
1225            _params: Self::Params,
1226            _ctx: &ToolContext,
1227        ) -> Result<ToolOutput, ToolError> {
1228            Err(ToolError::new("service unavailable")
1229                .with_meta("retry_after_secs", serde_json::json!(30)))
1230        }
1231    }
1232
1233    #[tokio::test]
1234    async fn dispatch_preserves_tool_error_metadata() {
1235        let mut d = ToolRegistry::new();
1236        d.register(MetadataErrorTool);
1237
1238        let err = d
1239            .dispatch("metadata_error_tool", serde_json::json!({}), &test_ctx())
1240            .await
1241            .unwrap_err();
1242
1243        assert_eq!(err.message, "service unavailable");
1244        assert_eq!(err.metadata()["retry_after_secs"], 30);
1245    }
1246
1247    // ── #[llm_tool] macro returning ToolOutput directly ─────────────────
1248
1249    /// A tool that returns ToolOutput with metadata via the macro.
1250    #[llm_tool]
1251    fn tool_with_metadata(
1252        /// Input value.
1253        input: String,
1254    ) -> Result<ToolOutput, ToolError> {
1255        Ok(ToolOutput::new(format!("echoed: {input}"))
1256            .with_meta("input_len", serde_json::json!(input.len())))
1257    }
1258
1259    #[tokio::test]
1260    async fn macro_tool_returning_tool_output_preserves_metadata() {
1261        let mut d = ToolRegistry::new();
1262        d.register(ToolWithMetadata);
1263
1264        let result = d
1265            .dispatch(
1266                "tool_with_metadata",
1267                serde_json::json!({"input": "hello"}),
1268                &test_ctx(),
1269            )
1270            .await
1271            .unwrap();
1272
1273        assert_eq!(result.content(), "echoed: hello");
1274        assert_eq!(result.metadata()["input_len"], 5);
1275    }
1276
1277    // ── with_metadata struct-based tests ─────────────────────────────
1278
1279    #[test]
1280    fn tool_output_with_metadata_struct() {
1281        #[derive(serde::Serialize)]
1282        struct Meta {
1283            status: String,
1284            count: u32,
1285        }
1286
1287        let out = ToolOutput::new("done")
1288            .with_metadata(&Meta {
1289                status: "ok".into(),
1290                count: 42,
1291            })
1292            .unwrap();
1293
1294        assert_eq!(out.metadata()["status"], "ok");
1295        assert_eq!(out.metadata()["count"], 42);
1296        assert_eq!(out.metadata().len(), 2);
1297    }
1298
1299    #[test]
1300    fn tool_output_with_metadata_merges_with_existing() {
1301        #[derive(serde::Serialize)]
1302        struct Extra {
1303            source: String,
1304        }
1305
1306        let out = ToolOutput::new("data")
1307            .with_meta("version", serde_json::json!(1))
1308            .with_metadata(&Extra {
1309                source: "cache".into(),
1310            })
1311            .unwrap();
1312
1313        assert_eq!(out.metadata()["version"], 1);
1314        assert_eq!(out.metadata()["source"], "cache");
1315        assert_eq!(out.metadata().len(), 2);
1316    }
1317
1318    #[test]
1319    fn tool_output_with_metadata_rejects_non_object() {
1320        let err = ToolOutput::new("x").with_metadata(&42_i32).unwrap_err();
1321
1322        assert!(
1323            err.message.contains("JSON object"),
1324            "Expected object error, got: {err}"
1325        );
1326    }
1327
1328    #[test]
1329    fn tool_error_with_metadata_struct() {
1330        #[derive(serde::Serialize)]
1331        struct ErrorMeta {
1332            status_code: u16,
1333            url: String,
1334        }
1335
1336        let err = ToolError::new("HTTP request failed")
1337            .with_metadata(&ErrorMeta {
1338                status_code: 503,
1339                url: "https://example.com".into(),
1340            })
1341            .unwrap();
1342
1343        assert_eq!(err.message, "HTTP request failed");
1344        assert_eq!(err.metadata()["status_code"], 503);
1345        assert_eq!(err.metadata()["url"], "https://example.com");
1346        assert_eq!(err.metadata().len(), 2);
1347    }
1348
1349    // ── from_metadata tests ─────────────────────────────────────────
1350
1351    #[test]
1352    fn tool_output_from_metadata_populates_both() {
1353        #[derive(serde::Serialize)]
1354        struct Weather {
1355            location: String,
1356            temp_f: i32,
1357        }
1358
1359        let out = ToolOutput::from_metadata(&Weather {
1360            location: "Seattle".into(),
1361            temp_f: 72,
1362        })
1363        .unwrap();
1364
1365        // Content is the JSON string sent to the model.
1366        assert!(out.content().contains("Seattle"));
1367        assert!(out.content().contains("72"));
1368
1369        // Metadata has typed fields for hooks.
1370        assert_eq!(out.metadata()["location"], "Seattle");
1371        assert_eq!(out.metadata()["temp_f"], 72);
1372        assert_eq!(out.metadata().len(), 2);
1373    }
1374
1375    #[test]
1376    fn tool_output_from_metadata_rejects_non_object() {
1377        let err = ToolOutput::from_metadata(&"just a string").unwrap_err();
1378        assert!(
1379            err.message.contains("JSON object"),
1380            "Expected object error, got: {err}"
1381        );
1382    }
1383}