1use std::collections::HashMap;
4
5use super::{
6 rust_tool::{ErasedTool, RustTool},
7 types::{ToolContext, ToolDefinition, ToolError, ToolOutput},
8};
9
10pub struct ToolRegistry {
11 tools: HashMap<String, Box<dyn ErasedTool>>,
12}
13
14impl std::fmt::Debug for ToolRegistry {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 f.debug_struct("ToolRegistry")
17 .field("tool_count", &self.tools.len())
18 .field("tool_names", &self.tools.keys().collect::<Vec<_>>())
19 .finish()
20 }
21}
22
23impl Default for ToolRegistry {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl ToolRegistry {
30 #[must_use]
32 pub fn new() -> Self {
33 Self {
34 tools: HashMap::new(),
35 }
36 }
37
38 pub fn register<T: RustTool + 'static>(&mut self, tool: T) -> &mut Self {
42 self.tools.insert(T::NAME.to_string(), Box::new(tool));
43 self
44 }
45
46 #[must_use]
83 pub fn with_tool<T: RustTool + 'static>(mut self, tool: T) -> Self {
84 self.register(tool);
85 self
86 }
87
88 fn try_definition<'a>(
90 name: &'a str,
91 tool: &'a dyn ErasedTool,
92 ) -> Option<(&'a str, ToolDefinition)> {
93 match tool.definition() {
94 Ok(def) => Some((name, def)),
95 Err(e) => {
96 tracing::warn!(tool = %name, error = %e, "Skipping tool with invalid schema");
97 None
98 }
99 }
100 }
101
102 #[must_use]
106 pub fn definitions(&self) -> Vec<ToolDefinition> {
107 self.tools
108 .iter()
109 .filter_map(|(name, tool)| {
110 Self::try_definition(name, tool.as_ref()).map(|(_, def)| def)
111 })
112 .collect()
113 }
114
115 pub async fn dispatch(
121 &self,
122 name: &str,
123 args: serde_json::Value,
124 ctx: &ToolContext,
125 ) -> Result<ToolOutput, ToolError> {
126 let tool = self
127 .tools
128 .get(name)
129 .ok_or_else(|| ToolError::new(format!("Unknown tool: {name}")))?;
130 tool.call_erased(args, ctx).await
131 }
132
133 #[must_use]
135 pub fn len(&self) -> usize {
136 self.tools.len()
137 }
138
139 #[must_use]
141 pub fn is_empty(&self) -> bool {
142 self.tools.is_empty()
143 }
144
145 pub fn iter(&self) -> impl Iterator<Item = (&String, ToolDefinition)> + '_ {
149 self.tools.iter().filter_map(|(name, tool)| {
150 Self::try_definition(name, tool.as_ref()).map(|(_, def)| (name, def))
151 })
152 }
153}
154
155impl<'a> IntoIterator for &'a ToolRegistry {
159 type Item = (&'a String, ToolDefinition);
160 type IntoIter = Box<dyn Iterator<Item = (&'a String, ToolDefinition)> + 'a>;
161
162 fn into_iter(self) -> Self::IntoIter {
163 Box::new(self.tools.iter().filter_map(|(name, tool)| {
164 ToolRegistry::try_definition(name, tool.as_ref()).map(|(_, def)| (name, def))
165 }))
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use crate::llm_tool;
172 use serde::Deserialize;
173
174 use super::{
175 super::{EmptyParams, definition_of},
176 *,
177 };
178
179 fn test_ctx() -> ToolContext {
181 ToolContext::new(None)
182 }
183
184 #[derive(Deserialize, schemars::JsonSchema)]
187 struct PathParams {
188 path: String,
190 }
191
192 struct SampleTool;
193
194 impl RustTool for SampleTool {
195 type Params = PathParams;
196 const NAME: &'static str = "sample";
197 const DESCRIPTION: &'static str = "A sample tool";
198 async fn call(
199 &self,
200 params: Self::Params,
201 _ctx: &ToolContext,
202 ) -> Result<ToolOutput, ToolError> {
203 Ok(params.path.into())
204 }
205 }
206
207 #[derive(Deserialize, schemars::JsonSchema)]
208 struct RunCommandParams {
209 command: String,
211 #[serde(default)]
213 timeout: Option<i64>,
214 #[serde(default)]
216 env: Option<std::collections::HashMap<String, String>>,
217 }
218
219 struct RunCommandTool;
220
221 impl RustTool for RunCommandTool {
222 type Params = RunCommandParams;
223 const NAME: &'static str = "run_command";
224 const DESCRIPTION: &'static str = "Runs a command.";
225 async fn call(
226 &self,
227 params: Self::Params,
228 _ctx: &ToolContext,
229 ) -> Result<ToolOutput, ToolError> {
230 assert!(params.timeout.is_none());
231 assert!(params.env.is_none());
232 Ok(format!("Ran: {}", params.command).into())
233 }
234 }
235
236 #[test]
239 fn tool_definition_serde_roundtrip() {
240 let def = definition_of(&SampleTool).expect("schema");
241 let json = serde_json::to_string(&def).expect("serialize");
242 let parsed: ToolDefinition = serde_json::from_str(&json).expect("deserialize");
243 assert_eq!(parsed.name, def.name);
244 assert_eq!(parsed.description, def.description);
245 assert_eq!(parsed.parameter_schema, def.parameter_schema);
246 }
247
248 struct EmptyParamTool;
249 impl RustTool for EmptyParamTool {
250 type Params = EmptyParams;
251 const NAME: &'static str = "empty";
252 const DESCRIPTION: &'static str = "No params";
253 async fn call(
254 &self,
255 _params: Self::Params,
256 _ctx: &ToolContext,
257 ) -> Result<ToolOutput, ToolError> {
258 Ok("ok".into())
259 }
260 }
261
262 #[test]
263 fn tool_definition_with_empty_schema() {
264 let tool = definition_of(&EmptyParamTool).expect("schema");
265 let json = serde_json::to_string(&tool).expect("serialize");
266 let parsed: ToolDefinition = serde_json::from_str(&json).expect("deserialize");
267 let orig_json = serde_json::to_value(&tool.parameter_schema).unwrap();
269 let parsed_json = serde_json::to_value(&parsed.parameter_schema).unwrap();
270 assert_eq!(orig_json, parsed_json);
271 }
272
273 #[test]
274 fn tool_definition_with_complex_schema() {
275 let tool = definition_of(&RunCommandTool).expect("schema");
276 let schema_json = serde_json::to_value(&tool.parameter_schema).expect("schema to json");
277 let required = schema_json["required"]
279 .as_array()
280 .expect("required should be an array");
281 assert!(
282 required.iter().any(|v| v == "command"),
283 "'command' should be required, got: {required:?}"
284 );
285 }
286
287 #[tokio::test]
290 async fn registry_dispatch_valid_tool() {
291 let mut d = ToolRegistry::new();
292 d.register(SampleTool);
293 let result = d
294 .dispatch(
295 "sample",
296 serde_json::json!({"path": "/tmp/foo"}),
297 &test_ctx(),
298 )
299 .await;
300 assert_eq!(result.unwrap().content(), "/tmp/foo");
301 }
302
303 #[tokio::test]
304 async fn registry_dispatch_unknown_tool() {
305 let d = ToolRegistry::new();
306 let result = d
307 .dispatch("nonexistent", serde_json::json!({}), &test_ctx())
308 .await;
309 assert_eq!(
310 result.unwrap_err(),
311 ToolError::new("Unknown tool: nonexistent")
312 );
313 }
314
315 #[tokio::test]
316 async fn registry_dispatch_invalid_args() {
317 let mut d = ToolRegistry::new();
318 d.register(SampleTool);
319 let result = d
321 .dispatch("sample", serde_json::json!({"path": 42}), &test_ctx())
322 .await;
323 assert!(
324 result.is_err(),
325 "Expected deserialization error for wrong type"
326 );
327 let err = result.unwrap_err();
328 assert!(
329 err.message.contains("deserialize"),
330 "Error should mention deserialization, got: {err}"
331 );
332 }
333
334 #[tokio::test]
335 async fn registry_dispatch_missing_required_field() {
336 let mut d = ToolRegistry::new();
337 d.register(SampleTool);
338 let result = d
340 .dispatch("sample", serde_json::json!({}), &test_ctx())
341 .await;
342 assert!(result.is_err(), "Expected error for missing required field");
343 }
344
345 #[test]
346 fn registry_definitions_returns_all() {
347 let mut d = ToolRegistry::new();
348 d.register(SampleTool);
349 d.register(RunCommandTool);
350
351 let defs = d.definitions();
352 assert_eq!(defs.len(), 2);
353
354 let mut names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
355 names.sort_unstable();
356 assert_eq!(names, vec!["run_command", "sample"]);
357 }
358
359 #[test]
360 fn registry_register_chaining() {
361 let mut d = ToolRegistry::new();
362 d.register(SampleTool).register(RunCommandTool);
363 assert_eq!(d.len(), 2);
364 assert!(!d.is_empty());
365 }
366
367 #[test]
368 fn registry_with_tool_owned_chaining() {
369 let d = ToolRegistry::new()
370 .with_tool(SampleTool)
371 .with_tool(RunCommandTool);
372 assert_eq!(d.len(), 2);
373 assert!(!d.is_empty());
374
375 let defs = d.definitions();
376 let mut names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
377 names.sort_unstable();
378 assert_eq!(names, vec!["run_command", "sample"]);
379 }
380
381 #[test]
382 fn registry_default_is_empty() {
383 let d = ToolRegistry::default();
384 assert!(d.is_empty());
385 assert_eq!(d.len(), 0);
386 }
387
388 #[tokio::test]
389 async fn registry_replaces_on_duplicate_name() {
390 struct AlternateSample;
391 impl RustTool for AlternateSample {
392 type Params = PathParams;
393 const NAME: &'static str = "sample";
394 const DESCRIPTION: &'static str = "Alternate sample";
395 async fn call(
396 &self,
397 params: Self::Params,
398 _ctx: &ToolContext,
399 ) -> Result<ToolOutput, ToolError> {
400 Ok(format!("alt: {}", params.path).into())
401 }
402 }
403
404 let mut d = ToolRegistry::new();
405 d.register(SampleTool);
406 d.register(AlternateSample);
407 assert_eq!(d.len(), 1);
408
409 let result = d
410 .dispatch("sample", serde_json::json!({"path": "x"}), &test_ctx())
411 .await;
412 assert_eq!(result.unwrap().content(), "alt: x");
413 }
414
415 #[tokio::test]
416 async fn registry_tool_returning_error() {
417 struct FailingTool;
418 impl RustTool for FailingTool {
419 type Params = EmptyParams;
420 const NAME: &'static str = "fail";
421 const DESCRIPTION: &'static str = "Always fails";
422 async fn call(
423 &self,
424 _params: Self::Params,
425 _ctx: &ToolContext,
426 ) -> Result<ToolOutput, ToolError> {
427 Err(ToolError::new("intentional failure"))
428 }
429 }
430
431 let mut d = ToolRegistry::new();
432 d.register(FailingTool);
433 let result = d.dispatch("fail", serde_json::json!({}), &test_ctx()).await;
434 assert_eq!(result.unwrap_err(), ToolError::new("intentional failure"));
435 }
436
437 #[test]
438 fn registry_debug_shows_tool_names() {
439 let mut d = ToolRegistry::new();
440 d.register(SampleTool);
441 let dbg = format!("{d:?}");
442 assert!(dbg.contains("ToolRegistry"));
443 assert!(dbg.contains("sample"));
444 assert!(dbg.contains("tool_count: 1"));
445 }
446
447 struct AsyncSleepTool;
451
452 impl RustTool for AsyncSleepTool {
453 type Params = EmptyParams;
454 const NAME: &'static str = "async_sleep";
455 const DESCRIPTION: &'static str = "Sleeps briefly then returns.";
456
457 async fn call(
458 &self,
459 _params: Self::Params,
460 _ctx: &ToolContext,
461 ) -> Result<ToolOutput, ToolError> {
462 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
463 Ok("slept".into())
464 }
465 }
466
467 #[tokio::test]
468 async fn async_tool_with_tokio_sleep() {
469 let mut d = ToolRegistry::new();
470 d.register(AsyncSleepTool);
471 let result = d
472 .dispatch("async_sleep", serde_json::json!({}), &test_ctx())
473 .await;
474 assert_eq!(result.unwrap().content(), "slept");
475 }
476
477 struct AsyncReadFileTool;
479
480 #[derive(Deserialize, schemars::JsonSchema)]
481 struct ReadFileParams {
482 path: String,
484 }
485
486 impl RustTool for AsyncReadFileTool {
487 type Params = ReadFileParams;
488 const NAME: &'static str = "read_file";
489 const DESCRIPTION: &'static str = "Reads a file asynchronously.";
490
491 async fn call(
492 &self,
493 params: Self::Params,
494 _ctx: &ToolContext,
495 ) -> Result<ToolOutput, ToolError> {
496 tokio::fs::read_to_string(¶ms.path)
497 .await
498 .map(ToolOutput::from)
499 .map_err(|e| ToolError::new(format!("IO error: {e}")))
500 }
501 }
502
503 #[tokio::test]
504 async fn async_tool_with_tokio_fs() {
505 let tmp = tempfile::NamedTempFile::new().expect("create tempfile");
506 std::fs::write(tmp.path(), "hello async").expect("write tempfile");
507
508 let mut d = ToolRegistry::new();
509 d.register(AsyncReadFileTool);
510
511 let path_str = tmp.path().to_str().expect("path to str").to_owned();
512 let result = d
513 .dispatch(
514 "read_file",
515 serde_json::json!({"path": path_str}),
516 &test_ctx(),
517 )
518 .await;
519 assert_eq!(result.unwrap().content(), "hello async");
520 }
521
522 #[tokio::test]
523 async fn async_tool_tokio_fs_missing_file() {
524 let mut d = ToolRegistry::new();
525 d.register(AsyncReadFileTool);
526 let result = d
527 .dispatch(
528 "read_file",
529 serde_json::json!({"path": "/nonexistent/file.txt"}),
530 &test_ctx(),
531 )
532 .await;
533 let err = result.unwrap_err();
534 assert!(
535 err.message.contains("IO error"),
536 "Expected IO error, got: {err}"
537 );
538 }
539
540 struct ChannelTool {
543 tx: tokio::sync::mpsc::Sender<String>,
544 rx: std::sync::Mutex<Option<tokio::sync::mpsc::Receiver<String>>>,
545 }
546
547 impl ChannelTool {
548 fn new() -> Self {
549 let (tx, rx) = tokio::sync::mpsc::channel(1);
550 Self {
551 tx,
552 rx: std::sync::Mutex::new(Some(rx)),
553 }
554 }
555 }
556
557 impl RustTool for ChannelTool {
558 type Params = EmptyParams;
559 const NAME: &'static str = "channel_tool";
560 const DESCRIPTION: &'static str = "Awaits a value from a channel.";
561
562 async fn call(
563 &self,
564 _params: Self::Params,
565 _ctx: &ToolContext,
566 ) -> Result<ToolOutput, ToolError> {
567 let mut rx = self
568 .rx
569 .lock()
570 .unwrap()
571 .take()
572 .ok_or_else(|| ToolError::new("channel already consumed"))?;
573 rx.recv()
574 .await
575 .map(ToolOutput::from)
576 .ok_or_else(|| ToolError::new("channel closed"))
577 }
578 }
579
580 #[tokio::test]
581 async fn async_tool_awaits_channel() {
582 let tool = ChannelTool::new();
583 let tx = tool.tx.clone();
584
585 let mut d = ToolRegistry::new();
586 d.register(tool);
587
588 let ctx = test_ctx();
590 let dispatch_future = d.dispatch("channel_tool", serde_json::json!({}), &ctx);
591 let send_future = async move {
592 tx.send("from_channel".to_string()).await.unwrap();
593 };
594
595 let (result, ()) = tokio::join!(dispatch_future, send_future);
596 assert_eq!(result.unwrap().content(), "from_channel");
597 }
598
599 #[tokio::test]
602 async fn concurrent_dispatches_to_different_tools() {
603 let mut d = ToolRegistry::new();
604 d.register(SampleTool);
605 d.register(AsyncSleepTool);
606 d.register(RunCommandTool);
607
608 let ctx = test_ctx();
609 let (r1, r2, r3) = tokio::join!(
610 d.dispatch("sample", serde_json::json!({"path": "a"}), &ctx),
611 d.dispatch("async_sleep", serde_json::json!({}), &ctx),
612 d.dispatch("run_command", serde_json::json!({"command": "ls"}), &ctx),
613 );
614
615 assert_eq!(r1.unwrap().content(), "a");
616 assert_eq!(r2.unwrap().content(), "slept");
617 assert_eq!(r3.unwrap().content(), "Ran: ls");
618 }
619
620 #[tokio::test]
621 async fn concurrent_dispatches_to_same_tool() {
622 let mut d = ToolRegistry::new();
623 d.register(SampleTool);
624
625 let ctx = test_ctx();
626 let futs: Vec<_> = (0..10)
627 .map(|i| d.dispatch("sample", serde_json::json!({"path": format!("p{i}")}), &ctx))
628 .collect();
629
630 let results = futures::future::join_all(futs).await;
631 for (i, r) in results.into_iter().enumerate() {
632 assert_eq!(r.unwrap().content(), format!("p{i}"));
633 }
634 }
635
636 #[derive(Deserialize, schemars::JsonSchema)]
639 struct DocumentedParams {
640 hostname: String,
642 port: u16,
644 #[serde(default)]
646 timeout: Option<f64>,
647 }
648
649 struct DocumentedTool;
650 impl RustTool for DocumentedTool {
651 type Params = DocumentedParams;
652 const NAME: &'static str = "connect";
653 const DESCRIPTION: &'static str = "Connects to a remote host.";
654 async fn call(&self, p: Self::Params, _ctx: &ToolContext) -> Result<ToolOutput, ToolError> {
655 Ok(format!("{}:{}:{:?}", p.hostname, p.port, p.timeout).into())
656 }
657 }
658
659 #[test]
660 fn schema_contains_field_descriptions() {
661 let def = definition_of(&DocumentedTool).expect("schema");
662 let schema = &def.parameter_schema;
663
664 let props = schema["properties"].as_object().expect("properties object");
666 assert!(props.contains_key("hostname"), "missing hostname");
667 assert!(props.contains_key("port"), "missing port");
668 assert!(props.contains_key("timeout"), "missing timeout");
669
670 let hostname_desc = props["hostname"]["description"]
672 .as_str()
673 .expect("hostname description");
674 assert!(
675 hostname_desc.contains("hostname"),
676 "hostname description should mention 'hostname', got: {hostname_desc}"
677 );
678
679 let port_desc = props["port"]["description"]
680 .as_str()
681 .expect("port description");
682 assert!(
683 port_desc.contains("1-65535"),
684 "port description should mention range, got: {port_desc}"
685 );
686 }
687
688 #[test]
689 fn schema_required_vs_optional_fields() {
690 let def = definition_of(&DocumentedTool).expect("schema");
691 let schema = &def.parameter_schema;
692
693 let required = schema["required"]
694 .as_array()
695 .expect("required should be an array");
696
697 assert!(
699 required.iter().any(|v| v == "hostname"),
700 "hostname required"
701 );
702 assert!(required.iter().any(|v| v == "port"), "port required");
703 assert!(
704 !required.iter().any(|v| v == "timeout"),
705 "timeout should NOT be required"
706 );
707 }
708
709 #[tokio::test]
710 async fn dispatch_with_optional_field_missing() {
711 let mut d = ToolRegistry::new();
712 d.register(DocumentedTool);
713
714 let result = d
716 .dispatch(
717 "connect",
718 serde_json::json!({"hostname": "example.com", "port": 443}),
719 &test_ctx(),
720 )
721 .await;
722 assert_eq!(result.unwrap().content(), "example.com:443:None");
723 }
724
725 #[tokio::test]
726 async fn dispatch_with_optional_field_present() {
727 let mut d = ToolRegistry::new();
728 d.register(DocumentedTool);
729
730 let result = d
731 .dispatch(
732 "connect",
733 serde_json::json!({"hostname": "localhost", "port": 8080, "timeout": 30.0}),
734 &test_ctx(),
735 )
736 .await;
737 assert_eq!(result.unwrap().content(), "localhost:8080:Some(30.0)");
738 }
739
740 #[tokio::test]
741 async fn dispatch_with_extra_fields_ignored() {
742 let mut d = ToolRegistry::new();
744 d.register(SampleTool);
745
746 let result = d
747 .dispatch(
748 "sample",
749 serde_json::json!({"path": "/tmp/x", "unknown_field": 42}),
750 &test_ctx(),
751 )
752 .await;
753 assert_eq!(result.unwrap().content(), "/tmp/x");
754 }
755
756 #[tokio::test]
759 async fn erased_dispatch_preserves_borrow_lifetime() {
760 let mut d = ToolRegistry::new();
763 d.register(AsyncSleepTool);
764 d.register(SampleTool);
765
766 let r1 = d
768 .dispatch("async_sleep", serde_json::json!({}), &test_ctx())
769 .await;
770 let r2 = d
771 .dispatch("sample", serde_json::json!({"path": "test"}), &test_ctx())
772 .await;
773
774 assert_eq!(r1.unwrap().content(), "slept");
775 assert_eq!(r2.unwrap().content(), "test");
776 }
777
778 #[tokio::test]
779 async fn dispatch_returns_meaningful_error_for_wrong_type() {
780 let mut d = ToolRegistry::new();
781 d.register(RunCommandTool);
782
783 let result = d
785 .dispatch(
786 "run_command",
787 serde_json::json!({"command": {"nested": "object"}}),
788 &test_ctx(),
789 )
790 .await;
791 let err = result.unwrap_err();
792 assert!(
793 err.message
794 .contains("Failed to deserialize tool parameters"),
795 "Error should mention deserialization failure, got: {err}"
796 );
797 }
798
799 #[llm_tool]
804 async fn async_delayed_echo(
805 message: String,
807 ) -> Result<String, ToolError> {
808 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
809 Ok(format!("echo: {message}"))
810 }
811
812 #[tokio::test]
813 async fn tool_macro_async_fn_dispatches_with_await() {
814 let mut d = ToolRegistry::new();
815 d.register(AsyncDelayedEcho);
816
817 let result = d
818 .dispatch(
819 "async_delayed_echo",
820 serde_json::json!({"message": "hello async"}),
821 &test_ctx(),
822 )
823 .await;
824 assert_eq!(result.unwrap().content(), "echo: hello async");
825 }
826
827 #[llm_tool]
829 async fn async_file_reader(
830 path: String,
832 ) -> Result<String, ToolError> {
833 tokio::fs::read_to_string(&path)
834 .await
835 .map_err(|e| ToolError::new(format!("IO error: {e}")))
836 }
837
838 #[tokio::test]
839 async fn tool_macro_async_fn_reads_file() {
840 let tmp = tempfile::NamedTempFile::new().expect("create tempfile");
841 std::fs::write(tmp.path(), "async macro content").expect("write");
842
843 let mut d = ToolRegistry::new();
844 d.register(AsyncFileReader);
845
846 let path_str = tmp.path().to_str().expect("path").to_owned();
847 let result = d
848 .dispatch(
849 "async_file_reader",
850 serde_json::json!({"path": path_str}),
851 &test_ctx(),
852 )
853 .await;
854 assert_eq!(result.unwrap().content(), "async macro content");
855 }
856
857 #[llm_tool]
861 fn greet_optional(
862 name: String,
864 greeting: Option<String>,
866 ) -> Result<String, ToolError> {
867 let g = greeting.unwrap_or_else(|| "Hello".to_string());
868 Ok(format!("{g}, {name}!"))
869 }
870
871 #[test]
872 fn tool_macro_option_param_not_in_required() {
873 let def = definition_of(&GreetOptional).expect("schema");
874 let schema = &def.parameter_schema;
875
876 let required = schema["required"]
877 .as_array()
878 .expect("required should be an array");
879
880 assert!(
882 required.iter().any(|v| v == "name"),
883 "'name' should be required, got: {required:?}"
884 );
885 assert!(
886 !required.iter().any(|v| v == "greeting"),
887 "'greeting' (Option<String>) should NOT be required, got: {required:?}"
888 );
889 }
890
891 #[tokio::test]
892 async fn tool_macro_option_param_missing_from_json() {
893 let mut d = ToolRegistry::new();
894 d.register(GreetOptional);
895
896 let result = d
898 .dispatch(
899 "greet_optional",
900 serde_json::json!({"name": "World"}),
901 &test_ctx(),
902 )
903 .await;
904 assert_eq!(result.unwrap().content(), "Hello, World!");
905 }
906
907 #[tokio::test]
908 async fn tool_macro_option_param_provided_in_json() {
909 let mut d = ToolRegistry::new();
910 d.register(GreetOptional);
911
912 let result = d
914 .dispatch(
915 "greet_optional",
916 serde_json::json!({"name": "World", "greeting": "Hi"}),
917 &test_ctx(),
918 )
919 .await;
920 assert_eq!(result.unwrap().content(), "Hi, World!");
921 }
922
923 #[llm_tool]
925 async fn async_optional_tool(
926 input: String,
928 suffix: Option<String>,
930 ) -> Result<String, ToolError> {
931 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
932 let s = suffix.unwrap_or_default();
933 Ok(format!("{input}{s}"))
934 }
935
936 #[tokio::test]
937 async fn tool_macro_async_with_optional_param() {
938 let mut d = ToolRegistry::new();
939 d.register(AsyncOptionalTool);
940
941 let r1 = d
943 .dispatch(
944 "async_optional_tool",
945 serde_json::json!({"input": "base"}),
946 &test_ctx(),
947 )
948 .await;
949 assert_eq!(r1.unwrap().content(), "base");
950
951 let r2 = d
953 .dispatch(
954 "async_optional_tool",
955 serde_json::json!({"input": "base", "suffix": "_ext"}),
956 &test_ctx(),
957 )
958 .await;
959 assert_eq!(r2.unwrap().content(), "base_ext");
960 }
961
962 #[test]
963 fn tool_macro_async_optional_schema_correctness() {
964 let def = definition_of(&AsyncOptionalTool).expect("schema");
965 let schema = &def.parameter_schema;
966
967 let required = schema["required"].as_array().expect("required array");
968 assert!(required.iter().any(|v| v == "input"), "'input' required");
969 assert!(
970 !required.iter().any(|v| v == "suffix"),
971 "'suffix' (Option) should NOT be required"
972 );
973 }
974
975 #[test]
978 fn into_iter_yields_all_tool_name_definition_pairs() {
979 let mut d = ToolRegistry::new();
980 d.register(SampleTool);
981 d.register(RunCommandTool);
982
983 let mut pairs: Vec<(String, String)> = (&d)
984 .into_iter()
985 .map(|(name, def)| (name.clone(), def.name))
986 .collect();
987 pairs.sort();
988
989 assert_eq!(pairs.len(), 2);
990 assert_eq!(pairs[0].0, "run_command");
991 assert_eq!(pairs[0].1, "run_command");
992 assert_eq!(pairs[1].0, "sample");
993 assert_eq!(pairs[1].1, "sample");
994 }
995
996 #[test]
997 fn into_iter_empty_registry_yields_nothing() {
998 let d = ToolRegistry::new();
999 let count = (&d).into_iter().count();
1000 assert_eq!(count, 0);
1001 }
1002
1003 #[test]
1004 fn into_iter_for_loop_syntax() {
1005 let mut d = ToolRegistry::new();
1006 d.register(SampleTool);
1007
1008 let mut found = false;
1009 for (name, def) in &d {
1010 if name == "sample" {
1011 assert_eq!(def.description, "A sample tool");
1012 found = true;
1013 }
1014 }
1015 assert!(found, "Expected to find 'sample' tool via for-in loop");
1016 }
1017
1018 #[test]
1021 fn tool_context_conversation_id_none_by_default() {
1022 let ctx = ToolContext::new(None);
1023 assert!(ctx.conversation_id().is_none());
1024 assert!(!ctx.is_idle());
1025 }
1026
1027 #[test]
1028 fn tool_context_conversation_id_returns_value() {
1029 let ctx = ToolContext::new(Some("conv-123".to_owned()));
1030 assert_eq!(ctx.conversation_id(), Some("conv-123"));
1031 }
1032
1033 #[test]
1034 fn tool_context_get_set_state_roundtrip() {
1035 let ctx = ToolContext::new(None);
1036
1037 let val = ctx.get_state("missing", serde_json::json!("fallback"));
1039 assert_eq!(val, serde_json::json!("fallback"));
1040
1041 ctx.set_state("counter", serde_json::json!(42))
1043 .expect("set_state");
1044 let val = ctx.get_state("counter", serde_json::json!(0));
1045 assert_eq!(val, serde_json::json!(42));
1046
1047 ctx.set_state("counter", serde_json::json!(99))
1049 .expect("set_state");
1050 let val = ctx.get_state("counter", serde_json::json!(0));
1051 assert_eq!(val, serde_json::json!(99));
1052 }
1053
1054 #[test]
1055 fn tool_context_state_persists_across_reads() {
1056 let ctx = ToolContext::new(None);
1057 ctx.set_state("key", serde_json::json!({"nested": true}))
1058 .expect("set_state");
1059
1060 let v1 = ctx.get_state("key", serde_json::json!(null));
1062 let v2 = ctx.get_state("key", serde_json::json!(null));
1063 assert_eq!(v1, v2);
1064 assert_eq!(v1, serde_json::json!({"nested": true}));
1065 }
1066
1067 #[tokio::test]
1068 async fn dispatch_passes_context_to_tool() {
1069 struct ContextAwareTool;
1071
1072 impl RustTool for ContextAwareTool {
1073 type Params = EmptyParams;
1074 const NAME: &'static str = "ctx_tool";
1075 const DESCRIPTION: &'static str = "Reads conversation_id from context.";
1076
1077 async fn call(
1078 &self,
1079 _params: Self::Params,
1080 ctx: &ToolContext,
1081 ) -> Result<ToolOutput, ToolError> {
1082 let conv = ctx.conversation_id().unwrap_or("none");
1083 let count = ctx.get_state("call_count", serde_json::json!(0));
1084 let n = count.as_i64().unwrap_or(0);
1085 ctx.set_state("call_count", serde_json::json!(n + 1))
1086 .map_err(|e| ToolError::new(format!("set_state failed: {e}")))?;
1087 Ok(format!("conv={conv}, call={n}").into())
1088 }
1089 }
1090
1091 let mut d = ToolRegistry::new();
1092 d.register(ContextAwareTool);
1093
1094 let ctx = ToolContext::new(Some("test-conv".to_owned()));
1095
1096 let r1 = d.dispatch("ctx_tool", serde_json::json!({}), &ctx).await;
1098 assert_eq!(r1.unwrap().content(), "conv=test-conv, call=0");
1099
1100 let r2 = d.dispatch("ctx_tool", serde_json::json!({}), &ctx).await;
1102 assert_eq!(r2.unwrap().content(), "conv=test-conv, call=1");
1103 }
1104
1105 #[derive(serde::Serialize)]
1108 struct ProcessMeta {
1109 bytes_read: usize,
1110 source: String,
1111 }
1112
1113 struct MetadataTool;
1115
1116 impl RustTool for MetadataTool {
1117 type Params = PathParams;
1118 const NAME: &'static str = "metadata_tool";
1119 const DESCRIPTION: &'static str = "Returns output with metadata.";
1120
1121 async fn call(
1122 &self,
1123 params: Self::Params,
1124 _ctx: &ToolContext,
1125 ) -> Result<ToolOutput, ToolError> {
1126 ToolOutput::new(format!("processed: {}", params.path)).with_metadata(&ProcessMeta {
1127 bytes_read: 1024,
1128 source: params.path,
1129 })
1130 }
1131 }
1132
1133 #[tokio::test]
1134 async fn dispatch_preserves_tool_output_metadata() {
1135 let mut d = ToolRegistry::new();
1136 d.register(MetadataTool);
1137
1138 let result = d
1139 .dispatch(
1140 "metadata_tool",
1141 serde_json::json!({"path": "/etc/hosts"}),
1142 &test_ctx(),
1143 )
1144 .await
1145 .unwrap();
1146
1147 assert_eq!(result.content(), "processed: /etc/hosts");
1148 assert_eq!(result.metadata()["bytes_read"], 1024);
1149 assert_eq!(result.metadata()["source"], "/etc/hosts");
1150 assert_eq!(result.metadata().len(), 2);
1151 }
1152
1153 #[tokio::test]
1154 async fn dispatch_tool_output_display_uses_content() {
1155 let output = ToolOutput::new("hello world").with_meta("ignored", serde_json::json!(true));
1156 assert_eq!(output.to_string(), "hello world");
1157 }
1158
1159 #[tokio::test]
1160 async fn dispatch_tool_output_into_content_consumes() {
1161 let output = ToolOutput::new("owned").with_meta("key", serde_json::json!("val"));
1162 let content: String = output.into_content();
1163 assert_eq!(content, "owned");
1164 }
1165
1166 #[test]
1167 fn tool_output_from_str_has_empty_metadata() {
1168 let output: ToolOutput = "plain".into();
1169 assert_eq!(output.content(), "plain");
1170 assert!(output.metadata().is_empty());
1171 }
1172
1173 #[test]
1174 fn tool_output_from_string_has_empty_metadata() {
1175 let output: ToolOutput = "owned".to_string().into();
1176 assert_eq!(output.content(), "owned");
1177 assert!(output.metadata().is_empty());
1178 }
1179
1180 #[test]
1183 fn tool_error_with_metadata() {
1184 let err = ToolError::new("HTTP request failed")
1185 .with_meta("status_code", serde_json::json!(503))
1186 .with_meta("url", serde_json::json!("https://example.com"));
1187
1188 assert_eq!(err.message, "HTTP request failed");
1189 assert_eq!(err.metadata()["status_code"], 503);
1190 assert_eq!(err.metadata()["url"], "https://example.com");
1191 assert_eq!(err.metadata().len(), 2);
1192 }
1193
1194 #[test]
1195 fn tool_error_without_metadata_is_empty() {
1196 let err = ToolError::new("simple error");
1197 assert!(err.metadata().is_empty());
1198 }
1199
1200 #[test]
1201 fn tool_error_display_ignores_metadata() {
1202 let err = ToolError::new("visible").with_meta("hidden", serde_json::json!(true));
1203 assert_eq!(err.to_string(), "visible");
1204 }
1205
1206 #[test]
1207 fn tool_error_equality_includes_metadata() {
1208 let a = ToolError::new("err").with_meta("k", serde_json::json!(1));
1209 let b = ToolError::new("err").with_meta("k", serde_json::json!(1));
1210 let c = ToolError::new("err").with_meta("k", serde_json::json!(2));
1211 assert_eq!(a, b);
1212 assert_ne!(a, c);
1213 }
1214
1215 struct MetadataErrorTool;
1217
1218 impl RustTool for MetadataErrorTool {
1219 type Params = EmptyParams;
1220 const NAME: &'static str = "metadata_error_tool";
1221 const DESCRIPTION: &'static str = "Always fails with metadata.";
1222
1223 async fn call(
1224 &self,
1225 _params: Self::Params,
1226 _ctx: &ToolContext,
1227 ) -> Result<ToolOutput, ToolError> {
1228 Err(ToolError::new("service unavailable")
1229 .with_meta("retry_after_secs", serde_json::json!(30)))
1230 }
1231 }
1232
1233 #[tokio::test]
1234 async fn dispatch_preserves_tool_error_metadata() {
1235 let mut d = ToolRegistry::new();
1236 d.register(MetadataErrorTool);
1237
1238 let err = d
1239 .dispatch("metadata_error_tool", serde_json::json!({}), &test_ctx())
1240 .await
1241 .unwrap_err();
1242
1243 assert_eq!(err.message, "service unavailable");
1244 assert_eq!(err.metadata()["retry_after_secs"], 30);
1245 }
1246
1247 #[llm_tool]
1251 fn tool_with_metadata(
1252 input: String,
1254 ) -> Result<ToolOutput, ToolError> {
1255 Ok(ToolOutput::new(format!("echoed: {input}"))
1256 .with_meta("input_len", serde_json::json!(input.len())))
1257 }
1258
1259 #[tokio::test]
1260 async fn macro_tool_returning_tool_output_preserves_metadata() {
1261 let mut d = ToolRegistry::new();
1262 d.register(ToolWithMetadata);
1263
1264 let result = d
1265 .dispatch(
1266 "tool_with_metadata",
1267 serde_json::json!({"input": "hello"}),
1268 &test_ctx(),
1269 )
1270 .await
1271 .unwrap();
1272
1273 assert_eq!(result.content(), "echoed: hello");
1274 assert_eq!(result.metadata()["input_len"], 5);
1275 }
1276
1277 #[test]
1280 fn tool_output_with_metadata_struct() {
1281 #[derive(serde::Serialize)]
1282 struct Meta {
1283 status: String,
1284 count: u32,
1285 }
1286
1287 let out = ToolOutput::new("done")
1288 .with_metadata(&Meta {
1289 status: "ok".into(),
1290 count: 42,
1291 })
1292 .unwrap();
1293
1294 assert_eq!(out.metadata()["status"], "ok");
1295 assert_eq!(out.metadata()["count"], 42);
1296 assert_eq!(out.metadata().len(), 2);
1297 }
1298
1299 #[test]
1300 fn tool_output_with_metadata_merges_with_existing() {
1301 #[derive(serde::Serialize)]
1302 struct Extra {
1303 source: String,
1304 }
1305
1306 let out = ToolOutput::new("data")
1307 .with_meta("version", serde_json::json!(1))
1308 .with_metadata(&Extra {
1309 source: "cache".into(),
1310 })
1311 .unwrap();
1312
1313 assert_eq!(out.metadata()["version"], 1);
1314 assert_eq!(out.metadata()["source"], "cache");
1315 assert_eq!(out.metadata().len(), 2);
1316 }
1317
1318 #[test]
1319 fn tool_output_with_metadata_rejects_non_object() {
1320 let err = ToolOutput::new("x").with_metadata(&42_i32).unwrap_err();
1321
1322 assert!(
1323 err.message.contains("JSON object"),
1324 "Expected object error, got: {err}"
1325 );
1326 }
1327
1328 #[test]
1329 fn tool_error_with_metadata_struct() {
1330 #[derive(serde::Serialize)]
1331 struct ErrorMeta {
1332 status_code: u16,
1333 url: String,
1334 }
1335
1336 let err = ToolError::new("HTTP request failed")
1337 .with_metadata(&ErrorMeta {
1338 status_code: 503,
1339 url: "https://example.com".into(),
1340 })
1341 .unwrap();
1342
1343 assert_eq!(err.message, "HTTP request failed");
1344 assert_eq!(err.metadata()["status_code"], 503);
1345 assert_eq!(err.metadata()["url"], "https://example.com");
1346 assert_eq!(err.metadata().len(), 2);
1347 }
1348
1349 #[test]
1352 fn tool_output_from_metadata_populates_both() {
1353 #[derive(serde::Serialize)]
1354 struct Weather {
1355 location: String,
1356 temp_f: i32,
1357 }
1358
1359 let out = ToolOutput::from_metadata(&Weather {
1360 location: "Seattle".into(),
1361 temp_f: 72,
1362 })
1363 .unwrap();
1364
1365 assert!(out.content().contains("Seattle"));
1367 assert!(out.content().contains("72"));
1368
1369 assert_eq!(out.metadata()["location"], "Seattle");
1371 assert_eq!(out.metadata()["temp_f"], 72);
1372 assert_eq!(out.metadata().len(), 2);
1373 }
1374
1375 #[test]
1376 fn tool_output_from_metadata_rejects_non_object() {
1377 let err = ToolOutput::from_metadata(&"just a string").unwrap_err();
1378 assert!(
1379 err.message.contains("JSON object"),
1380 "Expected object error, got: {err}"
1381 );
1382 }
1383}