1#![forbid(unsafe_code)]
54
55use proc_macro::TokenStream;
56use proc_macro2::TokenStream as TokenStream2;
57use quote::{format_ident, quote};
58use syn::{
59 Attribute, FnArg, Ident, ItemFn, Lit, LitStr, Meta, Pat, Token, Type, parse::Parse,
60 parse::ParseStream, parse_macro_input,
61};
62
63fn extract_doc_comments(attrs: &[Attribute]) -> Option<String> {
65 let docs: Vec<String> = attrs
66 .iter()
67 .filter_map(|attr| {
68 if attr.path().is_ident("doc") {
69 if let Meta::NameValue(nv) = &attr.meta {
70 if let syn::Expr::Lit(syn::ExprLit {
71 lit: Lit::Str(s), ..
72 }) = &nv.value
73 {
74 return Some(s.value().trim().to_string());
75 }
76 }
77 }
78 None
79 })
80 .collect();
81
82 if docs.is_empty() {
83 None
84 } else {
85 Some(docs.join("\n"))
86 }
87}
88
89fn is_mcp_context_ref(ty: &Type) -> bool {
91 if let Type::Reference(type_ref) = ty {
92 if let Type::Path(type_path) = type_ref.elem.as_ref() {
93 return type_path
94 .path
95 .segments
96 .last()
97 .is_some_and(|s| s.ident == "McpContext");
98 }
99 }
100 false
101}
102
103fn is_option_type(ty: &Type) -> bool {
105 if let Type::Path(type_path) = ty {
106 return type_path
107 .path
108 .segments
109 .last()
110 .is_some_and(|s| s.ident == "Option");
111 }
112 false
113}
114
115fn option_inner_type(ty: &Type) -> Option<&Type> {
117 if let Type::Path(type_path) = ty {
118 if let Some(segment) = type_path.path.segments.last() {
119 if segment.ident == "Option" {
120 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
121 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
122 return Some(inner_ty);
123 }
124 }
125 }
126 }
127 }
128 None
129}
130
131fn is_string_type(ty: &Type) -> bool {
133 if let Type::Path(type_path) = ty {
134 return type_path
135 .path
136 .segments
137 .last()
138 .is_some_and(|s| s.ident == "String");
139 }
140 false
141}
142
143fn parse_duration_to_millis(s: &str) -> Result<u64, String> {
147 let s = s.trim();
148 if s.is_empty() {
149 return Err("empty string".to_string());
150 }
151
152 let mut total_millis: u64 = 0;
153 let mut current_num = String::new();
154 let mut chars = s.chars().peekable();
155
156 while let Some(c) = chars.next() {
157 if c.is_ascii_digit() {
158 current_num.push(c);
159 } else if c.is_ascii_alphabetic() {
160 if current_num.is_empty() {
161 return Err(format!(
162 "unexpected unit character '{c}' without preceding number"
163 ));
164 }
165
166 let num: u64 = current_num
167 .parse()
168 .map_err(|_| format!("invalid number: {current_num}"))?;
169
170 let unit = if c == 'm' && chars.peek() == Some(&'s') {
172 chars.next(); "ms"
174 } else {
175 match c {
177 'h' => "h",
178 'm' => "m",
179 's' => "s",
180 _ => return Err(format!("unknown unit '{c}'")),
181 }
182 };
183
184 let millis = match unit {
185 "ms" => num,
186 "s" => num * 1000,
187 "m" => num * 60 * 1000,
188 "h" => num * 60 * 60 * 1000,
189 _ => unreachable!(),
190 };
191
192 total_millis = total_millis.saturating_add(millis);
193 current_num.clear();
194 } else if c.is_whitespace() {
195 continue;
196 } else {
197 return Err(format!("unexpected character '{c}'"));
198 }
199 }
200
201 if !current_num.is_empty() {
202 return Err(format!(
203 "number '{current_num}' missing unit (use s, m, h, or ms)"
204 ));
205 }
206
207 if total_millis == 0 {
208 return Err("duration must be greater than zero".to_string());
209 }
210
211 Ok(total_millis)
212}
213
214fn extract_template_params(uri: &str) -> Vec<String> {
216 let mut params = Vec::new();
217 let mut chars = uri.chars();
218
219 while let Some(ch) = chars.next() {
220 if ch == '{' {
221 let mut name = String::new();
222 for next in chars.by_ref() {
223 if next == '}' {
224 break;
225 }
226 name.push(next);
227 }
228 if !name.is_empty() {
229 params.push(name);
230 }
231 }
232 }
233
234 params
235}
236
237fn to_pascal_case(s: &str) -> String {
239 s.split('_')
240 .map(|word| {
241 let mut chars = word.chars();
242 match chars.next() {
243 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
244 None => String::new(),
245 }
246 })
247 .collect()
248}
249
250enum ReturnTypeKind {
252 VecContent,
254 String,
256 ResultVecContent,
258 ResultString,
260 McpResultVecContent,
262 McpResultString,
264 Other,
266 Unit,
268}
269
270fn analyze_return_type(output: &syn::ReturnType) -> ReturnTypeKind {
272 match output {
273 syn::ReturnType::Default => ReturnTypeKind::Unit,
274 syn::ReturnType::Type(_, ty) => analyze_type(ty),
275 }
276}
277
278fn analyze_type(ty: &Type) -> ReturnTypeKind {
280 if let Type::Path(type_path) = ty {
281 if let Some(segment) = type_path.path.segments.last() {
282 let type_name = segment.ident.to_string();
283
284 match type_name.as_str() {
285 "String" => return ReturnTypeKind::String,
286 "Vec" => {
287 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
289 if let Some(syn::GenericArgument::Type(Type::Path(inner_path))) =
290 args.args.first()
291 {
292 if inner_path
293 .path
294 .segments
295 .last()
296 .is_some_and(|s| s.ident == "Content")
297 {
298 return ReturnTypeKind::VecContent;
299 }
300 }
301 }
302 }
303 "Result" | "McpResult" => {
304 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
306 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
307 let inner_kind = analyze_type(inner_ty);
308 return match inner_kind {
309 ReturnTypeKind::VecContent => {
310 if type_name == "McpResult" {
311 ReturnTypeKind::McpResultVecContent
312 } else {
313 ReturnTypeKind::ResultVecContent
314 }
315 }
316 ReturnTypeKind::String => {
317 if type_name == "McpResult" {
318 ReturnTypeKind::McpResultString
319 } else {
320 ReturnTypeKind::ResultString
321 }
322 }
323 _ => ReturnTypeKind::Other,
324 };
325 }
326 }
327 }
328 _ => {}
329 }
330 }
331 }
332 ReturnTypeKind::Other
333}
334
335fn generate_result_conversion(output: &syn::ReturnType) -> TokenStream2 {
337 let kind = analyze_return_type(output);
338
339 match kind {
340 ReturnTypeKind::Unit => quote! {
341 Ok(vec![])
342 },
343 ReturnTypeKind::VecContent => quote! {
344 Ok(result)
345 },
346 ReturnTypeKind::String => quote! {
347 Ok(vec![fastmcp_protocol::Content::Text { text: result }])
348 },
349 ReturnTypeKind::ResultVecContent | ReturnTypeKind::McpResultVecContent => quote! {
350 result.map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))
351 },
352 ReturnTypeKind::ResultString | ReturnTypeKind::McpResultString => quote! {
353 result
354 .map(|s| vec![fastmcp_protocol::Content::Text { text: s }])
355 .map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))
356 },
357 ReturnTypeKind::Other => quote! {
358 let text = format!("{}", result);
360 Ok(vec![fastmcp_protocol::Content::Text { text }])
361 },
362 }
363}
364
365enum PromptReturnTypeKind {
371 VecPromptMessage,
373 ResultVecPromptMessage,
375 McpResultVecPromptMessage,
377 Other,
379}
380
381fn analyze_prompt_return_type(output: &syn::ReturnType) -> PromptReturnTypeKind {
383 match output {
384 syn::ReturnType::Default => PromptReturnTypeKind::Other, syn::ReturnType::Type(_, ty) => analyze_prompt_type(ty),
386 }
387}
388
389fn analyze_prompt_type(ty: &Type) -> PromptReturnTypeKind {
391 if let Type::Path(type_path) = ty {
392 if let Some(segment) = type_path.path.segments.last() {
393 let type_name = segment.ident.to_string();
394
395 match type_name.as_str() {
396 "Vec" => {
397 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
399 if let Some(syn::GenericArgument::Type(Type::Path(inner_path))) =
400 args.args.first()
401 {
402 if inner_path
403 .path
404 .segments
405 .last()
406 .is_some_and(|s| s.ident == "PromptMessage")
407 {
408 return PromptReturnTypeKind::VecPromptMessage;
409 }
410 }
411 }
412 }
413 "Result" | "McpResult" => {
414 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
416 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
417 let inner_kind = analyze_prompt_type(inner_ty);
418 return match inner_kind {
419 PromptReturnTypeKind::VecPromptMessage => {
420 if type_name == "McpResult" {
421 PromptReturnTypeKind::McpResultVecPromptMessage
422 } else {
423 PromptReturnTypeKind::ResultVecPromptMessage
424 }
425 }
426 _ => PromptReturnTypeKind::Other,
427 };
428 }
429 }
430 }
431 _ => {}
432 }
433 }
434 }
435 PromptReturnTypeKind::Other
436}
437
438fn generate_prompt_result_conversion(output: &syn::ReturnType) -> TokenStream2 {
440 let kind = analyze_prompt_return_type(output);
441
442 match kind {
443 PromptReturnTypeKind::VecPromptMessage => quote! {
444 Ok(result)
445 },
446 PromptReturnTypeKind::ResultVecPromptMessage
447 | PromptReturnTypeKind::McpResultVecPromptMessage => quote! {
448 result.map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))
449 },
450 PromptReturnTypeKind::Other => quote! {
451 Ok(result)
453 },
454 }
455}
456
457enum ResourceReturnTypeKind {
463 String,
465 ResultString,
467 McpResultString,
469 Other,
471}
472
473fn analyze_resource_return_type(output: &syn::ReturnType) -> ResourceReturnTypeKind {
475 match output {
476 syn::ReturnType::Default => ResourceReturnTypeKind::Other, syn::ReturnType::Type(_, ty) => analyze_resource_type(ty),
478 }
479}
480
481fn analyze_resource_type(ty: &Type) -> ResourceReturnTypeKind {
483 if let Type::Path(type_path) = ty {
484 if let Some(segment) = type_path.path.segments.last() {
485 let type_name = segment.ident.to_string();
486
487 match type_name.as_str() {
488 "String" => return ResourceReturnTypeKind::String,
489 "Result" | "McpResult" => {
490 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
492 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
493 let inner_kind = analyze_resource_type(inner_ty);
494 return match inner_kind {
495 ResourceReturnTypeKind::String => {
496 if type_name == "McpResult" {
497 ResourceReturnTypeKind::McpResultString
498 } else {
499 ResourceReturnTypeKind::ResultString
500 }
501 }
502 _ => ResourceReturnTypeKind::Other,
503 };
504 }
505 }
506 }
507 _ => {}
508 }
509 }
510 }
511 ResourceReturnTypeKind::Other
512}
513
514fn generate_resource_result_conversion(output: &syn::ReturnType, mime_type: &str) -> TokenStream2 {
524 let kind = analyze_resource_return_type(output);
525
526 match kind {
527 ResourceReturnTypeKind::String => quote! {
528 let text = result;
529 Ok(vec![fastmcp_protocol::ResourceContent {
530 uri: uri.to_string(),
531 mime_type: Some(#mime_type.to_string()),
532 text: Some(text),
533 blob: None,
534 }])
535 },
536 ResourceReturnTypeKind::ResultString | ResourceReturnTypeKind::McpResultString => quote! {
537 let text = result.map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))?;
538 Ok(vec![fastmcp_protocol::ResourceContent {
539 uri: uri.to_string(),
540 mime_type: Some(#mime_type.to_string()),
541 text: Some(text),
542 blob: None,
543 }])
544 },
545 ResourceReturnTypeKind::Other => quote! {
546 let text = result.to_string();
548 Ok(vec![fastmcp_protocol::ResourceContent {
549 uri: uri.to_string(),
550 mime_type: Some(#mime_type.to_string()),
551 text: Some(text),
552 blob: None,
553 }])
554 },
555 }
556}
557
558fn type_to_json_schema(ty: &Type) -> TokenStream2 {
560 let Type::Path(type_path) = ty else {
561 return quote! { serde_json::json!({}) };
562 };
563
564 let segment = type_path.path.segments.last().unwrap();
565 let type_name = segment.ident.to_string();
566
567 match type_name.as_str() {
568 "String" | "str" => quote! {
569 serde_json::json!({ "type": "string" })
570 },
571 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
572 | "usize" => quote! {
573 serde_json::json!({ "type": "integer" })
574 },
575 "f32" | "f64" => quote! {
576 serde_json::json!({ "type": "number" })
577 },
578 "bool" => quote! {
579 serde_json::json!({ "type": "boolean" })
580 },
581 "Option" => {
582 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
584 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
585 return type_to_json_schema(inner_ty);
586 }
587 }
588 quote! { serde_json::json!({}) }
589 }
590 "Vec" => {
591 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
593 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
594 let inner_schema = type_to_json_schema(inner_ty);
595 return quote! {
596 serde_json::json!({
597 "type": "array",
598 "items": #inner_schema
599 })
600 };
601 }
602 }
603 quote! { serde_json::json!({ "type": "array" }) }
604 }
605 "HashSet" | "BTreeSet" => {
606 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
608 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
609 let inner_schema = type_to_json_schema(inner_ty);
610 return quote! {
611 serde_json::json!({
612 "type": "array",
613 "items": #inner_schema,
614 "uniqueItems": true
615 })
616 };
617 }
618 }
619 quote! { serde_json::json!({ "type": "array", "uniqueItems": true }) }
620 }
621 "HashMap" | "BTreeMap" => {
622 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
624 if args.args.len() >= 2 {
627 if let Some(syn::GenericArgument::Type(value_ty)) = args.args.iter().nth(1) {
628 let value_schema = type_to_json_schema(value_ty);
629 return quote! {
630 serde_json::json!({
631 "type": "object",
632 "additionalProperties": #value_schema
633 })
634 };
635 }
636 }
637 }
638 quote! { serde_json::json!({ "type": "object" }) }
639 }
640 "serde_json::Value" | "Value" => {
641 quote! { serde_json::json!({}) }
643 }
644 _ => {
645 quote! { <#ty>::json_schema() }
648 }
649 }
650}
651
652struct ToolAttrs {
658 name: Option<String>,
659 description: Option<String>,
660 timeout: Option<String>,
661 output_schema: Option<syn::Expr>,
663}
664
665impl Parse for ToolAttrs {
666 fn parse(input: ParseStream) -> syn::Result<Self> {
667 let mut name = None;
668 let mut description = None;
669 let mut timeout = None;
670 let mut output_schema = None;
671
672 while !input.is_empty() {
673 let ident: Ident = input.parse()?;
674 input.parse::<Token![=]>()?;
675
676 match ident.to_string().as_str() {
677 "name" => {
678 let lit: LitStr = input.parse()?;
679 name = Some(lit.value());
680 }
681 "description" => {
682 let lit: LitStr = input.parse()?;
683 description = Some(lit.value());
684 }
685 "timeout" => {
686 let lit: LitStr = input.parse()?;
687 timeout = Some(lit.value());
688 }
689 "output_schema" => {
690 let expr: syn::Expr = input.parse()?;
692 output_schema = Some(expr);
693 }
694 _ => {
695 return Err(syn::Error::new(ident.span(), "unknown attribute"));
696 }
697 }
698
699 if !input.is_empty() {
700 input.parse::<Token![,]>()?;
701 }
702 }
703
704 Ok(Self {
705 name,
706 description,
707 timeout,
708 output_schema,
709 })
710 }
711}
712
713#[proc_macro_attribute]
725#[allow(clippy::too_many_lines)]
726pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
727 let attrs = parse_macro_input!(attr as ToolAttrs);
728 let input_fn = parse_macro_input!(item as ItemFn);
729
730 let fn_name = &input_fn.sig.ident;
731 let fn_name_str = fn_name.to_string();
732
733 let handler_name = format_ident!("{}", to_pascal_case(&fn_name_str));
735
736 let tool_name = attrs.name.unwrap_or_else(|| fn_name_str.clone());
738
739 let description = attrs
741 .description
742 .or_else(|| extract_doc_comments(&input_fn.attrs));
743 let description_tokens = description.as_ref().map_or_else(
744 || quote! { None },
745 |desc| quote! { Some(#desc.to_string()) },
746 );
747
748 let timeout_tokens = if let Some(ref timeout_str) = attrs.timeout {
750 match parse_duration_to_millis(timeout_str) {
751 Ok(millis) => {
752 quote! {
753 fn timeout(&self) -> Option<std::time::Duration> {
754 Some(std::time::Duration::from_millis(#millis))
755 }
756 }
757 }
758 Err(e) => {
759 return syn::Error::new_spanned(
760 &input_fn.sig.ident,
761 format!("invalid timeout: {e}"),
762 )
763 .to_compile_error()
764 .into();
765 }
766 }
767 } else {
768 quote! {}
769 };
770
771 let (output_schema_field, output_schema_method) =
773 if let Some(ref schema_expr) = attrs.output_schema {
774 (
775 quote! { Some(#schema_expr) },
776 quote! {
777 fn output_schema(&self) -> Option<serde_json::Value> {
778 Some(#schema_expr)
779 }
780 },
781 )
782 } else {
783 (quote! { None }, quote! {})
784 };
785
786 let mut params: Vec<(&Ident, &Type, Option<String>)> = Vec::new();
788 let mut required_params: Vec<String> = Vec::new();
789 let mut expects_context = false;
790
791 for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
792 if let FnArg::Typed(pat_type) = arg {
793 if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
795 expects_context = true;
796 continue;
797 }
798
799 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
800 let param_name = &pat_ident.ident;
801 let param_type = pat_type.ty.as_ref();
802 let param_doc = extract_doc_comments(&pat_type.attrs);
803
804 let is_optional = is_option_type(param_type);
806
807 if !is_optional {
808 required_params.push(param_name.to_string());
809 }
810
811 params.push((param_name, param_type, param_doc));
812 }
813 }
814 }
815
816 let property_entries: Vec<TokenStream2> = params
818 .iter()
819 .map(|(name, ty, doc)| {
820 let name_str = name.to_string();
821 let schema = type_to_json_schema(ty);
822 if let Some(desc) = doc {
823 quote! {
824 (#name_str.to_string(), {
825 let mut s = #schema;
826 if let Some(obj) = s.as_object_mut() {
827 obj.insert("description".to_string(), serde_json::json!(#desc));
828 }
829 s
830 })
831 }
832 } else {
833 quote! {
834 (#name_str.to_string(), #schema)
835 }
836 }
837 })
838 .collect();
839
840 let param_extractions: Vec<TokenStream2> = params
842 .iter()
843 .map(|(name, ty, _)| {
844 let name_str = name.to_string();
845 let is_optional = is_option_type(ty);
846
847 if is_optional {
848 quote! {
849 let #name: #ty = match arguments.get(#name_str) {
850 Some(value) => Some(
851 serde_json::from_value(value.clone()).map_err(|e| {
852 fastmcp_core::McpError::invalid_params(e.to_string())
853 })?,
854 ),
855 None => None,
856 };
857 }
858 } else {
859 quote! {
860 let #name: #ty = arguments.get(#name_str)
861 .ok_or_else(|| fastmcp_core::McpError::invalid_params(
862 format!("missing required parameter: {}", #name_str)
863 ))
864 .and_then(|v| serde_json::from_value(v.clone())
865 .map_err(|e| fastmcp_core::McpError::invalid_params(e.to_string())))?;
866 }
867 }
868 })
869 .collect();
870
871 let param_names: Vec<&Ident> = params.iter().map(|(name, _, _)| *name).collect();
873
874 let is_async = input_fn.sig.asyncness.is_some();
876
877 let return_type = &input_fn.sig.output;
879 let result_conversion = generate_result_conversion(return_type);
880
881 let call_expr = if is_async {
883 if expects_context {
884 quote! {
885 fastmcp_core::runtime::block_on(async move {
886 #fn_name(ctx, #(#param_names),*).await
887 })
888 }
889 } else {
890 quote! {
891 fastmcp_core::runtime::block_on(async move {
892 #fn_name(#(#param_names),*).await
893 })
894 }
895 }
896 } else {
897 if expects_context {
898 quote! {
899 #fn_name(ctx, #(#param_names),*)
900 }
901 } else {
902 quote! {
903 #fn_name(#(#param_names),*)
904 }
905 }
906 };
907
908 let expanded = quote! {
910 #input_fn
912
913 #[derive(Clone)]
915 pub struct #handler_name;
916
917 impl fastmcp_server::ToolHandler for #handler_name {
918 fn definition(&self) -> fastmcp_protocol::Tool {
919 let properties: std::collections::HashMap<String, serde_json::Value> = vec![
920 #(#property_entries),*
921 ].into_iter().collect();
922
923 let required: Vec<String> = vec![#(#required_params.to_string()),*];
924
925 fastmcp_protocol::Tool {
926 name: #tool_name.to_string(),
927 description: #description_tokens,
928 input_schema: serde_json::json!({
929 "type": "object",
930 "properties": properties,
931 "required": required,
932 }),
933 output_schema: #output_schema_field,
934 icon: None,
935 version: None,
936 tags: vec![],
937 annotations: None,
938 }
939 }
940
941 #timeout_tokens
942
943 #output_schema_method
944
945 fn call(
946 &self,
947 ctx: &fastmcp_core::McpContext,
948 arguments: serde_json::Value,
949 ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::Content>> {
950 let arguments = arguments.as_object()
952 .cloned()
953 .unwrap_or_default();
954
955 #(#param_extractions)*
957
958 let result = #call_expr;
960
961 #result_conversion
963 }
964 }
965 };
966
967 TokenStream::from(expanded)
968}
969
970struct ResourceAttrs {
976 uri: Option<String>,
977 name: Option<String>,
978 description: Option<String>,
979 mime_type: Option<String>,
980 timeout: Option<String>,
981}
982
983impl Parse for ResourceAttrs {
984 fn parse(input: ParseStream) -> syn::Result<Self> {
985 let mut uri = None;
986 let mut name = None;
987 let mut description = None;
988 let mut mime_type = None;
989 let mut timeout = None;
990
991 while !input.is_empty() {
992 let ident: Ident = input.parse()?;
993 input.parse::<Token![=]>()?;
994
995 match ident.to_string().as_str() {
996 "uri" => {
997 let lit: LitStr = input.parse()?;
998 uri = Some(lit.value());
999 }
1000 "name" => {
1001 let lit: LitStr = input.parse()?;
1002 name = Some(lit.value());
1003 }
1004 "description" => {
1005 let lit: LitStr = input.parse()?;
1006 description = Some(lit.value());
1007 }
1008 "mime_type" => {
1009 let lit: LitStr = input.parse()?;
1010 mime_type = Some(lit.value());
1011 }
1012 "timeout" => {
1013 let lit: LitStr = input.parse()?;
1014 timeout = Some(lit.value());
1015 }
1016 _ => {
1017 return Err(syn::Error::new(ident.span(), "unknown attribute"));
1018 }
1019 }
1020
1021 if !input.is_empty() {
1022 input.parse::<Token![,]>()?;
1023 }
1024 }
1025
1026 Ok(Self {
1027 uri,
1028 name,
1029 description,
1030 mime_type,
1031 timeout,
1032 })
1033 }
1034}
1035
1036#[proc_macro_attribute]
1045#[allow(clippy::too_many_lines)]
1046pub fn resource(attr: TokenStream, item: TokenStream) -> TokenStream {
1047 let attrs = parse_macro_input!(attr as ResourceAttrs);
1048 let input_fn = parse_macro_input!(item as ItemFn);
1049
1050 let fn_name = &input_fn.sig.ident;
1051 let fn_name_str = fn_name.to_string();
1052
1053 let handler_name = format_ident!("{}Resource", to_pascal_case(&fn_name_str));
1055
1056 let Some(uri) = attrs.uri else {
1058 return syn::Error::new_spanned(&input_fn.sig.ident, "resource requires uri attribute")
1059 .to_compile_error()
1060 .into();
1061 };
1062
1063 let resource_name = attrs.name.unwrap_or_else(|| fn_name_str.clone());
1065 let description = attrs
1066 .description
1067 .or_else(|| extract_doc_comments(&input_fn.attrs));
1068 let mime_type = attrs.mime_type.unwrap_or_else(|| "text/plain".to_string());
1069
1070 let description_tokens = description.as_ref().map_or_else(
1071 || quote! { None },
1072 |desc| quote! { Some(#desc.to_string()) },
1073 );
1074
1075 let timeout_tokens = if let Some(ref timeout_str) = attrs.timeout {
1077 match parse_duration_to_millis(timeout_str) {
1078 Ok(millis) => {
1079 quote! {
1080 fn timeout(&self) -> Option<std::time::Duration> {
1081 Some(std::time::Duration::from_millis(#millis))
1082 }
1083 }
1084 }
1085 Err(e) => {
1086 return syn::Error::new_spanned(
1087 &input_fn.sig.ident,
1088 format!("invalid timeout: {e}"),
1089 )
1090 .to_compile_error()
1091 .into();
1092 }
1093 }
1094 } else {
1095 quote! {}
1096 };
1097
1098 let template_params = extract_template_params(&uri);
1099
1100 let mut params: Vec<(&Ident, &Type)> = Vec::new();
1102 let mut expects_context = false;
1103
1104 for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
1105 if let FnArg::Typed(pat_type) = arg {
1106 if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
1107 expects_context = true;
1108 continue;
1109 }
1110
1111 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1112 let param_name = &pat_ident.ident;
1113 let param_type = pat_type.ty.as_ref();
1114 params.push((param_name, param_type));
1115 }
1116 }
1117 }
1118
1119 if template_params.is_empty() && !params.is_empty() {
1120 return syn::Error::new_spanned(
1121 &input_fn.sig.ident,
1122 "resource parameters require a URI template with matching {params}",
1123 )
1124 .to_compile_error()
1125 .into();
1126 }
1127
1128 let missing_params: Vec<String> = params
1129 .iter()
1130 .map(|(name, _)| name.to_string())
1131 .filter(|name| !template_params.contains(name))
1132 .collect();
1133
1134 if !missing_params.is_empty() {
1135 return syn::Error::new_spanned(
1136 &input_fn.sig.ident,
1137 format!(
1138 "resource parameters missing from uri template: {}",
1139 missing_params.join(", ")
1140 ),
1141 )
1142 .to_compile_error()
1143 .into();
1144 }
1145
1146 let is_template = !template_params.is_empty();
1147
1148 let param_extractions: Vec<TokenStream2> = params
1149 .iter()
1150 .map(|(name, ty)| {
1151 let name_str = name.to_string();
1152 if let Some(inner_ty) = option_inner_type(ty) {
1153 if is_string_type(inner_ty) {
1154 quote! {
1155 let #name: #ty = uri_params.get(#name_str).cloned();
1156 }
1157 } else {
1158 quote! {
1159 let #name: #ty = match uri_params.get(#name_str) {
1160 Some(value) => Some(value.parse().map_err(|_| {
1161 fastmcp_core::McpError::invalid_params(
1162 format!("invalid uri parameter: {}", #name_str)
1163 )
1164 })?),
1165 None => None,
1166 };
1167 }
1168 }
1169 } else if is_string_type(ty) {
1170 quote! {
1171 let #name: #ty = uri_params
1172 .get(#name_str)
1173 .ok_or_else(|| fastmcp_core::McpError::invalid_params(
1174 format!("missing uri parameter: {}", #name_str)
1175 ))?
1176 .clone();
1177 }
1178 } else {
1179 quote! {
1180 let #name: #ty = uri_params
1181 .get(#name_str)
1182 .ok_or_else(|| fastmcp_core::McpError::invalid_params(
1183 format!("missing uri parameter: {}", #name_str)
1184 ))?
1185 .parse()
1186 .map_err(|_| fastmcp_core::McpError::invalid_params(
1187 format!("invalid uri parameter: {}", #name_str)
1188 ))?;
1189 }
1190 }
1191 })
1192 .collect();
1193
1194 let param_names: Vec<&Ident> = params.iter().map(|(name, _)| *name).collect();
1195 let call_args = if expects_context {
1196 quote! { ctx, #(#param_names),* }
1197 } else {
1198 quote! { #(#param_names),* }
1199 };
1200
1201 let is_async = input_fn.sig.asyncness.is_some();
1202 let call_expr = if is_async {
1203 quote! {
1204 fastmcp_core::runtime::block_on(async move {
1205 #fn_name(#call_args).await
1206 })
1207 }
1208 } else {
1209 quote! {
1210 #fn_name(#call_args)
1211 }
1212 };
1213
1214 let template_tokens = if is_template {
1215 quote! {
1216 Some(fastmcp_protocol::ResourceTemplate {
1217 uri_template: #uri.to_string(),
1218 name: #resource_name.to_string(),
1219 description: #description_tokens,
1220 mime_type: Some(#mime_type.to_string()),
1221 icon: None,
1222 version: None,
1223 tags: vec![],
1224 })
1225 }
1226 } else {
1227 quote! { None }
1228 };
1229
1230 let return_type = &input_fn.sig.output;
1232 let resource_result_conversion = generate_resource_result_conversion(return_type, &mime_type);
1233
1234 let expanded = quote! {
1235 #input_fn
1237
1238 #[derive(Clone)]
1240 pub struct #handler_name;
1241
1242 impl fastmcp_server::ResourceHandler for #handler_name {
1243 fn definition(&self) -> fastmcp_protocol::Resource {
1244 fastmcp_protocol::Resource {
1245 uri: #uri.to_string(),
1246 name: #resource_name.to_string(),
1247 description: #description_tokens,
1248 mime_type: Some(#mime_type.to_string()),
1249 icon: None,
1250 version: None,
1251 tags: vec![],
1252 }
1253 }
1254
1255 fn template(&self) -> Option<fastmcp_protocol::ResourceTemplate> {
1256 #template_tokens
1257 }
1258
1259 #timeout_tokens
1260
1261 fn read(
1262 &self,
1263 ctx: &fastmcp_core::McpContext,
1264 ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::ResourceContent>> {
1265 let uri_params = std::collections::HashMap::new();
1266 self.read_with_uri(ctx, #uri, &uri_params)
1267 }
1268
1269 fn read_with_uri(
1270 &self,
1271 ctx: &fastmcp_core::McpContext,
1272 uri: &str,
1273 uri_params: &std::collections::HashMap<String, String>,
1274 ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::ResourceContent>> {
1275 #(#param_extractions)*
1276 let result = #call_expr;
1277 #resource_result_conversion
1278 }
1279
1280 fn read_async_with_uri<'a>(
1281 &'a self,
1282 ctx: &'a fastmcp_core::McpContext,
1283 uri: &'a str,
1284 uri_params: &'a std::collections::HashMap<String, String>,
1285 ) -> fastmcp_server::BoxFuture<'a, fastmcp_core::McpOutcome<Vec<fastmcp_protocol::ResourceContent>>> {
1286 Box::pin(async move {
1287 match self.read_with_uri(ctx, uri, uri_params) {
1288 Ok(value) => fastmcp_core::Outcome::Ok(value),
1289 Err(error) => fastmcp_core::Outcome::Err(error),
1290 }
1291 })
1292 }
1293 }
1294 };
1295
1296 TokenStream::from(expanded)
1297}
1298
1299struct PromptAttrs {
1305 name: Option<String>,
1306 description: Option<String>,
1307 timeout: Option<String>,
1308}
1309
1310impl Parse for PromptAttrs {
1311 fn parse(input: ParseStream) -> syn::Result<Self> {
1312 let mut name = None;
1313 let mut description = None;
1314 let mut timeout = None;
1315
1316 while !input.is_empty() {
1317 let ident: Ident = input.parse()?;
1318 input.parse::<Token![=]>()?;
1319
1320 match ident.to_string().as_str() {
1321 "name" => {
1322 let lit: LitStr = input.parse()?;
1323 name = Some(lit.value());
1324 }
1325 "description" => {
1326 let lit: LitStr = input.parse()?;
1327 description = Some(lit.value());
1328 }
1329 "timeout" => {
1330 let lit: LitStr = input.parse()?;
1331 timeout = Some(lit.value());
1332 }
1333 _ => {
1334 return Err(syn::Error::new(ident.span(), "unknown attribute"));
1335 }
1336 }
1337
1338 if !input.is_empty() {
1339 input.parse::<Token![,]>()?;
1340 }
1341 }
1342
1343 Ok(Self {
1344 name,
1345 description,
1346 timeout,
1347 })
1348 }
1349}
1350
1351#[proc_macro_attribute]
1358#[allow(clippy::too_many_lines)]
1359pub fn prompt(attr: TokenStream, item: TokenStream) -> TokenStream {
1360 let attrs = parse_macro_input!(attr as PromptAttrs);
1361 let input_fn = parse_macro_input!(item as ItemFn);
1362
1363 let fn_name = &input_fn.sig.ident;
1364 let fn_name_str = fn_name.to_string();
1365
1366 let handler_name = format_ident!("{}Prompt", to_pascal_case(&fn_name_str));
1368
1369 let prompt_name = attrs.name.unwrap_or_else(|| fn_name_str.clone());
1371
1372 let description = attrs
1374 .description
1375 .or_else(|| extract_doc_comments(&input_fn.attrs));
1376 let description_tokens = description.as_ref().map_or_else(
1377 || quote! { None },
1378 |desc| quote! { Some(#desc.to_string()) },
1379 );
1380
1381 let timeout_tokens = if let Some(ref timeout_str) = attrs.timeout {
1383 match parse_duration_to_millis(timeout_str) {
1384 Ok(millis) => {
1385 quote! {
1386 fn timeout(&self) -> Option<std::time::Duration> {
1387 Some(std::time::Duration::from_millis(#millis))
1388 }
1389 }
1390 }
1391 Err(e) => {
1392 return syn::Error::new_spanned(
1393 &input_fn.sig.ident,
1394 format!("invalid timeout: {e}"),
1395 )
1396 .to_compile_error()
1397 .into();
1398 }
1399 }
1400 } else {
1401 quote! {}
1402 };
1403
1404 let mut prompt_args: Vec<TokenStream2> = Vec::new();
1406 let mut expects_context = false;
1407
1408 for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
1409 if let FnArg::Typed(pat_type) = arg {
1410 if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
1412 expects_context = true;
1413 continue;
1414 }
1415
1416 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1417 let param_name = pat_ident.ident.to_string();
1418 let param_doc = extract_doc_comments(&pat_type.attrs);
1419 let is_optional = is_option_type(pat_type.ty.as_ref());
1420
1421 let desc_tokens = param_doc
1422 .as_ref()
1423 .map_or_else(|| quote! { None }, |d| quote! { Some(#d.to_string()) });
1424
1425 prompt_args.push(quote! {
1426 fastmcp_protocol::PromptArgument {
1427 name: #param_name.to_string(),
1428 description: #desc_tokens,
1429 required: !#is_optional,
1430 }
1431 });
1432 }
1433 }
1434 }
1435
1436 let mut param_extractions: Vec<TokenStream2> = Vec::new();
1438 let mut param_names: Vec<Ident> = Vec::new();
1439
1440 for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
1441 if let FnArg::Typed(pat_type) = arg {
1442 if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
1444 continue;
1445 }
1446
1447 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1448 let param_name = &pat_ident.ident;
1449 let param_name_str = param_name.to_string();
1450 let is_optional = is_option_type(pat_type.ty.as_ref());
1451
1452 param_names.push(param_name.clone());
1453
1454 if is_optional {
1455 param_extractions.push(quote! {
1457 let #param_name = arguments.get(#param_name_str).cloned();
1458 });
1459 } else {
1460 param_extractions.push(quote! {
1462 let #param_name = arguments.get(#param_name_str)
1463 .cloned()
1464 .ok_or_else(|| fastmcp_core::McpError::invalid_params(
1465 format!("missing required argument: {}", #param_name_str)
1466 ))?;
1467 });
1468 }
1469 }
1470 }
1471 }
1472
1473 let is_async = input_fn.sig.asyncness.is_some();
1474 let call_expr = if is_async {
1475 if expects_context {
1476 quote! {
1477 fastmcp_core::runtime::block_on(async move {
1478 #fn_name(ctx, #(#param_names),*).await
1479 })
1480 }
1481 } else {
1482 quote! {
1483 fastmcp_core::runtime::block_on(async move {
1484 #fn_name(#(#param_names),*).await
1485 })
1486 }
1487 }
1488 } else {
1489 if expects_context {
1490 quote! {
1491 #fn_name(ctx, #(#param_names),*)
1492 }
1493 } else {
1494 quote! {
1495 #fn_name(#(#param_names),*)
1496 }
1497 }
1498 };
1499
1500 let return_type = &input_fn.sig.output;
1502 let prompt_result_conversion = generate_prompt_result_conversion(return_type);
1503
1504 let expanded = quote! {
1505 #input_fn
1507
1508 #[derive(Clone)]
1510 pub struct #handler_name;
1511
1512 impl fastmcp_server::PromptHandler for #handler_name {
1513 fn definition(&self) -> fastmcp_protocol::Prompt {
1514 fastmcp_protocol::Prompt {
1515 name: #prompt_name.to_string(),
1516 description: #description_tokens,
1517 arguments: vec![#(#prompt_args),*],
1518 icon: None,
1519 version: None,
1520 tags: vec![],
1521 }
1522 }
1523
1524 #timeout_tokens
1525
1526 fn get(
1527 &self,
1528 ctx: &fastmcp_core::McpContext,
1529 arguments: std::collections::HashMap<String, String>,
1530 ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::PromptMessage>> {
1531 #(#param_extractions)*
1532 let result = #call_expr;
1533 #prompt_result_conversion
1534 }
1535 }
1536 };
1537
1538 TokenStream::from(expanded)
1539}
1540
1541#[proc_macro_derive(JsonSchema, attributes(json_schema))]
1590pub fn derive_json_schema(input: TokenStream) -> TokenStream {
1591 let input = parse_macro_input!(input as syn::DeriveInput);
1592
1593 let name = &input.ident;
1594 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1595
1596 let type_description = extract_doc_comments(&input.attrs);
1598 let type_desc_tokens = type_description
1599 .as_ref()
1600 .map_or_else(|| quote! { None::<&str> }, |desc| quote! { Some(#desc) });
1601
1602 let schema_impl = match &input.data {
1604 syn::Data::Struct(data_struct) => generate_struct_schema(data_struct, &type_desc_tokens),
1605 syn::Data::Enum(data_enum) => generate_enum_schema(data_enum, &type_desc_tokens),
1606 syn::Data::Union(_) => {
1607 return syn::Error::new_spanned(input, "JsonSchema cannot be derived for unions")
1608 .to_compile_error()
1609 .into();
1610 }
1611 };
1612
1613 let expanded = quote! {
1614 impl #impl_generics #name #ty_generics #where_clause {
1615 pub fn json_schema() -> serde_json::Value {
1617 #schema_impl
1618 }
1619 }
1620 };
1621
1622 TokenStream::from(expanded)
1623}
1624
1625fn generate_struct_schema(data: &syn::DataStruct, type_desc_tokens: &TokenStream2) -> TokenStream2 {
1627 match &data.fields {
1628 syn::Fields::Named(fields) => {
1629 let mut property_entries = Vec::new();
1630 let mut required_fields = Vec::new();
1631
1632 for field in &fields.named {
1633 if has_json_schema_attr(&field.attrs, "skip") {
1635 continue;
1636 }
1637
1638 let field_name = field.ident.as_ref().unwrap();
1639
1640 let schema_name =
1642 get_json_schema_rename(&field.attrs).unwrap_or_else(|| field_name.to_string());
1643
1644 let field_doc = extract_doc_comments(&field.attrs);
1646
1647 let field_type = &field.ty;
1649 let is_optional = is_option_type(field_type);
1650
1651 let field_schema = type_to_json_schema(field_type);
1653
1654 let property_value = if let Some(desc) = &field_doc {
1656 quote! {
1657 {
1658 let mut schema = #field_schema;
1659 if let Some(obj) = schema.as_object_mut() {
1660 obj.insert("description".to_string(), serde_json::json!(#desc));
1661 }
1662 schema
1663 }
1664 }
1665 } else {
1666 field_schema
1667 };
1668
1669 property_entries.push(quote! {
1670 (#schema_name.to_string(), #property_value)
1671 });
1672
1673 if !is_optional {
1675 required_fields.push(schema_name);
1676 }
1677 }
1678
1679 quote! {
1680 {
1681 let properties: std::collections::HashMap<String, serde_json::Value> = vec![
1682 #(#property_entries),*
1683 ].into_iter().collect();
1684
1685 let required: Vec<String> = vec![#(#required_fields.to_string()),*];
1686
1687 let mut schema = serde_json::json!({
1688 "type": "object",
1689 "properties": properties,
1690 "required": required,
1691 });
1692
1693 if let Some(desc) = #type_desc_tokens {
1695 if let Some(obj) = schema.as_object_mut() {
1696 obj.insert("description".to_string(), serde_json::json!(desc));
1697 }
1698 }
1699
1700 schema
1701 }
1702 }
1703 }
1704 syn::Fields::Unnamed(fields) => {
1705 if fields.unnamed.len() == 1 {
1707 let inner_type = &fields.unnamed.first().unwrap().ty;
1709 let inner_schema = type_to_json_schema(inner_type);
1710 quote! { #inner_schema }
1711 } else {
1712 let item_schemas: Vec<_> = fields
1714 .unnamed
1715 .iter()
1716 .map(|f| type_to_json_schema(&f.ty))
1717 .collect();
1718 let num_items = item_schemas.len();
1719 quote! {
1720 {
1721 let items: Vec<serde_json::Value> = vec![#(#item_schemas),*];
1722 serde_json::json!({
1723 "type": "array",
1724 "prefixItems": items,
1725 "minItems": #num_items,
1726 "maxItems": #num_items,
1727 })
1728 }
1729 }
1730 }
1731 }
1732 syn::Fields::Unit => {
1733 quote! { serde_json::json!({ "type": "null" }) }
1735 }
1736 }
1737}
1738
1739fn generate_enum_schema(data: &syn::DataEnum, type_desc_tokens: &TokenStream2) -> TokenStream2 {
1741 let all_unit = data
1743 .variants
1744 .iter()
1745 .all(|v| matches!(v.fields, syn::Fields::Unit));
1746
1747 if all_unit {
1748 let variant_names: Vec<String> =
1750 data.variants.iter().map(|v| v.ident.to_string()).collect();
1751
1752 quote! {
1753 {
1754 let mut schema = serde_json::json!({
1755 "type": "string",
1756 "enum": [#(#variant_names),*]
1757 });
1758
1759 if let Some(desc) = #type_desc_tokens {
1760 if let Some(obj) = schema.as_object_mut() {
1761 obj.insert("description".to_string(), serde_json::json!(desc));
1762 }
1763 }
1764
1765 schema
1766 }
1767 }
1768 } else {
1769 let variant_schemas: Vec<TokenStream2> = data
1771 .variants
1772 .iter()
1773 .map(|variant| {
1774 let variant_name = variant.ident.to_string();
1775 match &variant.fields {
1776 syn::Fields::Unit => {
1777 quote! {
1778 serde_json::json!({
1779 "type": "object",
1780 "properties": {
1781 #variant_name: { "type": "null" }
1782 },
1783 "required": [#variant_name]
1784 })
1785 }
1786 }
1787 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
1788 let inner_type = &fields.unnamed.first().unwrap().ty;
1789 let inner_schema = type_to_json_schema(inner_type);
1790 quote! {
1791 serde_json::json!({
1792 "type": "object",
1793 "properties": {
1794 #variant_name: #inner_schema
1795 },
1796 "required": [#variant_name]
1797 })
1798 }
1799 }
1800 _ => {
1801 quote! {
1803 serde_json::json!({
1804 "type": "object",
1805 "properties": {
1806 #variant_name: { "type": "object" }
1807 },
1808 "required": [#variant_name]
1809 })
1810 }
1811 }
1812 }
1813 })
1814 .collect();
1815
1816 quote! {
1817 {
1818 let mut schema = serde_json::json!({
1819 "oneOf": [#(#variant_schemas),*]
1820 });
1821
1822 if let Some(desc) = #type_desc_tokens {
1823 if let Some(obj) = schema.as_object_mut() {
1824 obj.insert("description".to_string(), serde_json::json!(desc));
1825 }
1826 }
1827
1828 schema
1829 }
1830 }
1831 }
1832}
1833
1834fn has_json_schema_attr(attrs: &[Attribute], attr_name: &str) -> bool {
1836 for attr in attrs {
1837 if attr.path().is_ident("json_schema") {
1838 if let Meta::List(meta_list) = &attr.meta {
1839 if let Ok(nested) = meta_list.parse_args::<Ident>() {
1840 if nested == attr_name {
1841 return true;
1842 }
1843 }
1844 }
1845 }
1846 }
1847 false
1848}
1849
1850fn get_json_schema_rename(attrs: &[Attribute]) -> Option<String> {
1852 for attr in attrs {
1853 if attr.path().is_ident("json_schema") {
1854 if let Meta::List(meta_list) = &attr.meta {
1855 let result: syn::Result<(Ident, LitStr)> =
1857 meta_list.parse_args_with(|input: ParseStream| {
1858 let ident: Ident = input.parse()?;
1859 let _: Token![=] = input.parse()?;
1860 let lit: LitStr = input.parse()?;
1861 Ok((ident, lit))
1862 });
1863
1864 if let Ok((ident, lit)) = result {
1865 if ident == "rename" {
1866 return Some(lit.value());
1867 }
1868 }
1869 }
1870 }
1871 }
1872 None
1873}