1use async_trait::async_trait;
26use serde::{Deserialize, Serialize};
27
28use crate::error::Result;
29use crate::types::ToolCallId;
30
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
60pub struct ToolDefinition {
61 pub name: String,
65
66 pub description: String,
70
71 pub parameters: serde_json::Value,
76}
77
78impl ToolDefinition {
79 pub fn new(
81 name: impl Into<String>,
82 description: impl Into<String>,
83 parameters: serde_json::Value,
84 ) -> Self {
85 Self {
86 name: name.into(),
87 description: description.into(),
88 parameters,
89 }
90 }
91
92 pub fn no_params(name: impl Into<String>, description: impl Into<String>) -> Self {
96 Self {
97 name: name.into(),
98 description: description.into(),
99 parameters: serde_json::json!({
100 "type": "object",
101 "properties": {}
102 }),
103 }
104 }
105}
106
107#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
133pub struct ToolOutput {
134 #[serde(default)]
138 pub title: String,
139
140 pub content: String,
144
145 #[serde(default = "default_metadata")]
149 pub metadata: serde_json::Value,
150
151 #[serde(default)]
157 pub is_error: bool,
158}
159
160fn default_metadata() -> serde_json::Value {
161 serde_json::Value::Object(serde_json::Map::new())
162}
163
164impl ToolOutput {
165 pub fn success(content: impl Into<String>) -> Self {
167 Self {
168 title: String::new(),
169 content: content.into(),
170 metadata: default_metadata(),
171 is_error: false,
172 }
173 }
174
175 pub fn success_with_title(title: impl Into<String>, content: impl Into<String>) -> Self {
177 Self {
178 title: title.into(),
179 content: content.into(),
180 metadata: default_metadata(),
181 is_error: false,
182 }
183 }
184
185 pub fn error(content: impl Into<String>) -> Self {
190 Self {
191 title: String::new(),
192 content: content.into(),
193 metadata: default_metadata(),
194 is_error: true,
195 }
196 }
197
198 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
200 self.metadata = metadata;
201 self
202 }
203
204 pub fn with_title(mut self, title: impl Into<String>) -> Self {
206 self.title = title.into();
207 self
208 }
209}
210
211#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
236#[serde(tag = "type", rename_all = "snake_case")]
237pub enum ToolChoice {
238 #[default]
240 Auto,
241
242 None,
244
245 Required,
247
248 Specific {
250 name: String,
252 },
253}
254
255impl ToolChoice {
256 pub fn specific(name: impl Into<String>) -> Self {
258 Self::Specific { name: name.into() }
259 }
260
261 pub fn allows_tools(&self) -> bool {
263 !matches!(self, Self::None)
264 }
265
266 pub fn required_tool(&self) -> Option<&str> {
268 match self {
269 Self::Specific { name } => Some(name.as_str()),
270 _ => Option::None,
271 }
272 }
273
274 pub fn is_forced(&self) -> bool {
276 matches!(self, Self::Required | Self::Specific { .. })
277 }
278}
279
280
281pub use tokio_util::sync::CancellationToken;
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
330#[serde(rename_all = "snake_case")]
331pub enum ConcurrencyMode {
332 #[default]
334 Shared,
335 Exclusive,
337}
338
339pub struct ToolCallContext {
363 pub call_id: ToolCallId,
365
366 pub cancellation: CancellationToken,
368
369 pub extra: serde_json::Value,
374}
375
376impl ToolCallContext {
377 pub fn new(call_id: ToolCallId) -> Self {
379 Self {
380 call_id,
381 cancellation: CancellationToken::new(),
382 extra: serde_json::Value::Object(serde_json::Map::new()),
383 }
384 }
385
386 pub fn with_cancellation(mut self, token: CancellationToken) -> Self {
388 self.cancellation = token;
389 self
390 }
391
392 pub fn with_extra(mut self, extra: serde_json::Value) -> Self {
394 self.extra = extra;
395 self
396 }
397}
398
399#[async_trait]
447pub trait Tool: Send + Sync {
448 fn definition(&self) -> &ToolDefinition;
450
451 async fn validate(&self, _args: &serde_json::Value, _ctx: &ToolCallContext) -> Result<()> {
459 Ok(())
460 }
461
462 async fn execute(&self, args: serde_json::Value, ctx: &ToolCallContext) -> Result<ToolOutput>;
468
469 fn concurrency_mode(&self) -> ConcurrencyMode {
474 ConcurrencyMode::Shared
475 }
476
477 fn permission_key(&self) -> &str {
482 &self.definition().name
483 }
484
485 fn check_permissions(
508 &self,
509 _args: &serde_json::Value,
510 _ctx: &ToolCallContext,
511 ) -> crate::permission::PermissionResult {
512 crate::permission::PermissionResult::Passthrough
513 }
514
515 fn permission_request(
550 &self,
551 _args: &serde_json::Value,
552 _ctx: &ToolCallContext,
553 ) -> Option<crate::permission::PermissionRequest> {
554 None
555 }
556}
557
558#[cfg(test)]
563mod tests {
564 use super::*;
565 use std::sync::Arc;
566 use serde_json::json;
567
568 #[test]
571 fn test_tool_definition_new() {
572 let def = ToolDefinition::new(
573 "read_file",
574 "Read file contents",
575 json!({
576 "type": "object",
577 "properties": {
578 "path": { "type": "string" }
579 },
580 "required": ["path"]
581 }),
582 );
583 assert_eq!(def.name, "read_file");
584 assert_eq!(def.description, "Read file contents");
585 assert!(def.parameters["properties"]["path"]["type"]
586 .as_str()
587 .unwrap()
588 == "string");
589 }
590
591 #[test]
592 fn test_tool_definition_no_params() {
593 let def = ToolDefinition::no_params("get_time", "Get current time");
594 assert_eq!(def.name, "get_time");
595 assert_eq!(def.parameters["type"], "object");
596 assert!(def.parameters["properties"]
597 .as_object()
598 .unwrap()
599 .is_empty());
600 }
601
602 #[test]
603 fn test_tool_definition_serde_roundtrip() {
604 let def = ToolDefinition::new(
605 "bash",
606 "Run a shell command",
607 json!({
608 "type": "object",
609 "properties": {
610 "command": { "type": "string" }
611 },
612 "required": ["command"]
613 }),
614 );
615 let json_str = serde_json::to_string(&def).unwrap();
616 let restored: ToolDefinition = serde_json::from_str(&json_str).unwrap();
617 assert_eq!(def, restored);
618 }
619
620 #[test]
623 fn test_tool_output_success() {
624 let out = ToolOutput::success("hello world");
625 assert_eq!(out.content, "hello world");
626 assert!(!out.is_error);
627 assert!(out.title.is_empty());
628 }
629
630 #[test]
631 fn test_tool_output_success_with_title() {
632 let out = ToolOutput::success_with_title("Read file", "file contents");
633 assert_eq!(out.title, "Read file");
634 assert_eq!(out.content, "file contents");
635 assert!(!out.is_error);
636 }
637
638 #[test]
639 fn test_tool_output_error() {
640 let out = ToolOutput::error("not found");
641 assert_eq!(out.content, "not found");
642 assert!(out.is_error);
643 }
644
645 #[test]
646 fn test_tool_output_builder() {
647 let out = ToolOutput::success("ok")
648 .with_title("Done")
649 .with_metadata(json!({"elapsed_ms": 42}));
650 assert_eq!(out.title, "Done");
651 assert_eq!(out.metadata["elapsed_ms"], 42);
652 assert!(!out.is_error);
653 }
654
655 #[test]
656 fn test_tool_output_serde_roundtrip() {
657 let out = ToolOutput::success_with_title("Read", "contents")
658 .with_metadata(json!({"lines": 100}));
659 let json_str = serde_json::to_string(&out).unwrap();
660 let restored: ToolOutput = serde_json::from_str(&json_str).unwrap();
661 assert_eq!(out, restored);
662 }
663
664 #[test]
665 fn test_tool_output_serde_defaults() {
666 let json_str = r#"{"content":"hello"}"#;
668 let out: ToolOutput = serde_json::from_str(json_str).unwrap();
669 assert_eq!(out.content, "hello");
670 assert!(!out.is_error);
671 assert!(out.title.is_empty());
672 assert!(out.metadata.is_object());
673 }
674
675 #[test]
678 fn test_tool_choice_auto_default() {
679 let choice = ToolChoice::default();
680 assert_eq!(choice, ToolChoice::Auto);
681 }
682
683 #[test]
684 fn test_tool_choice_allows_tools() {
685 assert!(ToolChoice::Auto.allows_tools());
686 assert!(!ToolChoice::None.allows_tools());
687 assert!(ToolChoice::Required.allows_tools());
688 assert!(ToolChoice::specific("bash").allows_tools());
689 }
690
691 #[test]
692 fn test_tool_choice_required_tool() {
693 assert_eq!(ToolChoice::Auto.required_tool(), Option::None);
694 assert_eq!(ToolChoice::None.required_tool(), Option::None);
695 assert_eq!(ToolChoice::Required.required_tool(), Option::None);
696 assert_eq!(ToolChoice::specific("bash").required_tool(), Some("bash"));
697 }
698
699 #[test]
700 fn test_tool_choice_is_forced() {
701 assert!(!ToolChoice::Auto.is_forced());
702 assert!(!ToolChoice::None.is_forced());
703 assert!(ToolChoice::Required.is_forced());
704 assert!(ToolChoice::specific("bash").is_forced());
705 }
706
707 #[test]
708 fn test_tool_choice_serde_roundtrip() {
709 for choice in [
710 ToolChoice::Auto,
711 ToolChoice::None,
712 ToolChoice::Required,
713 ToolChoice::specific("read_file"),
714 ] {
715 let json_str = serde_json::to_string(&choice).unwrap();
716 let restored: ToolChoice = serde_json::from_str(&json_str).unwrap();
717 assert_eq!(choice, restored);
718 }
719 }
720
721 #[test]
722 fn test_tool_choice_serde_format() {
723 let json_str = serde_json::to_string(&ToolChoice::Auto).unwrap();
724 assert!(json_str.contains(r#""type":"auto""#));
725
726 let json_str = serde_json::to_string(&ToolChoice::specific("bash")).unwrap();
727 assert!(json_str.contains(r#""type":"specific""#));
728 assert!(json_str.contains(r#""name":"bash""#));
729 }
730
731 #[test]
734 fn test_cancellation_token_new() {
735 let token = CancellationToken::new();
736 assert!(!token.is_cancelled());
737 }
738
739 #[test]
740 fn test_cancellation_token_cancel() {
741 let token = CancellationToken::new();
742 token.cancel();
743 assert!(token.is_cancelled());
744 }
745
746 #[test]
747 fn test_cancellation_token_clone_shares_state() {
748 let token = CancellationToken::new();
749 let token2 = token.clone();
750
751 assert!(!token.is_cancelled());
752 assert!(!token2.is_cancelled());
753
754 token.cancel();
755
756 assert!(token.is_cancelled());
757 assert!(token2.is_cancelled());
758 }
759
760 #[test]
761 fn test_cancellation_token_child() {
762 let parent = CancellationToken::new();
763 let child = parent.child_token();
764 assert!(!child.is_cancelled());
765 parent.cancel();
766 assert!(child.is_cancelled());
767 }
768
769 #[test]
772 fn test_concurrency_mode_default() {
773 assert_eq!(ConcurrencyMode::default(), ConcurrencyMode::Shared);
774 }
775
776 #[test]
777 fn test_concurrency_mode_serde_roundtrip() {
778 for mode in [ConcurrencyMode::Shared, ConcurrencyMode::Exclusive] {
779 let json_str = serde_json::to_string(&mode).unwrap();
780 let restored: ConcurrencyMode = serde_json::from_str(&json_str).unwrap();
781 assert_eq!(mode, restored);
782 }
783 }
784
785 #[test]
786 fn test_concurrency_mode_serde_format() {
787 assert_eq!(
788 serde_json::to_string(&ConcurrencyMode::Shared).unwrap(),
789 r#""shared""#
790 );
791 assert_eq!(
792 serde_json::to_string(&ConcurrencyMode::Exclusive).unwrap(),
793 r#""exclusive""#
794 );
795 }
796
797 #[test]
800 fn test_tool_call_context_new() {
801 let ctx = ToolCallContext::new(ToolCallId::new("call_1"));
802 assert_eq!(ctx.call_id.as_str(), "call_1");
803 assert!(!ctx.cancellation.is_cancelled());
804 assert!(ctx.extra.is_object());
805 }
806
807 #[test]
808 fn test_tool_call_context_builder() {
809 let token = CancellationToken::new();
810 let ctx = ToolCallContext::new(ToolCallId::new("call_2"))
811 .with_cancellation(token.clone())
812 .with_extra(json!({"cwd": "/tmp", "env": {"DEBUG": "1"}}));
813
814 assert_eq!(ctx.call_id.as_str(), "call_2");
815 assert_eq!(ctx.extra["cwd"], "/tmp");
816 assert_eq!(ctx.extra["env"]["DEBUG"], "1");
817
818 token.cancel();
820 assert!(ctx.cancellation.is_cancelled());
821 }
822
823 struct EchoTool;
826
827 static ECHO_DEF: std::sync::LazyLock<ToolDefinition> = std::sync::LazyLock::new(|| {
828 ToolDefinition::new(
829 "echo",
830 "Echoes the input message",
831 json!({
832 "type": "object",
833 "properties": {
834 "message": { "type": "string" }
835 },
836 "required": ["message"]
837 }),
838 )
839 });
840
841 #[async_trait]
842 impl Tool for EchoTool {
843 fn definition(&self) -> &ToolDefinition {
844 &ECHO_DEF
845 }
846
847 async fn execute(
848 &self,
849 args: serde_json::Value,
850 _ctx: &ToolCallContext,
851 ) -> Result<ToolOutput> {
852 let message = args["message"]
853 .as_str()
854 .unwrap_or("(no message)");
855 Ok(ToolOutput::success(message))
856 }
857 }
858
859 struct ExclusiveTool;
860
861 static EXCLUSIVE_DEF: std::sync::LazyLock<ToolDefinition> = std::sync::LazyLock::new(|| {
862 ToolDefinition::new(
863 "write_file",
864 "Write content to a file",
865 json!({
866 "type": "object",
867 "properties": {
868 "path": { "type": "string" },
869 "content": { "type": "string" }
870 },
871 "required": ["path", "content"]
872 }),
873 )
874 });
875
876 #[async_trait]
877 impl Tool for ExclusiveTool {
878 fn definition(&self) -> &ToolDefinition {
879 &EXCLUSIVE_DEF
880 }
881
882 async fn validate(
883 &self,
884 args: &serde_json::Value,
885 _ctx: &ToolCallContext,
886 ) -> Result<()> {
887 let path = args["path"].as_str().unwrap_or("");
888 if path.starts_with("/etc/") {
889 return Err(crate::Error::tool(
890 "write_file",
891 _ctx.call_id.clone(),
892 "cannot write to /etc/",
893 ));
894 }
895 Ok(())
896 }
897
898 async fn execute(
899 &self,
900 args: serde_json::Value,
901 ctx: &ToolCallContext,
902 ) -> Result<ToolOutput> {
903 if ctx.cancellation.is_cancelled() {
904 return Err(crate::Error::Cancelled);
905 }
906 let path = args["path"].as_str().unwrap_or("?");
907 Ok(ToolOutput::success(format!("wrote to {path}"))
908 .with_title(format!("Write: {path}")))
909 }
910
911 fn concurrency_mode(&self) -> ConcurrencyMode {
912 ConcurrencyMode::Exclusive
913 }
914 }
915
916 #[tokio::test]
917 async fn test_tool_echo_execute() {
918 let tool = EchoTool;
919 let ctx = ToolCallContext::new(ToolCallId::new("c1"));
920 let result = tool
921 .execute(json!({"message": "hello"}), &ctx)
922 .await
923 .unwrap();
924 assert_eq!(result.content, "hello");
925 assert!(!result.is_error);
926 }
927
928 #[tokio::test]
929 async fn test_tool_echo_default_concurrency() {
930 let tool = EchoTool;
931 assert_eq!(tool.concurrency_mode(), ConcurrencyMode::Shared);
932 }
933
934 #[tokio::test]
935 async fn test_tool_echo_default_validate() {
936 let tool = EchoTool;
937 let ctx = ToolCallContext::new(ToolCallId::new("c1"));
938 assert!(tool.validate(&json!({}), &ctx).await.is_ok());
940 }
941
942 #[tokio::test]
943 async fn test_tool_definition_matches() {
944 let tool = EchoTool;
945 assert_eq!(tool.definition().name, "echo");
946 assert!(!tool.definition().description.is_empty());
947 }
948
949 #[tokio::test]
950 async fn test_tool_exclusive_concurrency() {
951 let tool = ExclusiveTool;
952 assert_eq!(tool.concurrency_mode(), ConcurrencyMode::Exclusive);
953 }
954
955 #[tokio::test]
956 async fn test_tool_validate_rejects_invalid() {
957 let tool = ExclusiveTool;
958 let ctx = ToolCallContext::new(ToolCallId::new("c2"));
959 let result = tool
960 .validate(&json!({"path": "/etc/shadow", "content": "x"}), &ctx)
961 .await;
962 assert!(result.is_err());
963 }
964
965 #[tokio::test]
966 async fn test_tool_validate_accepts_valid() {
967 let tool = ExclusiveTool;
968 let ctx = ToolCallContext::new(ToolCallId::new("c3"));
969 let result = tool
970 .validate(&json!({"path": "/tmp/test.txt", "content": "x"}), &ctx)
971 .await;
972 assert!(result.is_ok());
973 }
974
975 #[tokio::test]
976 async fn test_tool_execute_with_cancellation() {
977 let tool = ExclusiveTool;
978 let token = CancellationToken::new();
979 let ctx = ToolCallContext::new(ToolCallId::new("c4"))
980 .with_cancellation(token.clone());
981
982 let result = tool
984 .execute(json!({"path": "/tmp/a.txt", "content": "hi"}), &ctx)
985 .await
986 .unwrap();
987 assert_eq!(result.content, "wrote to /tmp/a.txt");
988
989 token.cancel();
991 let ctx2 = ToolCallContext::new(ToolCallId::new("c5"))
992 .with_cancellation(token.clone());
993 let result = tool
994 .execute(json!({"path": "/tmp/b.txt", "content": "hi"}), &ctx2)
995 .await;
996 assert!(matches!(result, Err(crate::Error::Cancelled)));
997 }
998
999 #[tokio::test]
1000 async fn test_tool_dyn_dispatch() {
1001 let tool: Arc<dyn Tool> = Arc::new(EchoTool);
1003 let ctx = ToolCallContext::new(ToolCallId::new("c6"));
1004 let result = tool
1005 .execute(json!({"message": "dynamic"}), &ctx)
1006 .await
1007 .unwrap();
1008 assert_eq!(result.content, "dynamic");
1009 }
1010}