1use std::collections::HashMap;
4
5use super::{
6 rust_tool::{ErasedTool, RustTool, definition_of},
7 types::{ToolContext, ToolDefinition, ToolError, ToolOutput},
8};
9
10struct RegisteredTool {
16 definition: ToolDefinition,
17 erased: Box<dyn ErasedTool>,
18}
19
20pub struct ToolRegistry {
21 tools: HashMap<&'static str, RegisteredTool>,
22}
23
24impl std::fmt::Debug for ToolRegistry {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 let names: Vec<&str> = self
27 .tools
28 .values()
29 .map(|r| r.definition.name.as_str())
30 .collect();
31 f.debug_struct("ToolRegistry")
32 .field("tool_count", &self.tools.len())
33 .field("tool_names", &names)
34 .finish()
35 }
36}
37
38impl Default for ToolRegistry {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl ToolRegistry {
45 #[must_use]
47 pub fn new() -> Self {
48 Self {
49 tools: HashMap::new(),
50 }
51 }
52
53 pub fn register<T: RustTool + 'static>(&mut self, tool: T) -> &mut Self {
65 let definition = definition_of(&tool)
66 .unwrap_or_else(|e| panic!("Failed to build definition for tool '{}': {e}", T::NAME));
67 self.tools.insert(
68 T::NAME,
69 RegisteredTool {
70 definition,
71 erased: Box::new(tool),
72 },
73 );
74 self
75 }
76
77 #[must_use]
114 pub fn with_tool<T: RustTool + 'static>(mut self, tool: T) -> Self {
115 self.register(tool);
116 self
117 }
118
119 #[must_use]
123 pub fn definitions(&self) -> Vec<ToolDefinition> {
124 self.tools
125 .values()
126 .map(|entry| entry.definition.clone())
127 .collect()
128 }
129
130 pub async fn dispatch(
136 &self,
137 name: &str,
138 args: serde_json::Value,
139 ctx: &ToolContext,
140 ) -> Result<ToolOutput, ToolError> {
141 let entry = self
142 .tools
143 .get(name)
144 .ok_or_else(|| ToolError::new(format!("Unknown tool: {name}")))?;
145 entry.erased.call_erased(args, ctx).await
146 }
147
148 #[must_use]
150 pub fn len(&self) -> usize {
151 self.tools.len()
152 }
153
154 #[must_use]
156 pub fn is_empty(&self) -> bool {
157 self.tools.is_empty()
158 }
159
160 pub fn iter(&self) -> impl Iterator<Item = (&'static str, ToolDefinition)> + '_ {
164 self.tools
165 .iter()
166 .map(|(name, entry)| (*name, entry.definition.clone()))
167 }
168}
169
170impl<'a> IntoIterator for &'a ToolRegistry {
174 type Item = (&'static str, ToolDefinition);
175 type IntoIter = Box<dyn Iterator<Item = (&'static str, ToolDefinition)> + 'a>;
176
177 fn into_iter(self) -> Self::IntoIter {
178 Box::new(
179 self.tools
180 .iter()
181 .map(|(name, entry)| (*name, entry.definition.clone())),
182 )
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use serde::Deserialize;
189
190 use super::{
191 super::{EmptyParams, definition_of},
192 *,
193 };
194 use crate::llm_tool;
195
196 fn test_ctx() -> ToolContext {
198 ToolContext::new(None)
199 }
200
201 #[derive(Deserialize, schemars::JsonSchema)]
204 struct PathParams {
205 path: String,
207 }
208
209 struct SampleTool;
210
211 impl RustTool for SampleTool {
212 type Params = PathParams;
213 const NAME: &'static str = "sample";
214 const DESCRIPTION: &'static str = "A sample tool";
215 async fn call(
216 &self,
217 params: Self::Params,
218 _ctx: &ToolContext,
219 ) -> Result<ToolOutput, ToolError> {
220 Ok(params.path.into())
221 }
222 }
223
224 #[derive(Deserialize, schemars::JsonSchema)]
225 struct RunCommandParams {
226 command: String,
228 #[serde(default)]
230 timeout: Option<i64>,
231 #[serde(default)]
233 env: Option<std::collections::HashMap<String, String>>,
234 }
235
236 struct RunCommandTool;
237
238 impl RustTool for RunCommandTool {
239 type Params = RunCommandParams;
240 const NAME: &'static str = "run_command";
241 const DESCRIPTION: &'static str = "Runs a command.";
242 async fn call(
243 &self,
244 params: Self::Params,
245 _ctx: &ToolContext,
246 ) -> Result<ToolOutput, ToolError> {
247 assert!(params.timeout.is_none());
248 assert!(params.env.is_none());
249 Ok(format!("Ran: {}", params.command).into())
250 }
251 }
252
253 #[test]
256 fn tool_definition_serde_roundtrip() {
257 let def = definition_of(&SampleTool).expect("schema");
258 let json = serde_json::to_string(&def).expect("serialize");
259 let parsed: ToolDefinition = serde_json::from_str(&json).expect("deserialize");
260 assert_eq!(parsed.name, def.name);
261 assert_eq!(parsed.description, def.description);
262 assert_eq!(parsed.parameter_schema, def.parameter_schema);
263 }
264
265 struct EmptyParamTool;
266 impl RustTool for EmptyParamTool {
267 type Params = EmptyParams;
268 const NAME: &'static str = "empty";
269 const DESCRIPTION: &'static str = "No params";
270 async fn call(
271 &self,
272 _params: Self::Params,
273 _ctx: &ToolContext,
274 ) -> Result<ToolOutput, ToolError> {
275 Ok("ok".into())
276 }
277 }
278
279 #[test]
280 fn tool_definition_with_empty_schema() {
281 let tool = definition_of(&EmptyParamTool).expect("schema");
282 let json = serde_json::to_string(&tool).expect("serialize");
283 let parsed: ToolDefinition = serde_json::from_str(&json).expect("deserialize");
284 let orig_json = serde_json::to_value(&tool.parameter_schema).unwrap();
286 let parsed_json = serde_json::to_value(&parsed.parameter_schema).unwrap();
287 assert_eq!(orig_json, parsed_json);
288 }
289
290 #[test]
291 fn tool_definition_with_complex_schema() {
292 let tool = definition_of(&RunCommandTool).expect("schema");
293 let schema_json = serde_json::to_value(&tool.parameter_schema).expect("schema to json");
294 let required = schema_json["required"]
296 .as_array()
297 .expect("required should be an array");
298 assert!(
299 required.iter().any(|v| v == "command"),
300 "'command' should be required, got: {required:?}"
301 );
302 }
303
304 #[tokio::test]
307 async fn registry_dispatch_valid_tool() {
308 let mut d = ToolRegistry::new();
309 d.register(SampleTool);
310 let result = d
311 .dispatch(
312 "sample",
313 serde_json::json!({"path": "/tmp/foo"}),
314 &test_ctx(),
315 )
316 .await;
317 assert_eq!(result.unwrap().content(), "/tmp/foo");
318 }
319
320 #[tokio::test]
321 async fn registry_dispatch_unknown_tool() {
322 let d = ToolRegistry::new();
323 let result = d
324 .dispatch("nonexistent", serde_json::json!({}), &test_ctx())
325 .await;
326 assert_eq!(
327 result.unwrap_err(),
328 ToolError::new("Unknown tool: nonexistent")
329 );
330 }
331
332 #[tokio::test]
333 async fn registry_dispatch_invalid_args() {
334 let mut d = ToolRegistry::new();
335 d.register(SampleTool);
336 let result = d
338 .dispatch("sample", serde_json::json!({"path": 42}), &test_ctx())
339 .await;
340 let err = result.unwrap_err();
341 assert!(
342 err.message.contains("deserialize"),
343 "Error should mention deserialization, got: {err}"
344 );
345 }
346
347 #[tokio::test]
348 async fn registry_dispatch_missing_required_field() {
349 let mut d = ToolRegistry::new();
350 d.register(SampleTool);
351 let err = d
353 .dispatch("sample", serde_json::json!({}), &test_ctx())
354 .await
355 .expect_err("Expected error for missing required field");
356 assert!(
357 err.message.contains("missing field"),
358 "Error should mention missing field, got: {err}"
359 );
360 }
361
362 #[test]
363 fn registry_definitions_returns_all() {
364 let mut d = ToolRegistry::new();
365 d.register(SampleTool);
366 d.register(RunCommandTool);
367
368 let defs = d.definitions();
369 assert_eq!(defs.len(), 2);
370
371 let mut names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
372 names.sort_unstable();
373 assert_eq!(names, vec!["run_command", "sample"]);
374 }
375
376 #[test]
377 fn registry_register_chaining() {
378 let mut d = ToolRegistry::new();
379 d.register(SampleTool).register(RunCommandTool);
380 assert_eq!(d.len(), 2);
381 assert!(!d.is_empty());
382 }
383
384 #[test]
385 fn registry_with_tool_owned_chaining() {
386 let d = ToolRegistry::new()
387 .with_tool(SampleTool)
388 .with_tool(RunCommandTool);
389 assert_eq!(d.len(), 2);
390 assert!(!d.is_empty());
391
392 let defs = d.definitions();
393 let mut names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
394 names.sort_unstable();
395 assert_eq!(names, vec!["run_command", "sample"]);
396 }
397
398 #[test]
399 fn registry_default_is_empty() {
400 let d = ToolRegistry::default();
401 assert!(d.is_empty());
402 assert_eq!(d.len(), 0);
403 }
404
405 #[tokio::test]
406 async fn registry_replaces_on_duplicate_name() {
407 struct AlternateSample;
408 impl RustTool for AlternateSample {
409 type Params = PathParams;
410 const NAME: &'static str = "sample";
411 const DESCRIPTION: &'static str = "Alternate sample";
412 async fn call(
413 &self,
414 params: Self::Params,
415 _ctx: &ToolContext,
416 ) -> Result<ToolOutput, ToolError> {
417 Ok(format!("alt: {}", params.path).into())
418 }
419 }
420
421 let mut d = ToolRegistry::new();
422 d.register(SampleTool);
423 d.register(AlternateSample);
424 assert_eq!(d.len(), 1);
425
426 let result = d
427 .dispatch("sample", serde_json::json!({"path": "x"}), &test_ctx())
428 .await;
429 assert_eq!(result.unwrap().content(), "alt: x");
430 }
431
432 #[tokio::test]
433 async fn registry_tool_returning_error() {
434 struct FailingTool;
435 impl RustTool for FailingTool {
436 type Params = EmptyParams;
437 const NAME: &'static str = "fail";
438 const DESCRIPTION: &'static str = "Always fails";
439 async fn call(
440 &self,
441 _params: Self::Params,
442 _ctx: &ToolContext,
443 ) -> Result<ToolOutput, ToolError> {
444 Err(ToolError::new("intentional failure"))
445 }
446 }
447
448 let mut d = ToolRegistry::new();
449 d.register(FailingTool);
450 let result = d.dispatch("fail", serde_json::json!({}), &test_ctx()).await;
451 assert_eq!(result.unwrap_err(), ToolError::new("intentional failure"));
452 }
453
454 #[test]
455 fn registry_debug_shows_tool_names() {
456 let mut d = ToolRegistry::new();
457 d.register(SampleTool);
458 let dbg = format!("{d:?}");
459 assert!(dbg.contains("ToolRegistry"));
460 assert!(dbg.contains("sample"));
461 assert!(dbg.contains("tool_count: 1"));
462 }
463
464 struct AsyncSleepTool;
468
469 impl RustTool for AsyncSleepTool {
470 type Params = EmptyParams;
471 const NAME: &'static str = "async_sleep";
472 const DESCRIPTION: &'static str = "Sleeps briefly then returns.";
473
474 async fn call(
475 &self,
476 _params: Self::Params,
477 _ctx: &ToolContext,
478 ) -> Result<ToolOutput, ToolError> {
479 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
480 Ok("slept".into())
481 }
482 }
483
484 #[tokio::test]
485 async fn async_tool_with_tokio_sleep() {
486 let mut d = ToolRegistry::new();
487 d.register(AsyncSleepTool);
488 let result = d
489 .dispatch("async_sleep", serde_json::json!({}), &test_ctx())
490 .await;
491 assert_eq!(result.unwrap().content(), "slept");
492 }
493
494 struct AsyncReadFileTool;
496
497 #[derive(Deserialize, schemars::JsonSchema)]
498 struct ReadFileParams {
499 path: String,
501 }
502
503 impl RustTool for AsyncReadFileTool {
504 type Params = ReadFileParams;
505 const NAME: &'static str = "read_file";
506 const DESCRIPTION: &'static str = "Reads a file asynchronously.";
507
508 async fn call(
509 &self,
510 params: Self::Params,
511 _ctx: &ToolContext,
512 ) -> Result<ToolOutput, ToolError> {
513 tokio::fs::read_to_string(¶ms.path)
514 .await
515 .map(ToolOutput::from)
516 .map_err(|e| ToolError::new(format!("IO error: {e}")))
517 }
518 }
519
520 #[tokio::test]
521 async fn async_tool_with_tokio_fs() {
522 let tmp = tempfile::NamedTempFile::new().expect("create tempfile");
523 std::fs::write(tmp.path(), "hello async").expect("write tempfile");
524
525 let mut d = ToolRegistry::new();
526 d.register(AsyncReadFileTool);
527
528 let path_str = tmp.path().to_str().expect("path to str").to_owned();
529 let result = d
530 .dispatch(
531 "read_file",
532 serde_json::json!({"path": path_str}),
533 &test_ctx(),
534 )
535 .await;
536 assert_eq!(result.unwrap().content(), "hello async");
537 }
538
539 #[tokio::test]
540 async fn async_tool_tokio_fs_missing_file() {
541 let mut d = ToolRegistry::new();
542 d.register(AsyncReadFileTool);
543 let result = d
544 .dispatch(
545 "read_file",
546 serde_json::json!({"path": "/nonexistent/file.txt"}),
547 &test_ctx(),
548 )
549 .await;
550 let err = result.unwrap_err();
551 assert!(
552 err.message.contains("IO error"),
553 "Expected IO error, got: {err}"
554 );
555 }
556
557 struct ChannelTool {
560 tx: tokio::sync::mpsc::Sender<String>,
561 rx: std::sync::Mutex<Option<tokio::sync::mpsc::Receiver<String>>>,
562 }
563
564 impl ChannelTool {
565 fn new() -> Self {
566 let (tx, rx) = tokio::sync::mpsc::channel(1);
567 Self {
568 tx,
569 rx: std::sync::Mutex::new(Some(rx)),
570 }
571 }
572 }
573
574 impl RustTool for ChannelTool {
575 type Params = EmptyParams;
576 const NAME: &'static str = "channel_tool";
577 const DESCRIPTION: &'static str = "Awaits a value from a channel.";
578
579 async fn call(
580 &self,
581 _params: Self::Params,
582 _ctx: &ToolContext,
583 ) -> Result<ToolOutput, ToolError> {
584 let mut rx = self
585 .rx
586 .lock()
587 .unwrap()
588 .take()
589 .ok_or_else(|| ToolError::new("channel already consumed"))?;
590 rx.recv()
591 .await
592 .map(ToolOutput::from)
593 .ok_or_else(|| ToolError::new("channel closed"))
594 }
595 }
596
597 #[tokio::test]
598 async fn async_tool_awaits_channel() {
599 let tool = ChannelTool::new();
600 let tx = tool.tx.clone();
601
602 let mut d = ToolRegistry::new();
603 d.register(tool);
604
605 let ctx = test_ctx();
607 let dispatch_future = d.dispatch("channel_tool", serde_json::json!({}), &ctx);
608 let send_future = async move {
609 tx.send("from_channel".to_string()).await.unwrap();
610 };
611
612 let (result, ()) = tokio::join!(dispatch_future, send_future);
613 assert_eq!(result.unwrap().content(), "from_channel");
614 }
615
616 #[tokio::test]
619 async fn concurrent_dispatches_to_different_tools() {
620 let mut d = ToolRegistry::new();
621 d.register(SampleTool);
622 d.register(AsyncSleepTool);
623 d.register(RunCommandTool);
624
625 let ctx = test_ctx();
626 let (r1, r2, r3) = tokio::join!(
627 d.dispatch("sample", serde_json::json!({"path": "a"}), &ctx),
628 d.dispatch("async_sleep", serde_json::json!({}), &ctx),
629 d.dispatch("run_command", serde_json::json!({"command": "ls"}), &ctx),
630 );
631
632 assert_eq!(r1.unwrap().content(), "a");
633 assert_eq!(r2.unwrap().content(), "slept");
634 assert_eq!(r3.unwrap().content(), "Ran: ls");
635 }
636
637 #[tokio::test]
638 async fn concurrent_dispatches_to_same_tool() {
639 let mut d = ToolRegistry::new();
640 d.register(SampleTool);
641
642 let ctx = test_ctx();
643 let futs: Vec<_> = (0..10)
644 .map(|i| d.dispatch("sample", serde_json::json!({"path": format!("p{i}")}), &ctx))
645 .collect();
646
647 let results = futures::future::join_all(futs).await;
648 for (i, r) in results.into_iter().enumerate() {
649 assert_eq!(r.unwrap().content(), format!("p{i}"));
650 }
651 }
652
653 #[derive(Deserialize, schemars::JsonSchema)]
656 struct DocumentedParams {
657 hostname: String,
659 port: u16,
661 #[serde(default)]
663 timeout: Option<f64>,
664 }
665
666 struct DocumentedTool;
667 impl RustTool for DocumentedTool {
668 type Params = DocumentedParams;
669 const NAME: &'static str = "connect";
670 const DESCRIPTION: &'static str = "Connects to a remote host.";
671 async fn call(&self, p: Self::Params, _ctx: &ToolContext) -> Result<ToolOutput, ToolError> {
672 Ok(format!("{}:{}:{:?}", p.hostname, p.port, p.timeout).into())
673 }
674 }
675
676 #[test]
677 fn schema_contains_field_descriptions() {
678 let def = definition_of(&DocumentedTool).expect("schema");
679 let schema = &def.parameter_schema;
680
681 let props = schema["properties"].as_object().expect("properties object");
683 assert!(props.contains_key("hostname"), "missing hostname");
684 assert!(props.contains_key("port"), "missing port");
685 assert!(props.contains_key("timeout"), "missing timeout");
686
687 let hostname_desc = props["hostname"]["description"]
689 .as_str()
690 .expect("hostname description");
691 assert!(
692 hostname_desc.contains("hostname"),
693 "hostname description should mention 'hostname', got: {hostname_desc}"
694 );
695
696 let port_desc = props["port"]["description"]
697 .as_str()
698 .expect("port description");
699 assert!(
700 port_desc.contains("1-65535"),
701 "port description should mention range, got: {port_desc}"
702 );
703 }
704
705 #[test]
706 fn schema_required_vs_optional_fields() {
707 let def = definition_of(&DocumentedTool).expect("schema");
708 let schema = &def.parameter_schema;
709
710 let required = schema["required"]
711 .as_array()
712 .expect("required should be an array");
713
714 assert!(
716 required.iter().any(|v| v == "hostname"),
717 "hostname required"
718 );
719 assert!(required.iter().any(|v| v == "port"), "port required");
720 assert!(
721 !required.iter().any(|v| v == "timeout"),
722 "timeout should NOT be required"
723 );
724 }
725
726 #[tokio::test]
727 async fn dispatch_with_optional_field_missing() {
728 let mut d = ToolRegistry::new();
729 d.register(DocumentedTool);
730
731 let result = d
733 .dispatch(
734 "connect",
735 serde_json::json!({"hostname": "example.com", "port": 443}),
736 &test_ctx(),
737 )
738 .await;
739 assert_eq!(result.unwrap().content(), "example.com:443:None");
740 }
741
742 #[tokio::test]
743 async fn dispatch_with_optional_field_present() {
744 let mut d = ToolRegistry::new();
745 d.register(DocumentedTool);
746
747 let result = d
748 .dispatch(
749 "connect",
750 serde_json::json!({"hostname": "localhost", "port": 8080, "timeout": 30.0}),
751 &test_ctx(),
752 )
753 .await;
754 assert_eq!(result.unwrap().content(), "localhost:8080:Some(30.0)");
755 }
756
757 #[tokio::test]
758 async fn dispatch_with_extra_fields_ignored() {
759 let mut d = ToolRegistry::new();
761 d.register(SampleTool);
762
763 let result = d
764 .dispatch(
765 "sample",
766 serde_json::json!({"path": "/tmp/x", "unknown_field": 42}),
767 &test_ctx(),
768 )
769 .await;
770 assert_eq!(result.unwrap().content(), "/tmp/x");
771 }
772
773 #[tokio::test]
776 async fn erased_dispatch_preserves_borrow_lifetime() {
777 let mut d = ToolRegistry::new();
780 d.register(AsyncSleepTool);
781 d.register(SampleTool);
782
783 let r1 = d
785 .dispatch("async_sleep", serde_json::json!({}), &test_ctx())
786 .await;
787 let r2 = d
788 .dispatch("sample", serde_json::json!({"path": "test"}), &test_ctx())
789 .await;
790
791 assert_eq!(r1.unwrap().content(), "slept");
792 assert_eq!(r2.unwrap().content(), "test");
793 }
794
795 #[tokio::test]
796 async fn dispatch_returns_meaningful_error_for_wrong_type() {
797 let mut d = ToolRegistry::new();
798 d.register(RunCommandTool);
799
800 let result = d
802 .dispatch(
803 "run_command",
804 serde_json::json!({"command": {"nested": "object"}}),
805 &test_ctx(),
806 )
807 .await;
808 let err = result.unwrap_err();
809 assert!(
810 err.message
811 .contains("Failed to deserialize tool parameters"),
812 "Error should mention deserialization failure, got: {err}"
813 );
814 }
815
816 #[llm_tool]
821 async fn async_delayed_echo(
822 message: String,
824 ) -> Result<String, ToolError> {
825 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
826 Ok(format!("echo: {message}"))
827 }
828
829 #[tokio::test]
830 async fn tool_macro_async_fn_dispatches_with_await() {
831 let mut d = ToolRegistry::new();
832 d.register(AsyncDelayedEcho);
833
834 let result = d
835 .dispatch(
836 "async_delayed_echo",
837 serde_json::json!({"message": "hello async"}),
838 &test_ctx(),
839 )
840 .await;
841 assert_eq!(result.unwrap().content(), "echo: hello async");
842 }
843
844 #[llm_tool]
846 async fn async_file_reader(
847 path: String,
849 ) -> Result<String, ToolError> {
850 tokio::fs::read_to_string(&path)
851 .await
852 .map_err(|e| ToolError::new(format!("IO error: {e}")))
853 }
854
855 #[tokio::test]
856 async fn tool_macro_async_fn_reads_file() {
857 let tmp = tempfile::NamedTempFile::new().expect("create tempfile");
858 std::fs::write(tmp.path(), "async macro content").expect("write");
859
860 let mut d = ToolRegistry::new();
861 d.register(AsyncFileReader);
862
863 let path_str = tmp.path().to_str().expect("path").to_owned();
864 let result = d
865 .dispatch(
866 "async_file_reader",
867 serde_json::json!({"path": path_str}),
868 &test_ctx(),
869 )
870 .await;
871 assert_eq!(result.unwrap().content(), "async macro content");
872 }
873
874 #[llm_tool]
878 fn greet_optional(
879 name: String,
881 greeting: Option<String>,
883 ) -> Result<String, ToolError> {
884 let g = greeting.unwrap_or_else(|| "Hello".to_string());
885 Ok(format!("{g}, {name}!"))
886 }
887
888 #[test]
889 fn tool_macro_option_param_not_in_required() {
890 let def = definition_of(&GreetOptional).expect("schema");
891 let schema = &def.parameter_schema;
892
893 let required = schema["required"]
894 .as_array()
895 .expect("required should be an array");
896
897 assert!(
899 required.iter().any(|v| v == "name"),
900 "'name' should be required, got: {required:?}"
901 );
902 assert!(
903 !required.iter().any(|v| v == "greeting"),
904 "'greeting' (Option<String>) should NOT be required, got: {required:?}"
905 );
906 }
907
908 #[tokio::test]
909 async fn tool_macro_option_param_missing_from_json() {
910 let mut d = ToolRegistry::new();
911 d.register(GreetOptional);
912
913 let result = d
915 .dispatch(
916 "greet_optional",
917 serde_json::json!({"name": "World"}),
918 &test_ctx(),
919 )
920 .await;
921 assert_eq!(result.unwrap().content(), "Hello, World!");
922 }
923
924 #[tokio::test]
925 async fn tool_macro_option_param_provided_in_json() {
926 let mut d = ToolRegistry::new();
927 d.register(GreetOptional);
928
929 let result = d
931 .dispatch(
932 "greet_optional",
933 serde_json::json!({"name": "World", "greeting": "Hi"}),
934 &test_ctx(),
935 )
936 .await;
937 assert_eq!(result.unwrap().content(), "Hi, World!");
938 }
939
940 #[llm_tool]
942 async fn async_optional_tool(
943 input: String,
945 suffix: Option<String>,
947 ) -> Result<String, ToolError> {
948 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
949 let s = suffix.unwrap_or_default();
950 Ok(format!("{input}{s}"))
951 }
952
953 #[tokio::test]
954 async fn tool_macro_async_with_optional_param() {
955 let mut d = ToolRegistry::new();
956 d.register(AsyncOptionalTool);
957
958 let r1 = d
960 .dispatch(
961 "async_optional_tool",
962 serde_json::json!({"input": "base"}),
963 &test_ctx(),
964 )
965 .await;
966 assert_eq!(r1.unwrap().content(), "base");
967
968 let r2 = d
970 .dispatch(
971 "async_optional_tool",
972 serde_json::json!({"input": "base", "suffix": "_ext"}),
973 &test_ctx(),
974 )
975 .await;
976 assert_eq!(r2.unwrap().content(), "base_ext");
977 }
978
979 #[test]
980 fn tool_macro_async_optional_schema_correctness() {
981 let def = definition_of(&AsyncOptionalTool).expect("schema");
982 let schema = &def.parameter_schema;
983
984 let required = schema["required"].as_array().expect("required array");
985 assert!(required.iter().any(|v| v == "input"), "'input' required");
986 assert!(
987 !required.iter().any(|v| v == "suffix"),
988 "'suffix' (Option) should NOT be required"
989 );
990 }
991
992 #[test]
995 fn into_iter_yields_all_tool_name_definition_pairs() {
996 let mut d = ToolRegistry::new();
997 d.register(SampleTool);
998 d.register(RunCommandTool);
999
1000 let mut pairs: Vec<(&str, String)> = (&d)
1001 .into_iter()
1002 .map(|(name, def)| (name, def.name))
1003 .collect();
1004 pairs.sort();
1005
1006 assert_eq!(pairs.len(), 2);
1007 assert_eq!(pairs[0].0, "run_command");
1008 assert_eq!(pairs[0].1, "run_command");
1009 assert_eq!(pairs[1].0, "sample");
1010 assert_eq!(pairs[1].1, "sample");
1011 }
1012
1013 #[test]
1014 fn into_iter_empty_registry_yields_nothing() {
1015 let d = ToolRegistry::new();
1016 let count = (&d).into_iter().count();
1017 assert_eq!(count, 0);
1018 }
1019
1020 #[test]
1021 fn into_iter_for_loop_syntax() {
1022 let mut d = ToolRegistry::new();
1023 d.register(SampleTool);
1024
1025 let mut found = false;
1026 for (name, def) in &d {
1027 if name == "sample" {
1028 assert_eq!(def.description, "A sample tool");
1029 found = true;
1030 }
1031 }
1032 assert!(found, "Expected to find 'sample' tool via for-in loop");
1033 }
1034
1035 #[test]
1038 fn tool_context_conversation_id_none_by_default() {
1039 let ctx = ToolContext::new(None);
1040 assert!(ctx.conversation_id().is_none());
1041 assert!(!ctx.is_idle());
1042 }
1043
1044 #[test]
1045 fn tool_context_conversation_id_returns_value() {
1046 let ctx = ToolContext::new(Some("conv-123".to_owned()));
1047 assert_eq!(ctx.conversation_id(), Some("conv-123"));
1048 }
1049
1050 #[test]
1051 fn tool_context_get_set_state_roundtrip() {
1052 let ctx = ToolContext::new(None);
1053
1054 let val = ctx.get_state("missing", serde_json::json!("fallback"));
1056 assert_eq!(val, serde_json::json!("fallback"));
1057
1058 ctx.set_state("counter", serde_json::json!(42))
1060 .expect("set_state");
1061 let val = ctx.get_state("counter", serde_json::json!(0));
1062 assert_eq!(val, serde_json::json!(42));
1063
1064 ctx.set_state("counter", serde_json::json!(99))
1066 .expect("set_state");
1067 let val = ctx.get_state("counter", serde_json::json!(0));
1068 assert_eq!(val, serde_json::json!(99));
1069 }
1070
1071 #[test]
1072 fn tool_context_state_persists_across_reads() {
1073 let ctx = ToolContext::new(None);
1074 ctx.set_state("key", serde_json::json!({"nested": true}))
1075 .expect("set_state");
1076
1077 let v1 = ctx.get_state("key", serde_json::json!(null));
1079 let v2 = ctx.get_state("key", serde_json::json!(null));
1080 assert_eq!(v1, v2);
1081 assert_eq!(v1, serde_json::json!({"nested": true}));
1082 }
1083
1084 #[tokio::test]
1085 async fn dispatch_passes_context_to_tool() {
1086 struct ContextAwareTool;
1088
1089 impl RustTool for ContextAwareTool {
1090 type Params = EmptyParams;
1091 const NAME: &'static str = "ctx_tool";
1092 const DESCRIPTION: &'static str = "Reads conversation_id from context.";
1093
1094 async fn call(
1095 &self,
1096 _params: Self::Params,
1097 ctx: &ToolContext,
1098 ) -> Result<ToolOutput, ToolError> {
1099 let conv = ctx.conversation_id().unwrap_or("none");
1100 let count = ctx.get_state("call_count", serde_json::json!(0));
1101 let n = count.as_i64().unwrap_or(0);
1102 ctx.set_state("call_count", serde_json::json!(n + 1))
1103 .map_err(|e| ToolError::new(format!("set_state failed: {e}")))?;
1104 Ok(format!("conv={conv}, call={n}").into())
1105 }
1106 }
1107
1108 let mut d = ToolRegistry::new();
1109 d.register(ContextAwareTool);
1110
1111 let ctx = ToolContext::new(Some("test-conv".to_owned()));
1112
1113 let r1 = d.dispatch("ctx_tool", serde_json::json!({}), &ctx).await;
1115 assert_eq!(r1.unwrap().content(), "conv=test-conv, call=0");
1116
1117 let r2 = d.dispatch("ctx_tool", serde_json::json!({}), &ctx).await;
1119 assert_eq!(r2.unwrap().content(), "conv=test-conv, call=1");
1120 }
1121
1122 #[derive(serde::Serialize)]
1125 struct ProcessMeta {
1126 bytes_read: usize,
1127 source: String,
1128 }
1129
1130 struct MetadataTool;
1132
1133 impl RustTool for MetadataTool {
1134 type Params = PathParams;
1135 const NAME: &'static str = "metadata_tool";
1136 const DESCRIPTION: &'static str = "Returns output with metadata.";
1137
1138 async fn call(
1139 &self,
1140 params: Self::Params,
1141 _ctx: &ToolContext,
1142 ) -> Result<ToolOutput, ToolError> {
1143 ToolOutput::new(format!("processed: {}", params.path)).with_metadata(&ProcessMeta {
1144 bytes_read: 1024,
1145 source: params.path,
1146 })
1147 }
1148 }
1149
1150 #[tokio::test]
1151 async fn dispatch_preserves_tool_output_metadata() {
1152 let mut d = ToolRegistry::new();
1153 d.register(MetadataTool);
1154
1155 let result = d
1156 .dispatch(
1157 "metadata_tool",
1158 serde_json::json!({"path": "/etc/hosts"}),
1159 &test_ctx(),
1160 )
1161 .await
1162 .unwrap();
1163
1164 assert_eq!(result.content(), "processed: /etc/hosts");
1165 assert_eq!(result.metadata()["bytes_read"], 1024);
1166 assert_eq!(result.metadata()["source"], "/etc/hosts");
1167 assert_eq!(result.metadata().len(), 2);
1168 }
1169
1170 #[tokio::test]
1171 async fn dispatch_tool_output_display_uses_content() {
1172 let output = ToolOutput::new("hello world").with_meta("ignored", serde_json::json!(true));
1173 assert_eq!(output.to_string(), "hello world");
1174 }
1175
1176 #[tokio::test]
1177 async fn dispatch_tool_output_into_content_consumes() {
1178 let output = ToolOutput::new("owned").with_meta("key", serde_json::json!("val"));
1179 let content: String = output.into_content();
1180 assert_eq!(content, "owned");
1181 }
1182
1183 #[test]
1184 fn tool_output_from_str_has_empty_metadata() {
1185 let output: ToolOutput = "plain".into();
1186 assert_eq!(output.content(), "plain");
1187 assert!(output.metadata().is_empty());
1188 }
1189
1190 #[test]
1191 fn tool_output_from_string_has_empty_metadata() {
1192 let output: ToolOutput = "owned".to_string().into();
1193 assert_eq!(output.content(), "owned");
1194 assert!(output.metadata().is_empty());
1195 }
1196
1197 #[test]
1200 fn tool_error_with_metadata() {
1201 let err = ToolError::new("HTTP request failed")
1202 .with_meta("status_code", serde_json::json!(503))
1203 .with_meta("url", serde_json::json!("https://example.com"));
1204
1205 assert_eq!(err.message, "HTTP request failed");
1206 assert_eq!(err.metadata()["status_code"], 503);
1207 assert_eq!(err.metadata()["url"], "https://example.com");
1208 assert_eq!(err.metadata().len(), 2);
1209 }
1210
1211 #[test]
1212 fn tool_error_without_metadata_is_empty() {
1213 let err = ToolError::new("simple error");
1214 assert!(err.metadata().is_empty());
1215 }
1216
1217 #[test]
1218 fn tool_error_display_ignores_metadata() {
1219 let err = ToolError::new("visible").with_meta("hidden", serde_json::json!(true));
1220 assert_eq!(err.to_string(), "visible");
1221 }
1222
1223 #[test]
1224 fn tool_error_equality_includes_metadata() {
1225 let a = ToolError::new("err").with_meta("k", serde_json::json!(1));
1226 let b = ToolError::new("err").with_meta("k", serde_json::json!(1));
1227 let c = ToolError::new("err").with_meta("k", serde_json::json!(2));
1228 assert_eq!(a, b);
1229 assert_ne!(a, c);
1230 }
1231
1232 struct MetadataErrorTool;
1234
1235 impl RustTool for MetadataErrorTool {
1236 type Params = EmptyParams;
1237 const NAME: &'static str = "metadata_error_tool";
1238 const DESCRIPTION: &'static str = "Always fails with metadata.";
1239
1240 async fn call(
1241 &self,
1242 _params: Self::Params,
1243 _ctx: &ToolContext,
1244 ) -> Result<ToolOutput, ToolError> {
1245 Err(ToolError::new("service unavailable")
1246 .with_meta("retry_after_secs", serde_json::json!(30)))
1247 }
1248 }
1249
1250 #[tokio::test]
1251 async fn dispatch_preserves_tool_error_metadata() {
1252 let mut d = ToolRegistry::new();
1253 d.register(MetadataErrorTool);
1254
1255 let err = d
1256 .dispatch("metadata_error_tool", serde_json::json!({}), &test_ctx())
1257 .await
1258 .unwrap_err();
1259
1260 assert_eq!(err.message, "service unavailable");
1261 assert_eq!(err.metadata()["retry_after_secs"], 30);
1262 }
1263
1264 #[llm_tool]
1268 fn tool_with_metadata(
1269 input: String,
1271 ) -> Result<ToolOutput, ToolError> {
1272 Ok(ToolOutput::new(format!("echoed: {input}"))
1273 .with_meta("input_len", serde_json::json!(input.len())))
1274 }
1275
1276 #[tokio::test]
1277 async fn macro_tool_returning_tool_output_preserves_metadata() {
1278 let mut d = ToolRegistry::new();
1279 d.register(ToolWithMetadata);
1280
1281 let result = d
1282 .dispatch(
1283 "tool_with_metadata",
1284 serde_json::json!({"input": "hello"}),
1285 &test_ctx(),
1286 )
1287 .await
1288 .unwrap();
1289
1290 assert_eq!(result.content(), "echoed: hello");
1291 assert_eq!(result.metadata()["input_len"], 5);
1292 }
1293
1294 #[test]
1297 fn tool_output_with_metadata_struct() {
1298 #[derive(serde::Serialize)]
1299 struct Meta {
1300 status: String,
1301 count: u32,
1302 }
1303
1304 let out = ToolOutput::new("done")
1305 .with_metadata(&Meta {
1306 status: "ok".into(),
1307 count: 42,
1308 })
1309 .unwrap();
1310
1311 assert_eq!(out.metadata()["status"], "ok");
1312 assert_eq!(out.metadata()["count"], 42);
1313 assert_eq!(out.metadata().len(), 2);
1314 }
1315
1316 #[test]
1317 fn tool_output_with_metadata_merges_with_existing() {
1318 #[derive(serde::Serialize)]
1319 struct Extra {
1320 source: String,
1321 }
1322
1323 let out = ToolOutput::new("data")
1324 .with_meta("version", serde_json::json!(1))
1325 .with_metadata(&Extra {
1326 source: "cache".into(),
1327 })
1328 .unwrap();
1329
1330 assert_eq!(out.metadata()["version"], 1);
1331 assert_eq!(out.metadata()["source"], "cache");
1332 assert_eq!(out.metadata().len(), 2);
1333 }
1334
1335 #[test]
1336 fn tool_output_with_metadata_rejects_non_object() {
1337 let err = ToolOutput::new("x").with_metadata(&42_i32).unwrap_err();
1338
1339 assert!(
1340 err.message.contains("JSON object"),
1341 "Expected object error, got: {err}"
1342 );
1343 }
1344
1345 #[test]
1346 fn tool_error_with_metadata_struct() {
1347 #[derive(serde::Serialize)]
1348 struct ErrorMeta {
1349 status_code: u16,
1350 url: String,
1351 }
1352
1353 let err = ToolError::new("HTTP request failed")
1354 .with_metadata(&ErrorMeta {
1355 status_code: 503,
1356 url: "https://example.com".into(),
1357 })
1358 .unwrap();
1359
1360 assert_eq!(err.message, "HTTP request failed");
1361 assert_eq!(err.metadata()["status_code"], 503);
1362 assert_eq!(err.metadata()["url"], "https://example.com");
1363 assert_eq!(err.metadata().len(), 2);
1364 }
1365
1366 #[test]
1369 fn tool_output_from_metadata_populates_both() {
1370 #[derive(serde::Serialize)]
1371 struct Weather {
1372 location: String,
1373 temp_f: i32,
1374 }
1375
1376 let out = ToolOutput::from_metadata(&Weather {
1377 location: "Seattle".into(),
1378 temp_f: 72,
1379 })
1380 .unwrap();
1381
1382 assert!(out.content().contains("Seattle"));
1384 assert!(out.content().contains("72"));
1385
1386 assert_eq!(out.metadata()["location"], "Seattle");
1388 assert_eq!(out.metadata()["temp_f"], 72);
1389 assert_eq!(out.metadata().len(), 2);
1390 }
1391
1392 #[test]
1393 fn tool_output_from_metadata_rejects_non_object() {
1394 let err = ToolOutput::from_metadata(&"just a string").unwrap_err();
1395 assert!(
1396 err.message.contains("JSON object"),
1397 "Expected object error, got: {err}"
1398 );
1399 }
1400}