1use std::borrow::Cow;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use schemars::{JsonSchema, Schema, SchemaGenerator};
15use serde::Serialize;
16use serde::de::DeserializeOwned;
17use serde_json::Value;
18
19use crate::context::RequestContext;
20use crate::error::{Error, Result};
21use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
22
23#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
44pub struct NoParams;
45
46impl<'de> serde::Deserialize<'de> for NoParams {
47 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
48 where
49 D: serde::Deserializer<'de>,
50 {
51 struct NoParamsVisitor;
53
54 impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
55 type Value = NoParams;
56
57 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
58 formatter.write_str("null or an object")
59 }
60
61 fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
62 where
63 E: serde::de::Error,
64 {
65 Ok(NoParams)
66 }
67
68 fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
69 where
70 E: serde::de::Error,
71 {
72 Ok(NoParams)
73 }
74
75 fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
76 where
77 D: serde::Deserializer<'de>,
78 {
79 serde::Deserialize::deserialize(deserializer)
80 }
81
82 fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
83 where
84 A: serde::de::MapAccess<'de>,
85 {
86 while map
88 .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
89 .is_some()
90 {}
91 Ok(NoParams)
92 }
93 }
94
95 deserializer.deserialize_any(NoParamsVisitor)
96 }
97}
98
99impl JsonSchema for NoParams {
100 fn schema_name() -> Cow<'static, str> {
101 Cow::Borrowed("NoParams")
102 }
103
104 fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
105 serde_json::json!({
106 "type": "object"
107 })
108 .try_into()
109 .expect("valid schema")
110 }
111}
112
113pub fn validate_tool_name(name: &str) -> Result<()> {
121 if name.is_empty() {
122 return Err(Error::tool("Tool name cannot be empty"));
123 }
124 if name.len() > 128 {
125 return Err(Error::tool(format!(
126 "Tool name '{}' exceeds maximum length of 128 characters (got {})",
127 name,
128 name.len()
129 )));
130 }
131 if let Some(invalid_char) = name
132 .chars()
133 .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
134 {
135 return Err(Error::tool(format!(
136 "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
137 name, invalid_char
138 )));
139 }
140 Ok(())
141}
142
143pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
145
146pub trait ToolHandler: Send + Sync {
148 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
150
151 fn call_with_context(
156 &self,
157 _ctx: RequestContext,
158 args: Value,
159 ) -> BoxFuture<'_, Result<CallToolResult>> {
160 self.call(args)
161 }
162
163 fn uses_context(&self) -> bool {
165 false
166 }
167
168 fn input_schema(&self) -> Value;
170}
171
172pub struct Tool {
174 pub name: String,
175 pub title: Option<String>,
176 pub description: Option<String>,
177 pub output_schema: Option<Value>,
178 pub icons: Option<Vec<ToolIcon>>,
179 pub annotations: Option<ToolAnnotations>,
180 handler: Arc<dyn ToolHandler>,
181}
182
183impl std::fmt::Debug for Tool {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 f.debug_struct("Tool")
186 .field("name", &self.name)
187 .field("title", &self.title)
188 .field("description", &self.description)
189 .field("output_schema", &self.output_schema)
190 .field("icons", &self.icons)
191 .field("annotations", &self.annotations)
192 .finish_non_exhaustive()
193 }
194}
195
196impl Tool {
197 pub fn builder(name: impl Into<String>) -> ToolBuilder {
199 ToolBuilder::new(name)
200 }
201
202 pub fn definition(&self) -> ToolDefinition {
204 ToolDefinition {
205 name: self.name.clone(),
206 title: self.title.clone(),
207 description: self.description.clone(),
208 input_schema: self.handler.input_schema(),
209 output_schema: self.output_schema.clone(),
210 icons: self.icons.clone(),
211 annotations: self.annotations.clone(),
212 }
213 }
214
215 pub fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
217 self.handler.call(args)
218 }
219
220 pub fn call_with_context(
224 &self,
225 ctx: RequestContext,
226 args: Value,
227 ) -> BoxFuture<'_, Result<CallToolResult>> {
228 self.handler.call_with_context(ctx, args)
229 }
230
231 pub fn uses_context(&self) -> bool {
233 self.handler.uses_context()
234 }
235}
236
237pub struct ToolBuilder {
266 name: String,
267 title: Option<String>,
268 description: Option<String>,
269 output_schema: Option<Value>,
270 icons: Option<Vec<ToolIcon>>,
271 annotations: Option<ToolAnnotations>,
272}
273
274impl ToolBuilder {
275 pub fn new(name: impl Into<String>) -> Self {
276 Self {
277 name: name.into(),
278 title: None,
279 description: None,
280 output_schema: None,
281 icons: None,
282 annotations: None,
283 }
284 }
285
286 pub fn title(mut self, title: impl Into<String>) -> Self {
288 self.title = Some(title.into());
289 self
290 }
291
292 pub fn output_schema(mut self, schema: Value) -> Self {
294 self.output_schema = Some(schema);
295 self
296 }
297
298 pub fn icon(mut self, src: impl Into<String>) -> Self {
300 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
301 src: src.into(),
302 mime_type: None,
303 sizes: None,
304 });
305 self
306 }
307
308 pub fn icon_with_meta(
310 mut self,
311 src: impl Into<String>,
312 mime_type: Option<String>,
313 sizes: Option<Vec<String>>,
314 ) -> Self {
315 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
316 src: src.into(),
317 mime_type,
318 sizes,
319 });
320 self
321 }
322
323 pub fn description(mut self, description: impl Into<String>) -> Self {
325 self.description = Some(description.into());
326 self
327 }
328
329 pub fn read_only(mut self) -> Self {
331 self.annotations
332 .get_or_insert_with(ToolAnnotations::default)
333 .read_only_hint = true;
334 self
335 }
336
337 pub fn non_destructive(mut self) -> Self {
339 self.annotations
340 .get_or_insert_with(ToolAnnotations::default)
341 .destructive_hint = false;
342 self
343 }
344
345 pub fn idempotent(mut self) -> Self {
347 self.annotations
348 .get_or_insert_with(ToolAnnotations::default)
349 .idempotent_hint = true;
350 self
351 }
352
353 pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
355 self.annotations = Some(annotations);
356 self
357 }
358
359 pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
403 where
404 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
405 F: Fn(I) -> Fut + Send + Sync + 'static,
406 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
407 {
408 ToolBuilderWithHandler {
409 name: self.name,
410 title: self.title,
411 description: self.description,
412 output_schema: self.output_schema,
413 icons: self.icons,
414 annotations: self.annotations,
415 handler,
416 _phantom: std::marker::PhantomData,
417 }
418 }
419
420 pub fn handler_with_context<I, F, Fut>(self, handler: F) -> ToolBuilderWithContextHandler<I, F>
453 where
454 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
455 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
456 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
457 {
458 ToolBuilderWithContextHandler {
459 name: self.name,
460 title: self.title,
461 description: self.description,
462 output_schema: self.output_schema,
463 icons: self.icons,
464 annotations: self.annotations,
465 handler,
466 _phantom: std::marker::PhantomData,
467 }
468 }
469
470 pub fn handler_with_state<S, I, F, Fut>(
500 self,
501 state: S,
502 handler: F,
503 ) -> ToolBuilderWithHandler<
504 I,
505 impl Fn(I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
506 >
507 where
508 S: Clone + Send + Sync + 'static,
509 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
510 F: Fn(S, I) -> Fut + Send + Sync + 'static,
511 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
512 {
513 let handler = Arc::new(handler);
514 self.handler(move |input: I| {
515 let state = state.clone();
516 let handler = handler.clone();
517 Box::pin(async move { handler(state, input).await })
518 as BoxFuture<'static, Result<CallToolResult>>
519 })
520 }
521
522 pub fn handler_with_state_and_context<S, I, F, Fut>(
552 self,
553 state: S,
554 handler: F,
555 ) -> ToolBuilderWithContextHandler<
556 I,
557 impl Fn(RequestContext, I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
558 >
559 where
560 S: Clone + Send + Sync + 'static,
561 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
562 F: Fn(S, RequestContext, I) -> Fut + Send + Sync + 'static,
563 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
564 {
565 let handler = Arc::new(handler);
566 self.handler_with_context(move |ctx: RequestContext, input: I| {
567 let state = state.clone();
568 let handler = handler.clone();
569 Box::pin(async move { handler(state, ctx, input).await })
570 as BoxFuture<'static, Result<CallToolResult>>
571 })
572 }
573
574 pub fn handler_no_params<F, Fut>(self, handler: F) -> Result<Tool>
594 where
595 F: Fn() -> Fut + Send + Sync + 'static,
596 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
597 {
598 validate_tool_name(&self.name)?;
599 Ok(Tool {
600 name: self.name,
601 title: self.title,
602 description: self.description,
603 output_schema: self.output_schema,
604 icons: self.icons,
605 annotations: self.annotations,
606 handler: Arc::new(NoParamsHandler { handler }),
607 })
608 }
609
610 pub fn handler_no_params_with_state<S, F, Fut>(self, state: S, handler: F) -> Result<Tool>
637 where
638 S: Clone + Send + Sync + 'static,
639 F: Fn(S) -> Fut + Send + Sync + 'static,
640 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
641 {
642 validate_tool_name(&self.name)?;
643 Ok(Tool {
644 name: self.name,
645 title: self.title,
646 description: self.description,
647 output_schema: self.output_schema,
648 icons: self.icons,
649 annotations: self.annotations,
650 handler: Arc::new(NoParamsWithStateHandler { state, handler }),
651 })
652 }
653
654 pub fn raw_handler<F, Fut>(self, handler: F) -> Result<Tool>
658 where
659 F: Fn(Value) -> Fut + Send + Sync + 'static,
660 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
661 {
662 validate_tool_name(&self.name)?;
663 Ok(Tool {
664 name: self.name,
665 title: self.title,
666 description: self.description,
667 output_schema: self.output_schema,
668 icons: self.icons,
669 annotations: self.annotations,
670 handler: Arc::new(RawHandler { handler }),
671 })
672 }
673
674 pub fn raw_handler_with_context<F, Fut>(self, handler: F) -> Result<Tool>
681 where
682 F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
683 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
684 {
685 validate_tool_name(&self.name)?;
686 Ok(Tool {
687 name: self.name,
688 title: self.title,
689 description: self.description,
690 output_schema: self.output_schema,
691 icons: self.icons,
692 annotations: self.annotations,
693 handler: Arc::new(RawContextHandler { handler }),
694 })
695 }
696}
697
698pub struct ToolBuilderWithHandler<I, F> {
700 name: String,
701 title: Option<String>,
702 description: Option<String>,
703 output_schema: Option<Value>,
704 icons: Option<Vec<ToolIcon>>,
705 annotations: Option<ToolAnnotations>,
706 handler: F,
707 _phantom: std::marker::PhantomData<I>,
708}
709
710impl<I, F, Fut> ToolBuilderWithHandler<I, F>
711where
712 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
713 F: Fn(I) -> Fut + Send + Sync + 'static,
714 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
715{
716 pub fn build(self) -> Result<Tool> {
720 validate_tool_name(&self.name)?;
721 Ok(Tool {
722 name: self.name,
723 title: self.title,
724 description: self.description,
725 output_schema: self.output_schema,
726 icons: self.icons,
727 annotations: self.annotations,
728 handler: Arc::new(TypedHandler {
729 handler: self.handler,
730 _phantom: std::marker::PhantomData,
731 }),
732 })
733 }
734}
735
736pub struct ToolBuilderWithContextHandler<I, F> {
738 name: String,
739 title: Option<String>,
740 description: Option<String>,
741 output_schema: Option<Value>,
742 icons: Option<Vec<ToolIcon>>,
743 annotations: Option<ToolAnnotations>,
744 handler: F,
745 _phantom: std::marker::PhantomData<I>,
746}
747
748impl<I, F, Fut> ToolBuilderWithContextHandler<I, F>
749where
750 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
751 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
752 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
753{
754 pub fn build(self) -> Result<Tool> {
758 validate_tool_name(&self.name)?;
759 Ok(Tool {
760 name: self.name,
761 title: self.title,
762 description: self.description,
763 output_schema: self.output_schema,
764 icons: self.icons,
765 annotations: self.annotations,
766 handler: Arc::new(ContextAwareHandler {
767 handler: self.handler,
768 _phantom: std::marker::PhantomData,
769 }),
770 })
771 }
772}
773
774struct TypedHandler<I, F> {
780 handler: F,
781 _phantom: std::marker::PhantomData<I>,
782}
783
784impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
785where
786 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
787 F: Fn(I) -> Fut + Send + Sync + 'static,
788 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
789{
790 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
791 Box::pin(async move {
792 let input: I = serde_json::from_value(args)
793 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
794 (self.handler)(input).await
795 })
796 }
797
798 fn input_schema(&self) -> Value {
799 let schema = schemars::schema_for!(I);
800 serde_json::to_value(schema).unwrap_or_else(|_| {
801 serde_json::json!({
802 "type": "object"
803 })
804 })
805 }
806}
807
808struct RawHandler<F> {
810 handler: F,
811}
812
813impl<F, Fut> ToolHandler for RawHandler<F>
814where
815 F: Fn(Value) -> Fut + Send + Sync + 'static,
816 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
817{
818 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
819 Box::pin((self.handler)(args))
820 }
821
822 fn input_schema(&self) -> Value {
823 serde_json::json!({
825 "type": "object",
826 "additionalProperties": true
827 })
828 }
829}
830
831struct RawContextHandler<F> {
833 handler: F,
834}
835
836impl<F, Fut> ToolHandler for RawContextHandler<F>
837where
838 F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
839 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
840{
841 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
842 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
843 self.call_with_context(ctx, args)
844 }
845
846 fn call_with_context(
847 &self,
848 ctx: RequestContext,
849 args: Value,
850 ) -> BoxFuture<'_, Result<CallToolResult>> {
851 Box::pin((self.handler)(ctx, args))
852 }
853
854 fn uses_context(&self) -> bool {
855 true
856 }
857
858 fn input_schema(&self) -> Value {
859 serde_json::json!({
861 "type": "object",
862 "additionalProperties": true
863 })
864 }
865}
866
867struct ContextAwareHandler<I, F> {
869 handler: F,
870 _phantom: std::marker::PhantomData<I>,
871}
872
873impl<I, F, Fut> ToolHandler for ContextAwareHandler<I, F>
874where
875 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
876 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
877 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
878{
879 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
880 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
882 self.call_with_context(ctx, args)
883 }
884
885 fn call_with_context(
886 &self,
887 ctx: RequestContext,
888 args: Value,
889 ) -> BoxFuture<'_, Result<CallToolResult>> {
890 Box::pin(async move {
891 let input: I = serde_json::from_value(args)
892 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
893 (self.handler)(ctx, input).await
894 })
895 }
896
897 fn uses_context(&self) -> bool {
898 true
899 }
900
901 fn input_schema(&self) -> Value {
902 let schema = schemars::schema_for!(I);
903 serde_json::to_value(schema).unwrap_or_else(|_| {
904 serde_json::json!({
905 "type": "object"
906 })
907 })
908 }
909}
910
911struct NoParamsHandler<F> {
913 handler: F,
914}
915
916impl<F, Fut> ToolHandler for NoParamsHandler<F>
917where
918 F: Fn() -> Fut + Send + Sync + 'static,
919 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
920{
921 fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
922 Box::pin((self.handler)())
923 }
924
925 fn input_schema(&self) -> Value {
926 serde_json::json!({
927 "type": "object",
928 "properties": {}
929 })
930 }
931}
932
933struct NoParamsWithStateHandler<S, F> {
935 state: S,
936 handler: F,
937}
938
939impl<S, F, Fut> ToolHandler for NoParamsWithStateHandler<S, F>
940where
941 S: Clone + Send + Sync + 'static,
942 F: Fn(S) -> Fut + Send + Sync + 'static,
943 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
944{
945 fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
946 let state = self.state.clone();
947 let fut = (self.handler)(state);
948 Box::pin(fut)
949 }
950
951 fn input_schema(&self) -> Value {
952 serde_json::json!({
953 "type": "object",
954 "properties": {}
955 })
956 }
957}
958
959pub trait McpTool: Send + Sync + 'static {
1000 const NAME: &'static str;
1001 const DESCRIPTION: &'static str;
1002
1003 type Input: JsonSchema + DeserializeOwned + Send;
1004 type Output: Serialize + Send;
1005
1006 fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
1007
1008 fn annotations(&self) -> Option<ToolAnnotations> {
1010 None
1011 }
1012
1013 fn into_tool(self) -> Result<Tool>
1017 where
1018 Self: Sized,
1019 {
1020 validate_tool_name(Self::NAME)?;
1021 let annotations = self.annotations();
1022 let tool = Arc::new(self);
1023 Ok(Tool {
1024 name: Self::NAME.to_string(),
1025 title: None,
1026 description: Some(Self::DESCRIPTION.to_string()),
1027 output_schema: None,
1028 icons: None,
1029 annotations,
1030 handler: Arc::new(McpToolHandler { tool }),
1031 })
1032 }
1033}
1034
1035struct McpToolHandler<T: McpTool> {
1037 tool: Arc<T>,
1038}
1039
1040impl<T: McpTool> ToolHandler for McpToolHandler<T> {
1041 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
1042 let tool = self.tool.clone();
1043 Box::pin(async move {
1044 let input: T::Input = serde_json::from_value(args)
1045 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
1046 let output = tool.call(input).await?;
1047 let value = serde_json::to_value(output)
1048 .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
1049 Ok(CallToolResult::json(value))
1050 })
1051 }
1052
1053 fn input_schema(&self) -> Value {
1054 let schema = schemars::schema_for!(T::Input);
1055 serde_json::to_value(schema).unwrap_or_else(|_| {
1056 serde_json::json!({
1057 "type": "object"
1058 })
1059 })
1060 }
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065 use super::*;
1066 use schemars::JsonSchema;
1067 use serde::Deserialize;
1068
1069 #[derive(Debug, Deserialize, JsonSchema)]
1070 struct GreetInput {
1071 name: String,
1072 }
1073
1074 #[tokio::test]
1075 async fn test_builder_tool() {
1076 let tool = ToolBuilder::new("greet")
1077 .description("Greet someone")
1078 .handler(|input: GreetInput| async move {
1079 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1080 })
1081 .build()
1082 .expect("valid tool name");
1083
1084 assert_eq!(tool.name, "greet");
1085 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1086
1087 let result = tool
1088 .call(serde_json::json!({"name": "World"}))
1089 .await
1090 .unwrap();
1091
1092 assert!(!result.is_error);
1093 }
1094
1095 #[tokio::test]
1096 async fn test_raw_handler() {
1097 let tool = ToolBuilder::new("echo")
1098 .description("Echo input")
1099 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) })
1100 .expect("valid tool name");
1101
1102 let result = tool.call(serde_json::json!({"foo": "bar"})).await.unwrap();
1103
1104 assert!(!result.is_error);
1105 }
1106
1107 #[test]
1108 fn test_invalid_tool_name_empty() {
1109 let result = ToolBuilder::new("")
1110 .description("Empty name")
1111 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1112
1113 assert!(result.is_err());
1114 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
1115 }
1116
1117 #[test]
1118 fn test_invalid_tool_name_too_long() {
1119 let long_name = "a".repeat(129);
1120 let result = ToolBuilder::new(long_name)
1121 .description("Too long")
1122 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1123
1124 assert!(result.is_err());
1125 assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
1126 }
1127
1128 #[test]
1129 fn test_invalid_tool_name_bad_chars() {
1130 let result = ToolBuilder::new("my tool!")
1131 .description("Bad chars")
1132 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1133
1134 assert!(result.is_err());
1135 assert!(
1136 result
1137 .unwrap_err()
1138 .to_string()
1139 .contains("invalid character")
1140 );
1141 }
1142
1143 #[test]
1144 fn test_valid_tool_names() {
1145 let names = [
1147 "my_tool",
1148 "my-tool",
1149 "my.tool",
1150 "MyTool123",
1151 "a",
1152 &"a".repeat(128),
1153 ];
1154 for name in names {
1155 let result = ToolBuilder::new(name)
1156 .description("Valid")
1157 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1158 assert!(result.is_ok(), "Expected '{}' to be valid", name);
1159 }
1160 }
1161
1162 #[tokio::test]
1163 async fn test_context_aware_handler() {
1164 use crate::context::{RequestContext, notification_channel};
1165 use crate::protocol::{ProgressToken, RequestId};
1166
1167 #[derive(Debug, Deserialize, JsonSchema)]
1168 struct ProcessInput {
1169 count: i32,
1170 }
1171
1172 let tool = ToolBuilder::new("process")
1173 .description("Process with context")
1174 .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
1175 for i in 0..input.count {
1177 if ctx.is_cancelled() {
1178 return Ok(CallToolResult::error("Cancelled"));
1179 }
1180 ctx.report_progress(i as f64, Some(input.count as f64), None)
1181 .await;
1182 }
1183 Ok(CallToolResult::text(format!(
1184 "Processed {} items",
1185 input.count
1186 )))
1187 })
1188 .build()
1189 .expect("valid tool name");
1190
1191 assert_eq!(tool.name, "process");
1192 assert!(tool.uses_context());
1193
1194 let (tx, mut rx) = notification_channel(10);
1196 let ctx = RequestContext::new(RequestId::Number(1))
1197 .with_progress_token(ProgressToken::Number(42))
1198 .with_notification_sender(tx);
1199
1200 let result = tool
1201 .call_with_context(ctx, serde_json::json!({"count": 3}))
1202 .await
1203 .unwrap();
1204
1205 assert!(!result.is_error);
1206
1207 let mut progress_count = 0;
1209 while rx.try_recv().is_ok() {
1210 progress_count += 1;
1211 }
1212 assert_eq!(progress_count, 3);
1213 }
1214
1215 #[tokio::test]
1216 async fn test_context_aware_handler_cancellation() {
1217 use crate::context::RequestContext;
1218 use crate::protocol::RequestId;
1219 use std::sync::Arc;
1220 use std::sync::atomic::{AtomicI32, Ordering};
1221
1222 #[derive(Debug, Deserialize, JsonSchema)]
1223 struct LongRunningInput {
1224 iterations: i32,
1225 }
1226
1227 let iterations_completed = Arc::new(AtomicI32::new(0));
1228 let iterations_ref = iterations_completed.clone();
1229
1230 let tool = ToolBuilder::new("long_running")
1231 .description("Long running task")
1232 .handler_with_context(move |ctx: RequestContext, input: LongRunningInput| {
1233 let completed = iterations_ref.clone();
1234 async move {
1235 for i in 0..input.iterations {
1236 if ctx.is_cancelled() {
1237 return Ok(CallToolResult::error("Cancelled"));
1238 }
1239 completed.fetch_add(1, Ordering::SeqCst);
1240 tokio::task::yield_now().await;
1242 if i == 2 {
1244 ctx.cancellation_token().cancel();
1245 }
1246 }
1247 Ok(CallToolResult::text("Done"))
1248 }
1249 })
1250 .build()
1251 .expect("valid tool name");
1252
1253 let ctx = RequestContext::new(RequestId::Number(1));
1254
1255 let result = tool
1256 .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1257 .await
1258 .unwrap();
1259
1260 assert!(result.is_error);
1263 assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1264 }
1265
1266 #[tokio::test]
1267 async fn test_tool_builder_with_enhanced_fields() {
1268 let output_schema = serde_json::json!({
1269 "type": "object",
1270 "properties": {
1271 "greeting": {"type": "string"}
1272 }
1273 });
1274
1275 let tool = ToolBuilder::new("greet")
1276 .title("Greeting Tool")
1277 .description("Greet someone")
1278 .output_schema(output_schema.clone())
1279 .icon("https://example.com/icon.png")
1280 .icon_with_meta(
1281 "https://example.com/icon-large.png",
1282 Some("image/png".to_string()),
1283 Some(vec!["96x96".to_string()]),
1284 )
1285 .handler(|input: GreetInput| async move {
1286 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1287 })
1288 .build()
1289 .expect("valid tool name");
1290
1291 assert_eq!(tool.name, "greet");
1292 assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1293 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1294 assert_eq!(tool.output_schema, Some(output_schema));
1295 assert!(tool.icons.is_some());
1296 assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1297
1298 let def = tool.definition();
1300 assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1301 assert!(def.output_schema.is_some());
1302 assert!(def.icons.is_some());
1303 }
1304
1305 #[tokio::test]
1306 async fn test_handler_with_state() {
1307 let shared = Arc::new("shared-state".to_string());
1308
1309 let tool = ToolBuilder::new("stateful")
1310 .description("Uses shared state")
1311 .handler_with_state(shared, |state: Arc<String>, input: GreetInput| async move {
1312 Ok(CallToolResult::text(format!(
1313 "{}: Hello, {}!",
1314 state, input.name
1315 )))
1316 })
1317 .build()
1318 .expect("valid tool name");
1319
1320 let result = tool
1321 .call(serde_json::json!({"name": "World"}))
1322 .await
1323 .unwrap();
1324 assert!(!result.is_error);
1325 }
1326
1327 #[tokio::test]
1328 async fn test_handler_with_state_and_context() {
1329 use crate::context::RequestContext;
1330 use crate::protocol::RequestId;
1331
1332 let shared = Arc::new(42_i32);
1333
1334 let tool = ToolBuilder::new("stateful_ctx")
1335 .description("Uses state and context")
1336 .handler_with_state_and_context(
1337 shared,
1338 |state: Arc<i32>, _ctx: RequestContext, input: GreetInput| async move {
1339 Ok(CallToolResult::text(format!(
1340 "{}: Hello, {}!",
1341 state, input.name
1342 )))
1343 },
1344 )
1345 .build()
1346 .expect("valid tool name");
1347
1348 assert!(tool.uses_context());
1349
1350 let ctx = RequestContext::new(RequestId::Number(1));
1351 let result = tool
1352 .call_with_context(ctx, serde_json::json!({"name": "World"}))
1353 .await
1354 .unwrap();
1355 assert!(!result.is_error);
1356 }
1357
1358 #[tokio::test]
1359 async fn test_handler_no_params() {
1360 let tool = ToolBuilder::new("no_params")
1361 .description("Takes no parameters")
1362 .handler_no_params(|| async { Ok(CallToolResult::text("no params result")) })
1363 .expect("valid tool name");
1364
1365 assert_eq!(tool.name, "no_params");
1366
1367 let result = tool.call(serde_json::json!({})).await.unwrap();
1369 assert!(!result.is_error);
1370
1371 let result = tool
1373 .call(serde_json::json!({"unexpected": "value"}))
1374 .await
1375 .unwrap();
1376 assert!(!result.is_error);
1377
1378 let schema = tool.definition().input_schema;
1380 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1381 assert!(
1382 schema
1383 .get("properties")
1384 .unwrap()
1385 .as_object()
1386 .unwrap()
1387 .is_empty()
1388 );
1389 }
1390
1391 #[tokio::test]
1392 async fn test_handler_no_params_with_state() {
1393 let shared = Arc::new("shared_value".to_string());
1394
1395 let tool = ToolBuilder::new("no_params_with_state")
1396 .description("Takes no parameters but has state")
1397 .handler_no_params_with_state(shared, |state: Arc<String>| async move {
1398 Ok(CallToolResult::text(format!("state: {}", state)))
1399 })
1400 .expect("valid tool name");
1401
1402 assert_eq!(tool.name, "no_params_with_state");
1403
1404 let result = tool.call(serde_json::json!({})).await.unwrap();
1406 assert!(!result.is_error);
1407 assert_eq!(result.first_text().unwrap(), "state: shared_value");
1408
1409 let schema = tool.definition().input_schema;
1411 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1412 assert!(
1413 schema
1414 .get("properties")
1415 .unwrap()
1416 .as_object()
1417 .unwrap()
1418 .is_empty()
1419 );
1420 }
1421
1422 #[test]
1423 fn test_no_params_schema() {
1424 let schema = schemars::schema_for!(NoParams);
1426 let schema_value = serde_json::to_value(&schema).unwrap();
1427 assert_eq!(
1428 schema_value.get("type").and_then(|v| v.as_str()),
1429 Some("object"),
1430 "NoParams should generate type: object schema"
1431 );
1432 }
1433
1434 #[test]
1435 fn test_no_params_deserialize() {
1436 let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
1438 assert_eq!(from_empty_object, NoParams);
1439
1440 let from_null: NoParams = serde_json::from_str("null").unwrap();
1441 assert_eq!(from_null, NoParams);
1442
1443 let from_object_with_fields: NoParams =
1445 serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
1446 assert_eq!(from_object_with_fields, NoParams);
1447 }
1448
1449 #[tokio::test]
1450 async fn test_no_params_type_in_handler() {
1451 let tool = ToolBuilder::new("status")
1453 .description("Get status")
1454 .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
1455 .build()
1456 .expect("valid tool name");
1457
1458 let schema = tool.definition().input_schema;
1460 assert_eq!(
1461 schema.get("type").and_then(|v| v.as_str()),
1462 Some("object"),
1463 "NoParams handler should produce type: object schema"
1464 );
1465
1466 let result = tool.call(serde_json::json!({})).await.unwrap();
1468 assert!(!result.is_error);
1469 }
1470}