1use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use schemars::JsonSchema;
14use serde::Serialize;
15use serde::de::DeserializeOwned;
16use serde_json::Value;
17
18use crate::context::RequestContext;
19use crate::error::{Error, Result};
20use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
21
22pub fn validate_tool_name(name: &str) -> Result<()> {
30 if name.is_empty() {
31 return Err(Error::tool("Tool name cannot be empty"));
32 }
33 if name.len() > 128 {
34 return Err(Error::tool(format!(
35 "Tool name '{}' exceeds maximum length of 128 characters (got {})",
36 name,
37 name.len()
38 )));
39 }
40 if let Some(invalid_char) = name
41 .chars()
42 .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
43 {
44 return Err(Error::tool(format!(
45 "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
46 name, invalid_char
47 )));
48 }
49 Ok(())
50}
51
52pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
54
55pub trait ToolHandler: Send + Sync {
57 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
59
60 fn call_with_context(
65 &self,
66 _ctx: RequestContext,
67 args: Value,
68 ) -> BoxFuture<'_, Result<CallToolResult>> {
69 self.call(args)
70 }
71
72 fn uses_context(&self) -> bool {
74 false
75 }
76
77 fn input_schema(&self) -> Value;
79}
80
81pub struct Tool {
83 pub name: String,
84 pub title: Option<String>,
85 pub description: Option<String>,
86 pub output_schema: Option<Value>,
87 pub icons: Option<Vec<ToolIcon>>,
88 pub annotations: Option<ToolAnnotations>,
89 handler: Arc<dyn ToolHandler>,
90}
91
92impl std::fmt::Debug for Tool {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("Tool")
95 .field("name", &self.name)
96 .field("title", &self.title)
97 .field("description", &self.description)
98 .field("output_schema", &self.output_schema)
99 .field("icons", &self.icons)
100 .field("annotations", &self.annotations)
101 .finish_non_exhaustive()
102 }
103}
104
105impl Tool {
106 pub fn builder(name: impl Into<String>) -> ToolBuilder {
108 ToolBuilder::new(name)
109 }
110
111 pub fn definition(&self) -> ToolDefinition {
113 ToolDefinition {
114 name: self.name.clone(),
115 title: self.title.clone(),
116 description: self.description.clone(),
117 input_schema: self.handler.input_schema(),
118 output_schema: self.output_schema.clone(),
119 icons: self.icons.clone(),
120 annotations: self.annotations.clone(),
121 }
122 }
123
124 pub fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
126 self.handler.call(args)
127 }
128
129 pub fn call_with_context(
133 &self,
134 ctx: RequestContext,
135 args: Value,
136 ) -> BoxFuture<'_, Result<CallToolResult>> {
137 self.handler.call_with_context(ctx, args)
138 }
139
140 pub fn uses_context(&self) -> bool {
142 self.handler.uses_context()
143 }
144}
145
146pub struct ToolBuilder {
175 name: String,
176 title: Option<String>,
177 description: Option<String>,
178 output_schema: Option<Value>,
179 icons: Option<Vec<ToolIcon>>,
180 annotations: Option<ToolAnnotations>,
181}
182
183impl ToolBuilder {
184 pub fn new(name: impl Into<String>) -> Self {
185 Self {
186 name: name.into(),
187 title: None,
188 description: None,
189 output_schema: None,
190 icons: None,
191 annotations: None,
192 }
193 }
194
195 pub fn title(mut self, title: impl Into<String>) -> Self {
197 self.title = Some(title.into());
198 self
199 }
200
201 pub fn output_schema(mut self, schema: Value) -> Self {
203 self.output_schema = Some(schema);
204 self
205 }
206
207 pub fn icon(mut self, src: impl Into<String>) -> Self {
209 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
210 src: src.into(),
211 mime_type: None,
212 sizes: None,
213 });
214 self
215 }
216
217 pub fn icon_with_meta(
219 mut self,
220 src: impl Into<String>,
221 mime_type: Option<String>,
222 sizes: Option<Vec<String>>,
223 ) -> Self {
224 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
225 src: src.into(),
226 mime_type,
227 sizes,
228 });
229 self
230 }
231
232 pub fn description(mut self, description: impl Into<String>) -> Self {
234 self.description = Some(description.into());
235 self
236 }
237
238 pub fn read_only(mut self) -> Self {
240 self.annotations
241 .get_or_insert_with(ToolAnnotations::default)
242 .read_only_hint = true;
243 self
244 }
245
246 pub fn non_destructive(mut self) -> Self {
248 self.annotations
249 .get_or_insert_with(ToolAnnotations::default)
250 .destructive_hint = false;
251 self
252 }
253
254 pub fn idempotent(mut self) -> Self {
256 self.annotations
257 .get_or_insert_with(ToolAnnotations::default)
258 .idempotent_hint = true;
259 self
260 }
261
262 pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
264 self.annotations = Some(annotations);
265 self
266 }
267
268 pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
312 where
313 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
314 F: Fn(I) -> Fut + Send + Sync + 'static,
315 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
316 {
317 ToolBuilderWithHandler {
318 name: self.name,
319 title: self.title,
320 description: self.description,
321 output_schema: self.output_schema,
322 icons: self.icons,
323 annotations: self.annotations,
324 handler,
325 _phantom: std::marker::PhantomData,
326 }
327 }
328
329 pub fn handler_with_context<I, F, Fut>(self, handler: F) -> ToolBuilderWithContextHandler<I, F>
362 where
363 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
364 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
365 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
366 {
367 ToolBuilderWithContextHandler {
368 name: self.name,
369 title: self.title,
370 description: self.description,
371 output_schema: self.output_schema,
372 icons: self.icons,
373 annotations: self.annotations,
374 handler,
375 _phantom: std::marker::PhantomData,
376 }
377 }
378
379 pub fn handler_with_state<S, I, F, Fut>(
409 self,
410 state: S,
411 handler: F,
412 ) -> ToolBuilderWithHandler<
413 I,
414 impl Fn(I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
415 >
416 where
417 S: Clone + Send + Sync + 'static,
418 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
419 F: Fn(S, I) -> Fut + Send + Sync + 'static,
420 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
421 {
422 let handler = Arc::new(handler);
423 self.handler(move |input: I| {
424 let state = state.clone();
425 let handler = handler.clone();
426 Box::pin(async move { handler(state, input).await })
427 as BoxFuture<'static, Result<CallToolResult>>
428 })
429 }
430
431 pub fn handler_with_state_and_context<S, I, F, Fut>(
461 self,
462 state: S,
463 handler: F,
464 ) -> ToolBuilderWithContextHandler<
465 I,
466 impl Fn(RequestContext, I) -> BoxFuture<'static, Result<CallToolResult>> + Send + Sync + 'static,
467 >
468 where
469 S: Clone + Send + Sync + 'static,
470 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
471 F: Fn(S, RequestContext, I) -> Fut + Send + Sync + 'static,
472 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
473 {
474 let handler = Arc::new(handler);
475 self.handler_with_context(move |ctx: RequestContext, input: I| {
476 let state = state.clone();
477 let handler = handler.clone();
478 Box::pin(async move { handler(state, ctx, input).await })
479 as BoxFuture<'static, Result<CallToolResult>>
480 })
481 }
482
483 pub fn handler_no_params<F, Fut>(self, handler: F) -> Result<Tool>
503 where
504 F: Fn() -> Fut + Send + Sync + 'static,
505 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
506 {
507 validate_tool_name(&self.name)?;
508 Ok(Tool {
509 name: self.name,
510 title: self.title,
511 description: self.description,
512 output_schema: self.output_schema,
513 icons: self.icons,
514 annotations: self.annotations,
515 handler: Arc::new(NoParamsHandler { handler }),
516 })
517 }
518
519 pub fn raw_handler<F, Fut>(self, handler: F) -> Result<Tool>
523 where
524 F: Fn(Value) -> Fut + Send + Sync + 'static,
525 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
526 {
527 validate_tool_name(&self.name)?;
528 Ok(Tool {
529 name: self.name,
530 title: self.title,
531 description: self.description,
532 output_schema: self.output_schema,
533 icons: self.icons,
534 annotations: self.annotations,
535 handler: Arc::new(RawHandler { handler }),
536 })
537 }
538
539 pub fn raw_handler_with_context<F, Fut>(self, handler: F) -> Result<Tool>
546 where
547 F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
548 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
549 {
550 validate_tool_name(&self.name)?;
551 Ok(Tool {
552 name: self.name,
553 title: self.title,
554 description: self.description,
555 output_schema: self.output_schema,
556 icons: self.icons,
557 annotations: self.annotations,
558 handler: Arc::new(RawContextHandler { handler }),
559 })
560 }
561}
562
563pub struct ToolBuilderWithHandler<I, F> {
565 name: String,
566 title: Option<String>,
567 description: Option<String>,
568 output_schema: Option<Value>,
569 icons: Option<Vec<ToolIcon>>,
570 annotations: Option<ToolAnnotations>,
571 handler: F,
572 _phantom: std::marker::PhantomData<I>,
573}
574
575impl<I, F, Fut> ToolBuilderWithHandler<I, F>
576where
577 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
578 F: Fn(I) -> Fut + Send + Sync + 'static,
579 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
580{
581 pub fn build(self) -> Result<Tool> {
585 validate_tool_name(&self.name)?;
586 Ok(Tool {
587 name: self.name,
588 title: self.title,
589 description: self.description,
590 output_schema: self.output_schema,
591 icons: self.icons,
592 annotations: self.annotations,
593 handler: Arc::new(TypedHandler {
594 handler: self.handler,
595 _phantom: std::marker::PhantomData,
596 }),
597 })
598 }
599}
600
601pub struct ToolBuilderWithContextHandler<I, F> {
603 name: String,
604 title: Option<String>,
605 description: Option<String>,
606 output_schema: Option<Value>,
607 icons: Option<Vec<ToolIcon>>,
608 annotations: Option<ToolAnnotations>,
609 handler: F,
610 _phantom: std::marker::PhantomData<I>,
611}
612
613impl<I, F, Fut> ToolBuilderWithContextHandler<I, F>
614where
615 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
616 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
617 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
618{
619 pub fn build(self) -> Result<Tool> {
623 validate_tool_name(&self.name)?;
624 Ok(Tool {
625 name: self.name,
626 title: self.title,
627 description: self.description,
628 output_schema: self.output_schema,
629 icons: self.icons,
630 annotations: self.annotations,
631 handler: Arc::new(ContextAwareHandler {
632 handler: self.handler,
633 _phantom: std::marker::PhantomData,
634 }),
635 })
636 }
637}
638
639struct TypedHandler<I, F> {
645 handler: F,
646 _phantom: std::marker::PhantomData<I>,
647}
648
649impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
650where
651 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
652 F: Fn(I) -> Fut + Send + Sync + 'static,
653 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
654{
655 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
656 Box::pin(async move {
657 let input: I = serde_json::from_value(args)
658 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
659 (self.handler)(input).await
660 })
661 }
662
663 fn input_schema(&self) -> Value {
664 let schema = schemars::schema_for!(I);
665 serde_json::to_value(schema).unwrap_or_else(|_| {
666 serde_json::json!({
667 "type": "object"
668 })
669 })
670 }
671}
672
673struct RawHandler<F> {
675 handler: F,
676}
677
678impl<F, Fut> ToolHandler for RawHandler<F>
679where
680 F: Fn(Value) -> Fut + Send + Sync + 'static,
681 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
682{
683 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
684 Box::pin((self.handler)(args))
685 }
686
687 fn input_schema(&self) -> Value {
688 serde_json::json!({
690 "type": "object",
691 "additionalProperties": true
692 })
693 }
694}
695
696struct RawContextHandler<F> {
698 handler: F,
699}
700
701impl<F, Fut> ToolHandler for RawContextHandler<F>
702where
703 F: Fn(RequestContext, Value) -> Fut + Send + Sync + 'static,
704 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
705{
706 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
707 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
708 self.call_with_context(ctx, args)
709 }
710
711 fn call_with_context(
712 &self,
713 ctx: RequestContext,
714 args: Value,
715 ) -> BoxFuture<'_, Result<CallToolResult>> {
716 Box::pin((self.handler)(ctx, args))
717 }
718
719 fn uses_context(&self) -> bool {
720 true
721 }
722
723 fn input_schema(&self) -> Value {
724 serde_json::json!({
726 "type": "object",
727 "additionalProperties": true
728 })
729 }
730}
731
732struct ContextAwareHandler<I, F> {
734 handler: F,
735 _phantom: std::marker::PhantomData<I>,
736}
737
738impl<I, F, Fut> ToolHandler for ContextAwareHandler<I, F>
739where
740 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
741 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
742 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
743{
744 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
745 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
747 self.call_with_context(ctx, args)
748 }
749
750 fn call_with_context(
751 &self,
752 ctx: RequestContext,
753 args: Value,
754 ) -> BoxFuture<'_, Result<CallToolResult>> {
755 Box::pin(async move {
756 let input: I = serde_json::from_value(args)
757 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
758 (self.handler)(ctx, input).await
759 })
760 }
761
762 fn uses_context(&self) -> bool {
763 true
764 }
765
766 fn input_schema(&self) -> Value {
767 let schema = schemars::schema_for!(I);
768 serde_json::to_value(schema).unwrap_or_else(|_| {
769 serde_json::json!({
770 "type": "object"
771 })
772 })
773 }
774}
775
776struct NoParamsHandler<F> {
778 handler: F,
779}
780
781impl<F, Fut> ToolHandler for NoParamsHandler<F>
782where
783 F: Fn() -> Fut + Send + Sync + 'static,
784 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
785{
786 fn call(&self, _args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
787 Box::pin((self.handler)())
788 }
789
790 fn input_schema(&self) -> Value {
791 serde_json::json!({
792 "type": "object",
793 "properties": {}
794 })
795 }
796}
797
798pub trait McpTool: Send + Sync + 'static {
839 const NAME: &'static str;
840 const DESCRIPTION: &'static str;
841
842 type Input: JsonSchema + DeserializeOwned + Send;
843 type Output: Serialize + Send;
844
845 fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
846
847 fn annotations(&self) -> Option<ToolAnnotations> {
849 None
850 }
851
852 fn into_tool(self) -> Result<Tool>
856 where
857 Self: Sized,
858 {
859 validate_tool_name(Self::NAME)?;
860 let annotations = self.annotations();
861 let tool = Arc::new(self);
862 Ok(Tool {
863 name: Self::NAME.to_string(),
864 title: None,
865 description: Some(Self::DESCRIPTION.to_string()),
866 output_schema: None,
867 icons: None,
868 annotations,
869 handler: Arc::new(McpToolHandler { tool }),
870 })
871 }
872}
873
874struct McpToolHandler<T: McpTool> {
876 tool: Arc<T>,
877}
878
879impl<T: McpTool> ToolHandler for McpToolHandler<T> {
880 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
881 let tool = self.tool.clone();
882 Box::pin(async move {
883 let input: T::Input = serde_json::from_value(args)
884 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
885 let output = tool.call(input).await?;
886 let value = serde_json::to_value(output)
887 .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
888 Ok(CallToolResult::json(value))
889 })
890 }
891
892 fn input_schema(&self) -> Value {
893 let schema = schemars::schema_for!(T::Input);
894 serde_json::to_value(schema).unwrap_or_else(|_| {
895 serde_json::json!({
896 "type": "object"
897 })
898 })
899 }
900}
901
902#[cfg(test)]
903mod tests {
904 use super::*;
905 use schemars::JsonSchema;
906 use serde::Deserialize;
907
908 #[derive(Debug, Deserialize, JsonSchema)]
909 struct GreetInput {
910 name: String,
911 }
912
913 #[tokio::test]
914 async fn test_builder_tool() {
915 let tool = ToolBuilder::new("greet")
916 .description("Greet someone")
917 .handler(|input: GreetInput| async move {
918 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
919 })
920 .build()
921 .expect("valid tool name");
922
923 assert_eq!(tool.name, "greet");
924 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
925
926 let result = tool
927 .call(serde_json::json!({"name": "World"}))
928 .await
929 .unwrap();
930
931 assert!(!result.is_error);
932 }
933
934 #[tokio::test]
935 async fn test_raw_handler() {
936 let tool = ToolBuilder::new("echo")
937 .description("Echo input")
938 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) })
939 .expect("valid tool name");
940
941 let result = tool.call(serde_json::json!({"foo": "bar"})).await.unwrap();
942
943 assert!(!result.is_error);
944 }
945
946 #[test]
947 fn test_invalid_tool_name_empty() {
948 let result = ToolBuilder::new("")
949 .description("Empty name")
950 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
951
952 assert!(result.is_err());
953 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
954 }
955
956 #[test]
957 fn test_invalid_tool_name_too_long() {
958 let long_name = "a".repeat(129);
959 let result = ToolBuilder::new(long_name)
960 .description("Too long")
961 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
962
963 assert!(result.is_err());
964 assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
965 }
966
967 #[test]
968 fn test_invalid_tool_name_bad_chars() {
969 let result = ToolBuilder::new("my tool!")
970 .description("Bad chars")
971 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
972
973 assert!(result.is_err());
974 assert!(
975 result
976 .unwrap_err()
977 .to_string()
978 .contains("invalid character")
979 );
980 }
981
982 #[test]
983 fn test_valid_tool_names() {
984 let names = [
986 "my_tool",
987 "my-tool",
988 "my.tool",
989 "MyTool123",
990 "a",
991 &"a".repeat(128),
992 ];
993 for name in names {
994 let result = ToolBuilder::new(name)
995 .description("Valid")
996 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
997 assert!(result.is_ok(), "Expected '{}' to be valid", name);
998 }
999 }
1000
1001 #[tokio::test]
1002 async fn test_context_aware_handler() {
1003 use crate::context::{RequestContext, notification_channel};
1004 use crate::protocol::{ProgressToken, RequestId};
1005
1006 #[derive(Debug, Deserialize, JsonSchema)]
1007 struct ProcessInput {
1008 count: i32,
1009 }
1010
1011 let tool = ToolBuilder::new("process")
1012 .description("Process with context")
1013 .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
1014 for i in 0..input.count {
1016 if ctx.is_cancelled() {
1017 return Ok(CallToolResult::error("Cancelled"));
1018 }
1019 ctx.report_progress(i as f64, Some(input.count as f64), None)
1020 .await;
1021 }
1022 Ok(CallToolResult::text(format!(
1023 "Processed {} items",
1024 input.count
1025 )))
1026 })
1027 .build()
1028 .expect("valid tool name");
1029
1030 assert_eq!(tool.name, "process");
1031 assert!(tool.uses_context());
1032
1033 let (tx, mut rx) = notification_channel(10);
1035 let ctx = RequestContext::new(RequestId::Number(1))
1036 .with_progress_token(ProgressToken::Number(42))
1037 .with_notification_sender(tx);
1038
1039 let result = tool
1040 .call_with_context(ctx, serde_json::json!({"count": 3}))
1041 .await
1042 .unwrap();
1043
1044 assert!(!result.is_error);
1045
1046 let mut progress_count = 0;
1048 while rx.try_recv().is_ok() {
1049 progress_count += 1;
1050 }
1051 assert_eq!(progress_count, 3);
1052 }
1053
1054 #[tokio::test]
1055 async fn test_context_aware_handler_cancellation() {
1056 use crate::context::RequestContext;
1057 use crate::protocol::RequestId;
1058 use std::sync::Arc;
1059 use std::sync::atomic::{AtomicI32, Ordering};
1060
1061 #[derive(Debug, Deserialize, JsonSchema)]
1062 struct LongRunningInput {
1063 iterations: i32,
1064 }
1065
1066 let iterations_completed = Arc::new(AtomicI32::new(0));
1067 let iterations_ref = iterations_completed.clone();
1068
1069 let tool = ToolBuilder::new("long_running")
1070 .description("Long running task")
1071 .handler_with_context(move |ctx: RequestContext, input: LongRunningInput| {
1072 let completed = iterations_ref.clone();
1073 async move {
1074 for i in 0..input.iterations {
1075 if ctx.is_cancelled() {
1076 return Ok(CallToolResult::error("Cancelled"));
1077 }
1078 completed.fetch_add(1, Ordering::SeqCst);
1079 tokio::task::yield_now().await;
1081 if i == 2 {
1083 ctx.cancellation_token().cancel();
1084 }
1085 }
1086 Ok(CallToolResult::text("Done"))
1087 }
1088 })
1089 .build()
1090 .expect("valid tool name");
1091
1092 let ctx = RequestContext::new(RequestId::Number(1));
1093
1094 let result = tool
1095 .call_with_context(ctx, serde_json::json!({"iterations": 10}))
1096 .await
1097 .unwrap();
1098
1099 assert!(result.is_error);
1102 assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
1103 }
1104
1105 #[tokio::test]
1106 async fn test_tool_builder_with_enhanced_fields() {
1107 let output_schema = serde_json::json!({
1108 "type": "object",
1109 "properties": {
1110 "greeting": {"type": "string"}
1111 }
1112 });
1113
1114 let tool = ToolBuilder::new("greet")
1115 .title("Greeting Tool")
1116 .description("Greet someone")
1117 .output_schema(output_schema.clone())
1118 .icon("https://example.com/icon.png")
1119 .icon_with_meta(
1120 "https://example.com/icon-large.png",
1121 Some("image/png".to_string()),
1122 Some(vec!["96x96".to_string()]),
1123 )
1124 .handler(|input: GreetInput| async move {
1125 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
1126 })
1127 .build()
1128 .expect("valid tool name");
1129
1130 assert_eq!(tool.name, "greet");
1131 assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
1132 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
1133 assert_eq!(tool.output_schema, Some(output_schema));
1134 assert!(tool.icons.is_some());
1135 assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
1136
1137 let def = tool.definition();
1139 assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
1140 assert!(def.output_schema.is_some());
1141 assert!(def.icons.is_some());
1142 }
1143
1144 #[tokio::test]
1145 async fn test_handler_with_state() {
1146 let shared = Arc::new("shared-state".to_string());
1147
1148 let tool = ToolBuilder::new("stateful")
1149 .description("Uses shared state")
1150 .handler_with_state(shared, |state: Arc<String>, input: GreetInput| async move {
1151 Ok(CallToolResult::text(format!(
1152 "{}: Hello, {}!",
1153 state, input.name
1154 )))
1155 })
1156 .build()
1157 .expect("valid tool name");
1158
1159 let result = tool
1160 .call(serde_json::json!({"name": "World"}))
1161 .await
1162 .unwrap();
1163 assert!(!result.is_error);
1164 }
1165
1166 #[tokio::test]
1167 async fn test_handler_with_state_and_context() {
1168 use crate::context::RequestContext;
1169 use crate::protocol::RequestId;
1170
1171 let shared = Arc::new(42_i32);
1172
1173 let tool = ToolBuilder::new("stateful_ctx")
1174 .description("Uses state and context")
1175 .handler_with_state_and_context(
1176 shared,
1177 |state: Arc<i32>, _ctx: RequestContext, input: GreetInput| async move {
1178 Ok(CallToolResult::text(format!(
1179 "{}: Hello, {}!",
1180 state, input.name
1181 )))
1182 },
1183 )
1184 .build()
1185 .expect("valid tool name");
1186
1187 assert!(tool.uses_context());
1188
1189 let ctx = RequestContext::new(RequestId::Number(1));
1190 let result = tool
1191 .call_with_context(ctx, serde_json::json!({"name": "World"}))
1192 .await
1193 .unwrap();
1194 assert!(!result.is_error);
1195 }
1196
1197 #[tokio::test]
1198 async fn test_handler_no_params() {
1199 let tool = ToolBuilder::new("no_params")
1200 .description("Takes no parameters")
1201 .handler_no_params(|| async { Ok(CallToolResult::text("no params result")) })
1202 .expect("valid tool name");
1203
1204 assert_eq!(tool.name, "no_params");
1205
1206 let result = tool.call(serde_json::json!({})).await.unwrap();
1208 assert!(!result.is_error);
1209
1210 let result = tool
1212 .call(serde_json::json!({"unexpected": "value"}))
1213 .await
1214 .unwrap();
1215 assert!(!result.is_error);
1216
1217 let schema = tool.definition().input_schema;
1219 assert_eq!(schema.get("type").unwrap().as_str().unwrap(), "object");
1220 assert!(
1221 schema
1222 .get("properties")
1223 .unwrap()
1224 .as_object()
1225 .unwrap()
1226 .is_empty()
1227 );
1228 }
1229}