1use std::borrow::Cow;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use schemars::{JsonSchema, Schema, SchemaGenerator};
15use serde::Serialize;
16use serde::de::DeserializeOwned;
17use serde_json::Value;
18
19use crate::context::RequestContext;
20use crate::error::{Error, Result};
21use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
22
23#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
44pub struct NoParams;
45
46impl<'de> serde::Deserialize<'de> for NoParams {
47 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
48 where
49 D: serde::Deserializer<'de>,
50 {
51 struct NoParamsVisitor;
53
54 impl<'de> serde::de::Visitor<'de> for NoParamsVisitor {
55 type Value = NoParams;
56
57 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
58 formatter.write_str("null or an object")
59 }
60
61 fn visit_unit<E>(self) -> std::result::Result<Self::Value, E>
62 where
63 E: serde::de::Error,
64 {
65 Ok(NoParams)
66 }
67
68 fn visit_none<E>(self) -> std::result::Result<Self::Value, E>
69 where
70 E: serde::de::Error,
71 {
72 Ok(NoParams)
73 }
74
75 fn visit_some<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
76 where
77 D: serde::Deserializer<'de>,
78 {
79 serde::Deserialize::deserialize(deserializer)
80 }
81
82 fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
83 where
84 A: serde::de::MapAccess<'de>,
85 {
86 while map
88 .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
89 .is_some()
90 {}
91 Ok(NoParams)
92 }
93 }
94
95 deserializer.deserialize_any(NoParamsVisitor)
96 }
97}
98
99impl JsonSchema for NoParams {
100 fn schema_name() -> Cow<'static, str> {
101 Cow::Borrowed("NoParams")
102 }
103
104 fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
105 serde_json::json!({
106 "type": "object"
107 })
108 .try_into()
109 .expect("valid schema")
110 }
111}
112
113pub fn validate_tool_name(name: &str) -> Result<()> {
121 if name.is_empty() {
122 return Err(Error::tool("Tool name cannot be empty"));
123 }
124 if name.len() > 128 {
125 return Err(Error::tool(format!(
126 "Tool name '{}' exceeds maximum length of 128 characters (got {})",
127 name,
128 name.len()
129 )));
130 }
131 if let Some(invalid_char) = name
132 .chars()
133 .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
134 {
135 return Err(Error::tool(format!(
136 "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
137 name, invalid_char
138 )));
139 }
140 Ok(())
141}
142
143pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
145
146pub trait ToolHandler: Send + Sync {
148 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
150
151 fn call_with_context(
156 &self,
157 _ctx: RequestContext,
158 args: Value,
159 ) -> BoxFuture<'_, Result<CallToolResult>> {
160 self.call(args)
161 }
162
163 fn uses_context(&self) -> bool {
165 false
166 }
167
168 fn input_schema(&self) -> Value;
170}
171
172pub struct Tool {
174 pub name: String,
175 pub title: Option<String>,
176 pub description: Option<String>,
177 pub output_schema: Option<Value>,
178 pub icons: Option<Vec<ToolIcon>>,
179 pub annotations: Option<ToolAnnotations>,
180 handler: Arc<dyn ToolHandler>,
181}
182
183impl std::fmt::Debug for Tool {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 f.debug_struct("Tool")
186 .field("name", &self.name)
187 .field("title", &self.title)
188 .field("description", &self.description)
189 .field("output_schema", &self.output_schema)
190 .field("icons", &self.icons)
191 .field("annotations", &self.annotations)
192 .finish_non_exhaustive()
193 }
194}
195
196impl Tool {
197 pub fn builder(name: impl Into<String>) -> ToolBuilder {
199 ToolBuilder::new(name)
200 }
201
202 pub fn definition(&self) -> ToolDefinition {
204 ToolDefinition {
205 name: self.name.clone(),
206 title: self.title.clone(),
207 description: self.description.clone(),
208 input_schema: self.handler.input_schema(),
209 output_schema: self.output_schema.clone(),
210 icons: self.icons.clone(),
211 annotations: self.annotations.clone(),
212 }
213 }
214
215 pub fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
217 self.handler.call(args)
218 }
219
220 pub fn call_with_context(
224 &self,
225 ctx: RequestContext,
226 args: Value,
227 ) -> BoxFuture<'_, Result<CallToolResult>> {
228 self.handler.call_with_context(ctx, args)
229 }
230
231 pub fn uses_context(&self) -> bool {
233 self.handler.uses_context()
234 }
235}
236
237pub struct ToolBuilder {
266 name: String,
267 title: Option<String>,
268 description: Option<String>,
269 output_schema: Option<Value>,
270 icons: Option<Vec<ToolIcon>>,
271 annotations: Option<ToolAnnotations>,
272}
273
274impl ToolBuilder {
275 pub fn new(name: impl Into<String>) -> Self {
276 Self {
277 name: name.into(),
278 title: None,
279 description: None,
280 output_schema: None,
281 icons: None,
282 annotations: None,
283 }
284 }
285
286 pub fn title(mut self, title: impl Into<String>) -> Self {
288 self.title = Some(title.into());
289 self
290 }
291
292 pub fn output_schema(mut self, schema: Value) -> Self {
294 self.output_schema = Some(schema);
295 self
296 }
297
298 pub fn icon(mut self, src: impl Into<String>) -> Self {
300 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
301 src: src.into(),
302 mime_type: None,
303 sizes: None,
304 });
305 self
306 }
307
308 pub fn icon_with_meta(
310 mut self,
311 src: impl Into<String>,
312 mime_type: Option<String>,
313 sizes: Option<Vec<String>>,
314 ) -> Self {
315 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
316 src: src.into(),
317 mime_type,
318 sizes,
319 });
320 self
321 }
322
323 pub fn description(mut self, description: impl Into<String>) -> Self {
325 self.description = Some(description.into());
326 self
327 }
328
329 pub fn read_only(mut self) -> Self {
331 self.annotations
332 .get_or_insert_with(ToolAnnotations::default)
333 .read_only_hint = true;
334 self
335 }
336
337 pub fn non_destructive(mut self) -> Self {
339 self.annotations
340 .get_or_insert_with(ToolAnnotations::default)
341 .destructive_hint = false;
342 self
343 }
344
345 pub fn idempotent(mut self) -> Self {
347 self.annotations
348 .get_or_insert_with(ToolAnnotations::default)
349 .idempotent_hint = true;
350 self
351 }
352
353 pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
355 self.annotations = Some(annotations);
356 self
357 }
358
359 pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
403 where
404 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
405 F: Fn(I) -> Fut + Send + Sync + 'static,
406 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
407 {
408 ToolBuilderWithHandler {
409 name: self.name,
410 title: self.title,
411 description: self.description,
412 output_schema: self.output_schema,
413 icons: self.icons,
414 annotations: self.annotations,
415 handler,
416 _phantom: std::marker::PhantomData,
417 }
418 }
419
420 pub fn handler_with_context<I, F, Fut>(self, handler: F) -> ToolBuilderWithContextHandler<I, F>
453 where
454 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
455 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
456 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
457 {
458 ToolBuilderWithContextHandler {
459 name: self.name,
460 title: self.title,
461 description: self.description,
462 output_schema: self.output_schema,
463 icons: self.icons,
464 annotations: self.annotations,
465 handler,
466 _phantom: std::marker::PhantomData,
467 }
468 }
469
470 pub fn handler_with_state<S, I, F, Fut>(
500 self,
501 state: S,
502 handler: F,
503 ) -> ToolBuilderWithHandler<
504 I,
505 impl Fn(I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
506 >
507 where
508 S: Clone + Send + Sync + 'static,
509 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
510 F: Fn(S, I) -> Fut + Send + Sync + 'static,
511 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
512 {
513 let handler = Arc::new(handler);
514 self.handler(move |input: I| {
515 let state = state.clone();
516 let handler = handler.clone();
517 Box::pin(async move { handler(state, input).await })
518 as BoxFuture<'static, Result<CallToolResult>>
519 })
520 }
521
522 pub fn handler_with_state_and_context<S, I, F, Fut>(
552 self,
553 state: S,
554 handler: F,
555 ) -> ToolBuilderWithContextHandler<
556 I,
557 impl Fn(RequestContext, I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
558 >
559 where
560 S: Clone + Send + Sync + 'static,
561 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
562 F: Fn(S, RequestContext, I) -> Fut + Send + Sync + 'static,
563 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
564 {
565 let handler = Arc::new(handler);
566 self.handler_with_context(move |ctx: RequestContext, input: I| {
567 let state = state.clone();
568 let handler = handler.clone();
569 Box::pin(async move { handler(state, ctx, input).await })
570 as BoxFuture<'static, Result<CallToolResult>>
571 })
572 }
573
574 pub fn handler_no_params<F, Fut>(self, handler: F) -> Result<Tool>
594 where
595 F: Fn() -> Fut + Send + Sync + 'static,
596 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
597 {
598 validate_tool_name(&self.name)?;
599 Ok(Tool {
600 name: self.name,
601 title: self.title,
602 description: self.description,
603 output_schema: self.output_schema,
604 icons: self.icons,
605 annotations: self.annotations,
606 handler: Arc::new(NoParamsHandler { handler }),
607 })
608 }
609
610 pub fn raw_handler<F, Fut>(self, handler: F) -> Result<Tool>
614 where
615 F: Fn(Value) -> Fut + Send + Sync + 'static,
616 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
617 {
618 validate_tool_name(&self.name)?;
619 Ok(Tool {
620 name: self.name,
621 title: self.title,
622 description: self.description,
623 output_schema: self.output_schema,
624 icons: self.icons,
625 annotations: self.annotations,
626 handler: Arc::new(RawHandler { handler }),
627 })
628 }
629
630 pub fn raw_handler_with_context<F, Fut>(self, handler: F) -> Result<Tool>
637 where
638 F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
639 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
640 {
641 validate_tool_name(&self.name)?;
642 Ok(Tool {
643 name: self.name,
644 title: self.title,
645 description: self.description,
646 output_schema: self.output_schema,
647 icons: self.icons,
648 annotations: self.annotations,
649 handler: Arc::new(RawContextHandler { handler }),
650 })
651 }
652}
653
654pub struct ToolBuilderWithHandler<I, F> {
656 name: String,
657 title: Option<String>,
658 description: Option<String>,
659 output_schema: Option<Value>,
660 icons: Option<Vec<ToolIcon>>,
661 annotations: Option<ToolAnnotations>,
662 handler: F,
663 _phantom: std::marker::PhantomData<I>,
664}
665
666impl<I, F, Fut> ToolBuilderWithHandler<I, F>
667where
668 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
669 F: Fn(I) -> Fut + Send + Sync + 'static,
670 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
671{
672 pub fn build(self) -> Result<Tool> {
676 validate_tool_name(&self.name)?;
677 Ok(Tool {
678 name: self.name,
679 title: self.title,
680 description: self.description,
681 output_schema: self.output_schema,
682 icons: self.icons,
683 annotations: self.annotations,
684 handler: Arc::new(TypedHandler {
685 handler: self.handler,
686 _phantom: std::marker::PhantomData,
687 }),
688 })
689 }
690}
691
692pub struct ToolBuilderWithContextHandler<I, F> {
694 name: String,
695 title: Option<String>,
696 description: Option<String>,
697 output_schema: Option<Value>,
698 icons: Option<Vec<ToolIcon>>,
699 annotations: Option<ToolAnnotations>,
700 handler: F,
701 _phantom: std::marker::PhantomData<I>,
702}
703
704impl<I, F, Fut> ToolBuilderWithContextHandler<I, F>
705where
706 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
707 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
708 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
709{
710 pub fn build(self) -> Result<Tool> {
714 validate_tool_name(&self.name)?;
715 Ok(Tool {
716 name: self.name,
717 title: self.title,
718 description: self.description,
719 output_schema: self.output_schema,
720 icons: self.icons,
721 annotations: self.annotations,
722 handler: Arc::new(ContextAwareHandler {
723 handler: self.handler,
724 _phantom: std::marker::PhantomData,
725 }),
726 })
727 }
728}
729
730struct TypedHandler<I, F> {
736 handler: F,
737 _phantom: std::marker::PhantomData<I>,
738}
739
740impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
741where
742 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
743 F: Fn(I) -> Fut + Send + Sync + 'static,
744 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
745{
746 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
747 Box::pin(async move {
748 let input: I = serde_json::from_value(args)
749 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
750 (self.handler)(input).await
751 })
752 }
753
754 fn input_schema(&self) -> Value {
755 let schema = schemars::schema_for!(I);
756 serde_json::to_value(schema).unwrap_or_else(|_| {
757 serde_json::json!({
758 "type": "object"
759 })
760 })
761 }
762}
763
764struct RawHandler<F> {
766 handler: F,
767}
768
769impl<F, Fut> ToolHandler for RawHandler<F>
770where
771 F: Fn(Value) -> Fut + Send + Sync + 'static,
772 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
773{
774 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
775 Box::pin((self.handler)(args))
776 }
777
778 fn input_schema(&self) -> Value {
779 serde_json::json!({
781 "type": "object",
782 "additionalProperties": true
783 })
784 }
785}
786
787struct RawContextHandler<F> {
789 handler: F,
790}
791
792impl<F, Fut> ToolHandler for RawContextHandler<F>
793where
794 F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
795 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
796{
797 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
798 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
799 self.call_with_context(ctx, args)
800 }
801
802 fn call_with_context(
803 &self,
804 ctx: RequestContext,
805 args: Value,
806 ) -> BoxFuture<'_, Result<CallToolResult>> {
807 Box::pin((self.handler)(ctx, args))
808 }
809
810 fn uses_context(&self) -> bool {
811 true
812 }
813
814 fn input_schema(&self) -> Value {
815 serde_json::json!({
817 "type": "object",
818 "additionalProperties": true
819 })
820 }
821}
822
823struct ContextAwareHandler<I, F> {
825 handler: F,
826 _phantom: std::marker::PhantomData<I>,
827}
828
829impl<I, F, Fut> ToolHandler for ContextAwareHandler<I, F>
830where
831 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
832 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
833 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
834{
835 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
836 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
838 self.call_with_context(ctx, args)
839 }
840
841 fn call_with_context(
842 &self,
843 ctx: RequestContext,
844 args: Value,
845 ) -> BoxFuture<'_, Result<CallToolResult>> {
846 Box::pin(async move {
847 let input: I = serde_json::from_value(args)
848 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
849 (self.handler)(ctx, input).await
850 })
851 }
852
853 fn uses_context(&self) -> bool {
854 true
855 }
856
857 fn input_schema(&self) -> Value {
858 let schema = schemars::schema_for!(I);
859 serde_json::to_value(schema).unwrap_or_else(|_| {
860 serde_json::json!({
861 "type": "object"
862 })
863 })
864 }
865}
866
867struct NoParamsHandler<F> {
869 handler: F,
870}
871
872impl<F, Fut> ToolHandler for NoParamsHandler<F>
873where
874 F: Fn() -> Fut + Send + Sync + 'static,
875 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
876{
877 fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
878 Box::pin((self.handler)())
879 }
880
881 fn input_schema(&self) -> Value {
882 serde_json::json!({
883 "type": "object",
884 "properties": {}
885 })
886 }
887}
888
889pub trait McpTool: Send + Sync + 'static {
930 const NAME: &'static str;
931 const DESCRIPTION: &'static str;
932
933 type Input: JsonSchema + DeserializeOwned + Send;
934 type Output: Serialize + Send;
935
936 fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
937
938 fn annotations(&self) -> Option<ToolAnnotations> {
940 None
941 }
942
943 fn into_tool(self) -> Result<Tool>
947 where
948 Self: Sized,
949 {
950 validate_tool_name(Self::NAME)?;
951 let annotations = self.annotations();
952 let tool = Arc::new(self);
953 Ok(Tool {
954 name: Self::NAME.to_string(),
955 title: None,
956 description: Some(Self::DESCRIPTION.to_string()),
957 output_schema: None,
958 icons: None,
959 annotations,
960 handler: Arc::new(McpToolHandler { tool }),
961 })
962 }
963}
964
965struct McpToolHandler<T: McpTool> {
967 tool: Arc<T>,
968}
969
970impl<T: McpTool> ToolHandler for McpToolHandler<T> {
971 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
972 let tool = self.tool.clone();
973 Box::pin(async move {
974 let input: T::Input = serde_json::from_value(args)
975 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
976 let output = tool.call(input).await?;
977 let value = serde_json::to_value(output)
978 .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
979 Ok(CallToolResult::json(value))
980 })
981 }
982
983 fn input_schema(&self) -> Value {
984 let schema = schemars::schema_for!(T::Input);
985 serde_json::to_value(schema).unwrap_or_else(|_| {
986 serde_json::json!({
987 "type": "object"
988 })
989 })
990 }
991}
992
993#[cfg(test)]
994mod tests {
995 use super::*;
996 use schemars::JsonSchema;
997 use serde::Deserialize;
998
999 #[derive(Debug, Deserialize, JsonSchema)]
1000 struct GreetInput {
1001 name: String,
1002 }
1003
1004 #[tokio::test]
1005 async fn test_builder_tool() {
1006 let tool = ToolBuilder::new("greet")
1007 .description("Greet someone")
1008 .handler(|input: GreetInput| async move {
1009 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1010 })
1011 .build()
1012 .expect("valid tool name");
1013
1014 assert_eq!(tool.name, "greet");
1015 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1016
1017 let result = tool
1018 .call(serde_json::json!({"name": "World"}))
1019 .await
1020 .unwrap();
1021
1022 assert!(!result.is_error);
1023 }
1024
1025 #[tokio::test]
1026 async fn test_raw_handler() {
1027 let tool = ToolBuilder::new("echo")
1028 .description("Echo input")
1029 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) })
1030 .expect("valid tool name");
1031
1032 let result = tool.call(serde_json::json!({"foo": "bar"})).await.unwrap();
1033
1034 assert!(!result.is_error);
1035 }
1036
1037 #[test]
1038 fn test_invalid_tool_name_empty() {
1039 let result = ToolBuilder::new("")
1040 .description("Empty name")
1041 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1042
1043 assert!(result.is_err());
1044 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
1045 }
1046
1047 #[test]
1048 fn test_invalid_tool_name_too_long() {
1049 let long_name = "a".repeat(129);
1050 let result = ToolBuilder::new(long_name)
1051 .description("Too long")
1052 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1053
1054 assert!(result.is_err());
1055 assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
1056 }
1057
1058 #[test]
1059 fn test_invalid_tool_name_bad_chars() {
1060 let result = ToolBuilder::new("my tool!")
1061 .description("Bad chars")
1062 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1063
1064 assert!(result.is_err());
1065 assert!(
1066 result
1067 .unwrap_err()
1068 .to_string()
1069 .contains("invalid character")
1070 );
1071 }
1072
1073 #[test]
1074 fn test_valid_tool_names() {
1075 let names = [
1077 "my_tool",
1078 "my-tool",
1079 "my.tool",
1080 "MyTool123",
1081 "a",
1082 &"a".repeat(128),
1083 ];
1084 for name in names {
1085 let result = ToolBuilder::new(name)
1086 .description("Valid")
1087 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
1088 assert!(result.is_ok(), "Expected '{}' to be valid", name);
1089 }
1090 }
1091
1092 #[tokio::test]
1093 async fn test_context_aware_handler() {
1094 use crate::context::{RequestContext, notification_channel};
1095 use crate::protocol::{ProgressToken, RequestId};
1096
1097 #[derive(Debug, Deserialize, JsonSchema)]
1098 struct ProcessInput {
1099 count: i32,
1100 }
1101
1102 let tool = ToolBuilder::new("process")
1103 .description("Process with context")
1104 .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
1105 for i in 0..input.count {
1107 if ctx.is_cancelled() {
1108 return Ok(CallToolResult::error("Cancelled"));
1109 }
1110 ctx.report_progress(i as f64, Some(input.count as f64), None)
1111 .await;
1112 }
1113 Ok(CallToolResult::text(format!(
1114 "Processed {} items",
1115 input.count
1116 )))
1117 })
1118 .build()
1119 .expect("valid tool name");
1120
1121 assert_eq!(tool.name, "process");
1122 assert!(tool.uses_context());
1123
1124 let (tx, mut rx) = notification_channel(10);
1126 let ctx = RequestContext::new(RequestId::Number(1))
1127 .with_progress_token(ProgressToken::Number(42))
1128 .with_notification_sender(tx);
1129
1130 let result = tool
1131 .call_with_context(ctx, serde_json::json!({"count": 3}))
1132 .await
1133 .unwrap();
1134
1135 assert!(!result.is_error);
1136
1137 let mut progress_count = 0;
1139 while rx.try_recv().is_ok() {
1140 progress_count += 1;
1141 }
1142 assert_eq!(progress_count, 3);
1143 }
1144
1145 #[tokio::test]
1146 async fn test_context_aware_handler_cancellation() {
1147 use crate::context::RequestContext;
1148 use crate::protocol::RequestId;
1149 use std::sync::Arc;
1150 use std::sync::atomic::{AtomicI32, Ordering};
1151
1152 #[derive(Debug, Deserialize, JsonSchema)]
1153 struct LongRunningInput {
1154 iterations: i32,
1155 }
1156
1157 let iterations_completed = Arc::new(AtomicI32::new(0));
1158 let iterations_ref = iterations_completed.clone();
1159
1160 let tool = ToolBuilder::new("long_running")
1161 .description("Long running task")
1162 .handler_with_context(move |ctx: RequestContext, input: LongRunningInput| {
1163 let completed = iterations_ref.clone();
1164 async move {
1165 for i in 0..input.iterations {
1166 if ctx.is_cancelled() {
1167 return Ok(CallToolResult::error("Cancelled"));
1168 }
1169 completed.fetch_add(1, Ordering::SeqCst);
1170 tokio::task::yield_now().await;
1172 if i == 2 {
1174 ctx.cancellation_token().cancel();
1175 }
1176 }
1177 Ok(CallToolResult::text("Done"))
1178 }
1179 })
1180 .build()
1181 .expect("valid tool name");
1182
1183 let ctx = RequestContext::new(RequestId::Number(1));
1184
1185 let result = tool
1186 .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1187 .await
1188 .unwrap();
1189
1190 assert!(result.is_error);
1193 assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1194 }
1195
1196 #[tokio::test]
1197 async fn test_tool_builder_with_enhanced_fields() {
1198 let output_schema = serde_json::json!({
1199 "type": "object",
1200 "properties": {
1201 "greeting": {"type": "string"}
1202 }
1203 });
1204
1205 let tool = ToolBuilder::new("greet")
1206 .title("Greeting Tool")
1207 .description("Greet someone")
1208 .output_schema(output_schema.clone())
1209 .icon("https://example.com/icon.png")
1210 .icon_with_meta(
1211 "https://example.com/icon-large.png",
1212 Some("image/png".to_string()),
1213 Some(vec!["96x96".to_string()]),
1214 )
1215 .handler(|input: GreetInput| async move {
1216 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1217 })
1218 .build()
1219 .expect("valid tool name");
1220
1221 assert_eq!(tool.name, "greet");
1222 assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1223 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1224 assert_eq!(tool.output_schema, Some(output_schema));
1225 assert!(tool.icons.is_some());
1226 assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1227
1228 let def = tool.definition();
1230 assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1231 assert!(def.output_schema.is_some());
1232 assert!(def.icons.is_some());
1233 }
1234
1235 #[tokio::test]
1236 async fn test_handler_with_state() {
1237 let shared = Arc::new("shared-state".to_string());
1238
1239 let tool = ToolBuilder::new("stateful")
1240 .description("Uses shared state")
1241 .handler_with_state(shared, |state: Arc<String>, input: GreetInput| async move {
1242 Ok(CallToolResult::text(format!(
1243 "{}: Hello, {}!",
1244 state, input.name
1245 )))
1246 })
1247 .build()
1248 .expect("valid tool name");
1249
1250 let result = tool
1251 .call(serde_json::json!({"name": "World"}))
1252 .await
1253 .unwrap();
1254 assert!(!result.is_error);
1255 }
1256
1257 #[tokio::test]
1258 async fn test_handler_with_state_and_context() {
1259 use crate::context::RequestContext;
1260 use crate::protocol::RequestId;
1261
1262 let shared = Arc::new(42_i32);
1263
1264 let tool = ToolBuilder::new("stateful_ctx")
1265 .description("Uses state and context")
1266 .handler_with_state_and_context(
1267 shared,
1268 |state: Arc<i32>, _ctx: RequestContext, input: GreetInput| async move {
1269 Ok(CallToolResult::text(format!(
1270 "{}: Hello, {}!",
1271 state, input.name
1272 )))
1273 },
1274 )
1275 .build()
1276 .expect("valid tool name");
1277
1278 assert!(tool.uses_context());
1279
1280 let ctx = RequestContext::new(RequestId::Number(1));
1281 let result = tool
1282 .call_with_context(ctx, serde_json::json!({"name": "World"}))
1283 .await
1284 .unwrap();
1285 assert!(!result.is_error);
1286 }
1287
1288 #[tokio::test]
1289 async fn test_handler_no_params() {
1290 let tool = ToolBuilder::new("no_params")
1291 .description("Takes no parameters")
1292 .handler_no_params(|| async { Ok(CallToolResult::text("no params result")) })
1293 .expect("valid tool name");
1294
1295 assert_eq!(tool.name, "no_params");
1296
1297 let result = tool.call(serde_json::json!({})).await.unwrap();
1299 assert!(!result.is_error);
1300
1301 let result = tool
1303 .call(serde_json::json!({"unexpected": "value"}))
1304 .await
1305 .unwrap();
1306 assert!(!result.is_error);
1307
1308 let schema = tool.definition().input_schema;
1310 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1311 assert!(
1312 schema
1313 .get("properties")
1314 .unwrap()
1315 .as_object()
1316 .unwrap()
1317 .is_empty()
1318 );
1319 }
1320
1321 #[test]
1322 fn test_no_params_schema() {
1323 let schema = schemars::schema_for!(NoParams);
1325 let schema_value = serde_json::to_value(&schema).unwrap();
1326 assert_eq!(
1327 schema_value.get("type").and_then(|v| v.as_str()),
1328 Some("object"),
1329 "NoParams should generate type: object schema"
1330 );
1331 }
1332
1333 #[test]
1334 fn test_no_params_deserialize() {
1335 let from_empty_object: NoParams = serde_json::from_str("{}").unwrap();
1337 assert_eq!(from_empty_object, NoParams);
1338
1339 let from_null: NoParams = serde_json::from_str("null").unwrap();
1340 assert_eq!(from_null, NoParams);
1341
1342 let from_object_with_fields: NoParams =
1344 serde_json::from_str(r#"{"unexpected": "value"}"#).unwrap();
1345 assert_eq!(from_object_with_fields, NoParams);
1346 }
1347
1348 #[tokio::test]
1349 async fn test_no_params_type_in_handler() {
1350 let tool = ToolBuilder::new("status")
1352 .description("Get status")
1353 .handler(|_input: NoParams| async move { Ok(CallToolResult::text("OK")) })
1354 .build()
1355 .expect("valid tool name");
1356
1357 let schema = tool.definition().input_schema;
1359 assert_eq!(
1360 schema.get("type").and_then(|v| v.as_str()),
1361 Some("object"),
1362 "NoParams handler should produce type: object schema"
1363 );
1364
1365 let result = tool.call(serde_json::json!({})).await.unwrap();
1367 assert!(!result.is_error);
1368 }
1369}