1use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use schemars::JsonSchema;
14use serde::Serialize;
15use serde::de::DeserializeOwned;
16use serde_json::Value;
17
18use crate::context::RequestContext;
19use crate::error::{Error, Result};
20use crate::protocol::{CallToolResult, ToolAnnotations, ToolDefinition, ToolIcon};
21
22pub fn validate_tool_name(name: &str) -> Result<()> {
30 if name.is_empty() {
31 return Err(Error::tool("Tool name cannot be empty"));
32 }
33 if name.len() > 128 {
34 return Err(Error::tool(format!(
35 "Tool name '{}' exceeds maximum length of 128 characters (got {})",
36 name,
37 name.len()
38 )));
39 }
40 if let Some(invalid_char) = name
41 .chars()
42 .find(|c| !c.is_ascii_alphanumeric() && *c != '_' && *c != '-' && *c != '.')
43 {
44 return Err(Error::tool(format!(
45 "Tool name '{}' contains invalid character '{}'. Only alphanumeric, underscore, hyphen, and dot are allowed.",
46 name, invalid_char
47 )));
48 }
49 Ok(())
50}
51
52pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
54
55pub trait ToolHandler: Send + Sync {
57 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>>;
59
60 fn call_with_context(
65 &self,
66 _ctx: RequestContext,
67 args: Value,
68 ) -> BoxFuture<'_, Result<CallToolResult>> {
69 self.call(args)
70 }
71
72 fn uses_context(&self) -> bool {
74 false
75 }
76
77 fn input_schema(&self) -> Value;
79}
80
81pub struct Tool {
83 pub name: String,
84 pub title: Option<String>,
85 pub description: Option<String>,
86 pub output_schema: Option<Value>,
87 pub icons: Option<Vec<ToolIcon>>,
88 pub annotations: Option<ToolAnnotations>,
89 handler: Arc<dyn ToolHandler>,
90}
91
92impl std::fmt::Debug for Tool {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("Tool")
95 .field("name", &self.name)
96 .field("title", &self.title)
97 .field("description", &self.description)
98 .field("output_schema", &self.output_schema)
99 .field("icons", &self.icons)
100 .field("annotations", &self.annotations)
101 .finish_non_exhaustive()
102 }
103}
104
105impl Tool {
106 pub fn builder(name: impl Into<String>) -> ToolBuilder {
108 ToolBuilder::new(name)
109 }
110
111 pub fn definition(&self) -> ToolDefinition {
113 ToolDefinition {
114 name: self.name.clone(),
115 title: self.title.clone(),
116 description: self.description.clone(),
117 input_schema: self.handler.input_schema(),
118 output_schema: self.output_schema.clone(),
119 icons: self.icons.clone(),
120 annotations: self.annotations.clone(),
121 }
122 }
123
124 pub fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
126 self.handler.call(args)
127 }
128
129 pub fn call_with_context(
133 &self,
134 ctx: RequestContext,
135 args: Value,
136 ) -> BoxFuture<'_, Result<CallToolResult>> {
137 self.handler.call_with_context(ctx, args)
138 }
139
140 pub fn uses_context(&self) -> bool {
142 self.handler.uses_context()
143 }
144}
145
146pub struct ToolBuilder {
175 name: String,
176 title: Option<String>,
177 description: Option<String>,
178 output_schema: Option<Value>,
179 icons: Option<Vec<ToolIcon>>,
180 annotations: Option<ToolAnnotations>,
181}
182
183impl ToolBuilder {
184 pub fn new(name: impl Into<String>) -> Self {
185 Self {
186 name: name.into(),
187 title: None,
188 description: None,
189 output_schema: None,
190 icons: None,
191 annotations: None,
192 }
193 }
194
195 pub fn title(mut self, title: impl Into<String>) -> Self {
197 self.title = Some(title.into());
198 self
199 }
200
201 pub fn output_schema(mut self, schema: Value) -> Self {
203 self.output_schema = Some(schema);
204 self
205 }
206
207 pub fn icon(mut self, src: impl Into<String>) -> Self {
209 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
210 src: src.into(),
211 mime_type: None,
212 sizes: None,
213 });
214 self
215 }
216
217 pub fn icon_with_meta(
219 mut self,
220 src: impl Into<String>,
221 mime_type: Option<String>,
222 sizes: Option<Vec<String>>,
223 ) -> Self {
224 self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
225 src: src.into(),
226 mime_type,
227 sizes,
228 });
229 self
230 }
231
232 pub fn description(mut self, description: impl Into<String>) -> Self {
234 self.description = Some(description.into());
235 self
236 }
237
238 pub fn read_only(mut self) -> Self {
240 self.annotations
241 .get_or_insert_with(ToolAnnotations::default)
242 .read_only_hint = true;
243 self
244 }
245
246 pub fn non_destructive(mut self) -> Self {
248 self.annotations
249 .get_or_insert_with(ToolAnnotations::default)
250 .destructive_hint = false;
251 self
252 }
253
254 pub fn idempotent(mut self) -> Self {
256 self.annotations
257 .get_or_insert_with(ToolAnnotations::default)
258 .idempotent_hint = true;
259 self
260 }
261
262 pub fn annotations(mut self, annotations: ToolAnnotations) -> Self {
264 self.annotations = Some(annotations);
265 self
266 }
267
268 pub fn handler<I, F, Fut>(self, handler: F) -> ToolBuilderWithHandler<I, F>
312 where
313 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
314 F: Fn(I) -> Fut + Send + Sync + 'static,
315 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
316 {
317 ToolBuilderWithHandler {
318 name: self.name,
319 title: self.title,
320 description: self.description,
321 output_schema: self.output_schema,
322 icons: self.icons,
323 annotations: self.annotations,
324 handler,
325 _phantom: std::marker::PhantomData,
326 }
327 }
328
329 pub fn handler_with_context<I, F, Fut>(self, handler: F) -> ToolBuilderWithContextHandler<I, F>
362 where
363 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
364 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
365 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
366 {
367 ToolBuilderWithContextHandler {
368 name: self.name,
369 title: self.title,
370 description: self.description,
371 output_schema: self.output_schema,
372 icons: self.icons,
373 annotations: self.annotations,
374 handler,
375 _phantom: std::marker::PhantomData,
376 }
377 }
378
379 pub fn raw_handler<F, Fut>(self, handler: F) -> Result<Tool>
383 where
384 F: Fn(Value) -> Fut + Send + Sync + 'static,
385 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
386 {
387 validate_tool_name(&self.name)?;
388 Ok(Tool {
389 name: self.name,
390 title: self.title,
391 description: self.description,
392 output_schema: self.output_schema,
393 icons: self.icons,
394 annotations: self.annotations,
395 handler: Arc::new(RawHandler { handler }),
396 })
397 }
398}
399
400pub struct ToolBuilderWithHandler<I, F> {
402 name: String,
403 title: Option<String>,
404 description: Option<String>,
405 output_schema: Option<Value>,
406 icons: Option<Vec<ToolIcon>>,
407 annotations: Option<ToolAnnotations>,
408 handler: F,
409 _phantom: std::marker::PhantomData<I>,
410}
411
412impl<I, F, Fut> ToolBuilderWithHandler<I, F>
413where
414 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
415 F: Fn(I) -> Fut + Send + Sync + 'static,
416 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
417{
418 pub fn build(self) -> Result<Tool> {
422 validate_tool_name(&self.name)?;
423 Ok(Tool {
424 name: self.name,
425 title: self.title,
426 description: self.description,
427 output_schema: self.output_schema,
428 icons: self.icons,
429 annotations: self.annotations,
430 handler: Arc::new(TypedHandler {
431 handler: self.handler,
432 _phantom: std::marker::PhantomData,
433 }),
434 })
435 }
436}
437
438pub struct ToolBuilderWithContextHandler<I, F> {
440 name: String,
441 title: Option<String>,
442 description: Option<String>,
443 output_schema: Option<Value>,
444 icons: Option<Vec<ToolIcon>>,
445 annotations: Option<ToolAnnotations>,
446 handler: F,
447 _phantom: std::marker::PhantomData<I>,
448}
449
450impl<I, F, Fut> ToolBuilderWithContextHandler<I, F>
451where
452 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
453 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
454 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
455{
456 pub fn build(self) -> Result<Tool> {
460 validate_tool_name(&self.name)?;
461 Ok(Tool {
462 name: self.name,
463 title: self.title,
464 description: self.description,
465 output_schema: self.output_schema,
466 icons: self.icons,
467 annotations: self.annotations,
468 handler: Arc::new(ContextAwareHandler {
469 handler: self.handler,
470 _phantom: std::marker::PhantomData,
471 }),
472 })
473 }
474}
475
476struct TypedHandler<I, F> {
482 handler: F,
483 _phantom: std::marker::PhantomData<I>,
484}
485
486impl<I, F, Fut> ToolHandler for TypedHandler<I, F>
487where
488 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
489 F: Fn(I) -> Fut + Send + Sync + 'static,
490 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
491{
492 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
493 Box::pin(async move {
494 let input: I = serde_json::from_value(args)
495 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
496 (self.handler)(input).await
497 })
498 }
499
500 fn input_schema(&self) -> Value {
501 let schema = schemars::schema_for!(I);
502 serde_json::to_value(schema).unwrap_or_else(|_| {
503 serde_json::json!({
504 "type": "object"
505 })
506 })
507 }
508}
509
510struct RawHandler<F> {
512 handler: F,
513}
514
515impl<F, Fut> ToolHandler for RawHandler<F>
516where
517 F: Fn(Value) -> Fut + Send + Sync + 'static,
518 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
519{
520 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
521 Box::pin((self.handler)(args))
522 }
523
524 fn input_schema(&self) -> Value {
525 serde_json::json!({
527 "type": "object",
528 "additionalProperties": true
529 })
530 }
531}
532
533struct ContextAwareHandler<I, F> {
535 handler: F,
536 _phantom: std::marker::PhantomData<I>,
537}
538
539impl<I, F, Fut> ToolHandler for ContextAwareHandler<I, F>
540where
541 I: JsonSchema + DeserializeOwned + Send + Sync + 'static,
542 F: Fn(RequestContext, I) -> Fut + Send + Sync + 'static,
543 Fut: Future<Output = Result<CallToolResult>> + Send + 'static,
544{
545 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
546 let ctx = RequestContext::new(crate::protocol::RequestId::Number(0));
548 self.call_with_context(ctx, args)
549 }
550
551 fn call_with_context(
552 &self,
553 ctx: RequestContext,
554 args: Value,
555 ) -> BoxFuture<'_, Result<CallToolResult>> {
556 Box::pin(async move {
557 let input: I = serde_json::from_value(args)
558 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
559 (self.handler)(ctx, input).await
560 })
561 }
562
563 fn uses_context(&self) -> bool {
564 true
565 }
566
567 fn input_schema(&self) -> Value {
568 let schema = schemars::schema_for!(I);
569 serde_json::to_value(schema).unwrap_or_else(|_| {
570 serde_json::json!({
571 "type": "object"
572 })
573 })
574 }
575}
576
577pub trait McpTool: Send + Sync + 'static {
618 const NAME: &'static str;
619 const DESCRIPTION: &'static str;
620
621 type Input: JsonSchema + DeserializeOwned + Send;
622 type Output: Serialize + Send;
623
624 fn call(&self, input: Self::Input) -> impl Future<Output = Result<Self::Output>> + Send;
625
626 fn annotations(&self) -> Option<ToolAnnotations> {
628 None
629 }
630
631 fn into_tool(self) -> Result<Tool>
635 where
636 Self: Sized,
637 {
638 validate_tool_name(Self::NAME)?;
639 let annotations = self.annotations();
640 let tool = Arc::new(self);
641 Ok(Tool {
642 name: Self::NAME.to_string(),
643 title: None,
644 description: Some(Self::DESCRIPTION.to_string()),
645 output_schema: None,
646 icons: None,
647 annotations,
648 handler: Arc::new(McpToolHandler { tool }),
649 })
650 }
651}
652
653struct McpToolHandler<T: McpTool> {
655 tool: Arc<T>,
656}
657
658impl<T: McpTool> ToolHandler for McpToolHandler<T> {
659 fn call(&self, args: Value) -> BoxFuture<'_, Result<CallToolResult>> {
660 let tool = self.tool.clone();
661 Box::pin(async move {
662 let input: T::Input = serde_json::from_value(args)
663 .map_err(|e| Error::tool(format!("Invalid input: {}", e)))?;
664 let output = tool.call(input).await?;
665 let value = serde_json::to_value(output)
666 .map_err(|e| Error::tool(format!("Failed to serialize output: {}", e)))?;
667 Ok(CallToolResult::json(value))
668 })
669 }
670
671 fn input_schema(&self) -> Value {
672 let schema = schemars::schema_for!(T::Input);
673 serde_json::to_value(schema).unwrap_or_else(|_| {
674 serde_json::json!({
675 "type": "object"
676 })
677 })
678 }
679}
680
681#[cfg(test)]
682mod tests {
683 use super::*;
684 use schemars::JsonSchema;
685 use serde::Deserialize;
686
687 #[derive(Debug, Deserialize, JsonSchema)]
688 struct GreetInput {
689 name: String,
690 }
691
692 #[tokio::test]
693 async fn test_builder_tool() {
694 let tool = ToolBuilder::new("greet")
695 .description("Greet someone")
696 .handler(|input: GreetInput| async move {
697 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
698 })
699 .build()
700 .expect("valid tool name");
701
702 assert_eq!(tool.name, "greet");
703 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
704
705 let result = tool
706 .call(serde_json::json!({"name": "World"}))
707 .await
708 .unwrap();
709
710 assert!(!result.is_error);
711 }
712
713 #[tokio::test]
714 async fn test_raw_handler() {
715 let tool = ToolBuilder::new("echo")
716 .description("Echo input")
717 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) })
718 .expect("valid tool name");
719
720 let result = tool.call(serde_json::json!({"foo": "bar"})).await.unwrap();
721
722 assert!(!result.is_error);
723 }
724
725 #[test]
726 fn test_invalid_tool_name_empty() {
727 let result = ToolBuilder::new("")
728 .description("Empty name")
729 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
730
731 assert!(result.is_err());
732 assert!(result.unwrap_err().to_string().contains("cannot be empty"));
733 }
734
735 #[test]
736 fn test_invalid_tool_name_too_long() {
737 let long_name = "a".repeat(129);
738 let result = ToolBuilder::new(long_name)
739 .description("Too long")
740 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
741
742 assert!(result.is_err());
743 assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
744 }
745
746 #[test]
747 fn test_invalid_tool_name_bad_chars() {
748 let result = ToolBuilder::new("my tool!")
749 .description("Bad chars")
750 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
751
752 assert!(result.is_err());
753 assert!(
754 result
755 .unwrap_err()
756 .to_string()
757 .contains("invalid character")
758 );
759 }
760
761 #[test]
762 fn test_valid_tool_names() {
763 let names = [
765 "my_tool",
766 "my-tool",
767 "my.tool",
768 "MyTool123",
769 "a",
770 &"a".repeat(128),
771 ];
772 for name in names {
773 let result = ToolBuilder::new(name)
774 .description("Valid")
775 .raw_handler(|args: Value| async move { Ok(CallToolResult::json(args)) });
776 assert!(result.is_ok(), "Expected '{}' to be valid", name);
777 }
778 }
779
780 #[tokio::test]
781 async fn test_context_aware_handler() {
782 use crate::context::{RequestContext, notification_channel};
783 use crate::protocol::{ProgressToken, RequestId};
784
785 #[derive(Debug, Deserialize, JsonSchema)]
786 struct ProcessInput {
787 count: i32,
788 }
789
790 let tool = ToolBuilder::new("process")
791 .description("Process with context")
792 .handler_with_context(|ctx: RequestContext, input: ProcessInput| async move {
793 for i in 0..input.count {
795 if ctx.is_cancelled() {
796 return Ok(CallToolResult::error("Cancelled"));
797 }
798 ctx.report_progress(i as f64, Some(input.count as f64), None)
799 .await;
800 }
801 Ok(CallToolResult::text(format!(
802 "Processed {} items",
803 input.count
804 )))
805 })
806 .build()
807 .expect("valid tool name");
808
809 assert_eq!(tool.name, "process");
810 assert!(tool.uses_context());
811
812 let (tx, mut rx) = notification_channel(10);
814 let ctx = RequestContext::new(RequestId::Number(1))
815 .with_progress_token(ProgressToken::Number(42))
816 .with_notification_sender(tx);
817
818 let result = tool
819 .call_with_context(ctx, serde_json::json!({"count": 3}))
820 .await
821 .unwrap();
822
823 assert!(!result.is_error);
824
825 let mut progress_count = 0;
827 while rx.try_recv().is_ok() {
828 progress_count += 1;
829 }
830 assert_eq!(progress_count, 3);
831 }
832
833 #[tokio::test]
834 async fn test_context_aware_handler_cancellation() {
835 use crate::context::RequestContext;
836 use crate::protocol::RequestId;
837 use std::sync::Arc;
838 use std::sync::atomic::{AtomicI32, Ordering};
839
840 #[derive(Debug, Deserialize, JsonSchema)]
841 struct LongRunningInput {
842 iterations: i32,
843 }
844
845 let iterations_completed = Arc::new(AtomicI32::new(0));
846 let iterations_ref = iterations_completed.clone();
847
848 let tool = ToolBuilder::new("long_running")
849 .description("Long running task")
850 .handler_with_context(move |ctx: RequestContext, input: LongRunningInput| {
851 let completed = iterations_ref.clone();
852 async move {
853 for i in 0..input.iterations {
854 if ctx.is_cancelled() {
855 return Ok(CallToolResult::error("Cancelled"));
856 }
857 completed.fetch_add(1, Ordering::SeqCst);
858 tokio::task::yield_now().await;
860 if i == 2 {
862 ctx.cancellation_token().cancel();
863 }
864 }
865 Ok(CallToolResult::text("Done"))
866 }
867 })
868 .build()
869 .expect("valid tool name");
870
871 let ctx = RequestContext::new(RequestId::Number(1));
872
873 let result = tool
874 .call_with_context(ctx, serde_json::json!({"iterations": 10}))
875 .await
876 .unwrap();
877
878 assert!(result.is_error);
881 assert_eq!(iterations_completed.load(Ordering::SeqCst), 3);
882 }
883
884 #[tokio::test]
885 async fn test_tool_builder_with_enhanced_fields() {
886 let output_schema = serde_json::json!({
887 "type": "object",
888 "properties": {
889 "greeting": {"type": "string"}
890 }
891 });
892
893 let tool = ToolBuilder::new("greet")
894 .title("Greeting Tool")
895 .description("Greet someone")
896 .output_schema(output_schema.clone())
897 .icon("https://example.com/icon.png")
898 .icon_with_meta(
899 "https://example.com/icon-large.png",
900 Some("image/png".to_string()),
901 Some(vec!["96x96".to_string()]),
902 )
903 .handler(|input: GreetInput| async move {
904 Ok(CallToolResult::text(format!("Hello, {}!", input.name)))
905 })
906 .build()
907 .expect("valid tool name");
908
909 assert_eq!(tool.name, "greet");
910 assert_eq!(tool.title.as_deref(), Some("Greeting Tool"));
911 assert_eq!(tool.description.as_deref(), Some("Greet someone"));
912 assert_eq!(tool.output_schema, Some(output_schema));
913 assert!(tool.icons.is_some());
914 assert_eq!(tool.icons.as_ref().unwrap().len(), 2);
915
916 let def = tool.definition();
918 assert_eq!(def.title.as_deref(), Some("Greeting Tool"));
919 assert!(def.output_schema.is_some());
920 assert!(def.icons.is_some());
921 }
922}