1#![forbid(unsafe_code)]
54
55use std::collections::HashMap;
56
57use proc_macro::TokenStream;
58use proc_macro2::TokenStream as TokenStream2;
59use quote::{format_ident, quote};
60use syn::spanned::Spanned as _;
61use syn::{
62 Attribute, FnArg, Ident, ItemFn, Lit, LitStr, Meta, Pat, Token, Type, parse::Parse,
63 parse::ParseStream, parse_macro_input,
64};
65
66fn extract_doc_comments(attrs: &[Attribute]) -> Option<String> {
68 let docs: Vec<String> = attrs
69 .iter()
70 .filter_map(|attr| {
71 if attr.path().is_ident("doc") {
72 if let Meta::NameValue(nv) = &attr.meta {
73 if let syn::Expr::Lit(syn::ExprLit {
74 lit: Lit::Str(s), ..
75 }) = &nv.value
76 {
77 return Some(s.value().trim().to_string());
78 }
79 }
80 }
81 None
82 })
83 .collect();
84
85 if docs.is_empty() {
86 None
87 } else {
88 Some(docs.join("\n"))
89 }
90}
91
92fn is_mcp_context_ref(ty: &Type) -> bool {
94 if let Type::Reference(type_ref) = ty {
95 if let Type::Path(type_path) = type_ref.elem.as_ref() {
96 return type_path
97 .path
98 .segments
99 .last()
100 .is_some_and(|s| s.ident == "McpContext");
101 }
102 }
103 false
104}
105
106fn is_option_type(ty: &Type) -> bool {
108 if let Type::Path(type_path) = ty {
109 return type_path
110 .path
111 .segments
112 .last()
113 .is_some_and(|s| s.ident == "Option");
114 }
115 false
116}
117
118fn option_inner_type(ty: &Type) -> Option<&Type> {
120 if let Type::Path(type_path) = ty {
121 if let Some(segment) = type_path.path.segments.last() {
122 if segment.ident == "Option" {
123 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
124 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
125 return Some(inner_ty);
126 }
127 }
128 }
129 }
130 }
131 None
132}
133
134fn is_string_type(ty: &Type) -> bool {
136 if let Type::Path(type_path) = ty {
137 return type_path
138 .path
139 .segments
140 .last()
141 .is_some_and(|s| s.ident == "String");
142 }
143 false
144}
145
146fn default_lit_expr_for_type(lit: &Lit, ty: &Type) -> syn::Result<TokenStream2> {
147 if is_option_type(ty) {
148 let inner = option_inner_type(ty).ok_or_else(|| {
149 syn::Error::new(
150 ty.span(),
151 "Option<T> default requires a concrete inner type",
152 )
153 })?;
154 let inner_expr = default_lit_expr_for_type(lit, inner)?;
155 return Ok(quote! { Some(#inner_expr) });
156 }
157
158 if is_string_type(ty) {
159 let Lit::Str(s) = lit else {
160 return Err(syn::Error::new(
161 lit.span(),
162 "default for String must be a string literal",
163 ));
164 };
165 return Ok(quote! { #s.to_string() });
166 }
167
168 Ok(quote! { #lit })
169}
170
171fn parse_duration_to_millis(s: &str) -> Result<u64, String> {
175 let s = s.trim();
176 if s.is_empty() {
177 return Err("empty string".to_string());
178 }
179
180 let mut total_millis: u64 = 0;
181 let mut current_num = String::new();
182 let mut chars = s.chars().peekable();
183
184 while let Some(c) = chars.next() {
185 if c.is_ascii_digit() {
186 current_num.push(c);
187 } else if c.is_ascii_alphabetic() {
188 if current_num.is_empty() {
189 return Err(format!(
190 "unexpected unit character '{c}' without preceding number"
191 ));
192 }
193
194 let num: u64 = current_num
195 .parse()
196 .map_err(|_| format!("invalid number: {current_num}"))?;
197
198 let unit = if c == 'm' && chars.peek() == Some(&'s') {
200 chars.next(); "ms"
202 } else {
203 match c {
205 'h' => "h",
206 'm' => "m",
207 's' => "s",
208 _ => return Err(format!("unknown unit '{c}'")),
209 }
210 };
211
212 let millis = match unit {
213 "ms" => num,
214 "s" => num
215 .checked_mul(1000)
216 .ok_or_else(|| format!("duration overflow for component: {num}s"))?,
217 "m" => num
218 .checked_mul(60_000)
219 .ok_or_else(|| format!("duration overflow for component: {num}m"))?,
220 "h" => num
221 .checked_mul(3_600_000)
222 .ok_or_else(|| format!("duration overflow for component: {num}h"))?,
223 _ => unreachable!(),
224 };
225
226 total_millis = total_millis
227 .checked_add(millis)
228 .ok_or_else(|| "duration overflow".to_string())?;
229 current_num.clear();
230 } else if c.is_whitespace() {
231 continue;
232 } else {
233 return Err(format!("unexpected character '{c}'"));
234 }
235 }
236
237 if !current_num.is_empty() {
238 return Err(format!(
239 "number '{current_num}' missing unit (use s, m, h, or ms)"
240 ));
241 }
242
243 if total_millis == 0 {
244 return Err("duration must be greater than zero".to_string());
245 }
246
247 Ok(total_millis)
248}
249
250#[cfg(test)]
251#[allow(clippy::items_after_test_module)]
252mod duration_parse_tests {
253 use super::parse_duration_to_millis;
254
255 #[test]
256 fn parse_duration_compound_values() {
257 assert_eq!(parse_duration_to_millis("1h30m"), Ok(5_400_000));
258 assert_eq!(parse_duration_to_millis("500ms"), Ok(500));
259 }
260
261 #[test]
262 fn parse_duration_component_overflow_returns_error() {
263 let input = format!("{}s", u64::MAX);
264 let err = parse_duration_to_millis(&input).expect_err("overflowing component must fail");
265 assert!(err.contains("overflow"));
266 }
267
268 #[test]
269 fn parse_duration_total_overflow_returns_error() {
270 let input = format!("{}ms1ms", u64::MAX);
271 let err = parse_duration_to_millis(&input).expect_err("overflowing total must fail");
272 assert!(err.contains("overflow"));
273 }
274
275 #[test]
280 fn parse_single_units() {
281 assert_eq!(parse_duration_to_millis("30s"), Ok(30_000));
282 assert_eq!(parse_duration_to_millis("5m"), Ok(300_000));
283 assert_eq!(parse_duration_to_millis("2h"), Ok(7_200_000));
284 assert_eq!(parse_duration_to_millis("100ms"), Ok(100));
285 }
286
287 #[test]
288 fn parse_empty_and_whitespace() {
289 assert!(parse_duration_to_millis("").is_err());
290 assert!(parse_duration_to_millis(" ").is_err());
291 }
292
293 #[test]
294 fn parse_missing_unit() {
295 let err = parse_duration_to_millis("42").unwrap_err();
296 assert!(err.contains("missing unit"));
297 }
298
299 #[test]
300 fn parse_unit_without_number() {
301 let err = parse_duration_to_millis("s").unwrap_err();
302 assert!(err.contains("without preceding number"));
303 }
304
305 #[test]
306 fn parse_unknown_unit() {
307 let err = parse_duration_to_millis("10x").unwrap_err();
308 assert!(err.contains("unknown unit"));
309 }
310
311 #[test]
312 fn parse_unexpected_character() {
313 let err = parse_duration_to_millis("10s$").unwrap_err();
314 assert!(err.contains("unexpected character"));
315 }
316
317 #[test]
318 fn parse_zero_duration() {
319 let err = parse_duration_to_millis("0s").unwrap_err();
320 assert!(err.contains("greater than zero"));
321 }
322
323 #[test]
324 fn parse_whitespace_between_components() {
325 assert_eq!(parse_duration_to_millis("1h 30m"), Ok(5_400_000));
326 }
327
328 #[test]
329 fn parse_trimmed_input() {
330 assert_eq!(parse_duration_to_millis(" 10s "), Ok(10_000));
331 }
332}
333
334#[cfg(test)]
335#[allow(clippy::items_after_test_module)]
336mod helper_tests {
337 use super::{extract_template_params, to_pascal_case};
338
339 #[test]
340 fn template_params_basic() {
341 let params = extract_template_params("users/{id}/posts/{post_id}");
342 assert_eq!(params, vec!["id", "post_id"]);
343 }
344
345 #[test]
346 fn template_params_none() {
347 let params = extract_template_params("static/path/no/params");
348 assert!(params.is_empty());
349 }
350
351 #[test]
352 fn template_params_single() {
353 let params = extract_template_params("config://{name}");
354 assert_eq!(params, vec!["name"]);
355 }
356
357 #[test]
358 fn template_params_adjacent_braces() {
359 let params = extract_template_params("{a}{b}");
360 assert_eq!(params, vec!["a", "b"]);
361 }
362
363 #[test]
364 fn template_params_empty_braces_skipped() {
365 let params = extract_template_params("prefix/{}");
366 assert!(params.is_empty());
367 }
368
369 #[test]
370 fn pascal_case_single_word() {
371 assert_eq!(to_pascal_case("hello"), "Hello");
372 }
373
374 #[test]
375 fn pascal_case_snake_case() {
376 assert_eq!(to_pascal_case("my_tool_handler"), "MyToolHandler");
377 }
378
379 #[test]
380 fn pascal_case_already_pascal() {
381 assert_eq!(to_pascal_case("Hello"), "Hello");
382 }
383
384 #[test]
385 fn pascal_case_empty() {
386 assert_eq!(to_pascal_case(""), "");
387 }
388
389 #[test]
390 fn pascal_case_leading_underscore() {
391 assert_eq!(to_pascal_case("_private"), "Private");
393 }
394}
395
396fn extract_template_params(uri: &str) -> Vec<String> {
398 let mut params = Vec::new();
399 let mut chars = uri.chars();
400
401 while let Some(ch) = chars.next() {
402 if ch == '{' {
403 let mut name = String::new();
404 for next in chars.by_ref() {
405 if next == '}' {
406 break;
407 }
408 name.push(next);
409 }
410 if !name.is_empty() {
411 params.push(name);
412 }
413 }
414 }
415
416 params
417}
418
419fn to_pascal_case(s: &str) -> String {
421 s.split('_')
422 .map(|word| {
423 let mut chars = word.chars();
424 match chars.next() {
425 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
426 None => String::new(),
427 }
428 })
429 .collect()
430}
431
432enum ReturnTypeKind {
434 VecContent,
436 String,
438 ResultVecContent,
440 ResultString,
442 McpResultVecContent,
444 McpResultString,
446 Other,
448 Unit,
450}
451
452fn analyze_return_type(output: &syn::ReturnType) -> ReturnTypeKind {
454 match output {
455 syn::ReturnType::Default => ReturnTypeKind::Unit,
456 syn::ReturnType::Type(_, ty) => analyze_type(ty),
457 }
458}
459
460fn analyze_type(ty: &Type) -> ReturnTypeKind {
462 if let Type::Path(type_path) = ty {
463 if let Some(segment) = type_path.path.segments.last() {
464 let type_name = segment.ident.to_string();
465
466 match type_name.as_str() {
467 "String" => return ReturnTypeKind::String,
468 "Vec" => {
469 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
471 if let Some(syn::GenericArgument::Type(Type::Path(inner_path))) =
472 args.args.first()
473 {
474 if inner_path
475 .path
476 .segments
477 .last()
478 .is_some_and(|s| s.ident == "Content")
479 {
480 return ReturnTypeKind::VecContent;
481 }
482 }
483 }
484 }
485 "Result" | "McpResult" => {
486 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
488 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
489 let inner_kind = analyze_type(inner_ty);
490 return match inner_kind {
491 ReturnTypeKind::VecContent => {
492 if type_name == "McpResult" {
493 ReturnTypeKind::McpResultVecContent
494 } else {
495 ReturnTypeKind::ResultVecContent
496 }
497 }
498 ReturnTypeKind::String => {
499 if type_name == "McpResult" {
500 ReturnTypeKind::McpResultString
501 } else {
502 ReturnTypeKind::ResultString
503 }
504 }
505 _ => ReturnTypeKind::Other,
506 };
507 }
508 }
509 }
510 _ => {}
511 }
512 }
513 }
514 ReturnTypeKind::Other
515}
516
517fn generate_result_conversion(output: &syn::ReturnType) -> TokenStream2 {
519 let kind = analyze_return_type(output);
520
521 match kind {
522 ReturnTypeKind::Unit => quote! {
523 Ok(vec![])
524 },
525 ReturnTypeKind::VecContent => quote! {
526 Ok(result)
527 },
528 ReturnTypeKind::String => quote! {
529 Ok(vec![fastmcp_protocol::Content::Text { text: result }])
530 },
531 ReturnTypeKind::ResultVecContent => quote! {
532 result.map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))
533 },
534 ReturnTypeKind::McpResultVecContent => quote! {
535 result
536 },
537 ReturnTypeKind::ResultString => quote! {
538 result
539 .map(|s| vec![fastmcp_protocol::Content::Text { text: s }])
540 .map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))
541 },
542 ReturnTypeKind::McpResultString => quote! {
543 result.map(|s| vec![fastmcp_protocol::Content::Text { text: s }])
544 },
545 ReturnTypeKind::Other => quote! {
546 let text = format!("{}", result);
548 Ok(vec![fastmcp_protocol::Content::Text { text }])
549 },
550 }
551}
552
553enum PromptReturnTypeKind {
559 VecPromptMessage,
561 ResultVecPromptMessage,
563 McpResultVecPromptMessage,
565 Other,
567}
568
569fn analyze_prompt_return_type(output: &syn::ReturnType) -> PromptReturnTypeKind {
571 match output {
572 syn::ReturnType::Default => PromptReturnTypeKind::Other, syn::ReturnType::Type(_, ty) => analyze_prompt_type(ty),
574 }
575}
576
577fn analyze_prompt_type(ty: &Type) -> PromptReturnTypeKind {
579 if let Type::Path(type_path) = ty {
580 if let Some(segment) = type_path.path.segments.last() {
581 let type_name = segment.ident.to_string();
582
583 match type_name.as_str() {
584 "Vec" => {
585 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
587 if let Some(syn::GenericArgument::Type(Type::Path(inner_path))) =
588 args.args.first()
589 {
590 if inner_path
591 .path
592 .segments
593 .last()
594 .is_some_and(|s| s.ident == "PromptMessage")
595 {
596 return PromptReturnTypeKind::VecPromptMessage;
597 }
598 }
599 }
600 }
601 "Result" | "McpResult" => {
602 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
604 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
605 let inner_kind = analyze_prompt_type(inner_ty);
606 return match inner_kind {
607 PromptReturnTypeKind::VecPromptMessage => {
608 if type_name == "McpResult" {
609 PromptReturnTypeKind::McpResultVecPromptMessage
610 } else {
611 PromptReturnTypeKind::ResultVecPromptMessage
612 }
613 }
614 _ => PromptReturnTypeKind::Other,
615 };
616 }
617 }
618 }
619 _ => {}
620 }
621 }
622 }
623 PromptReturnTypeKind::Other
624}
625
626fn generate_prompt_result_conversion(output: &syn::ReturnType) -> TokenStream2 {
628 let kind = analyze_prompt_return_type(output);
629
630 match kind {
631 PromptReturnTypeKind::VecPromptMessage => quote! {
632 Ok(result)
633 },
634 PromptReturnTypeKind::ResultVecPromptMessage => quote! {
635 result.map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))
636 },
637 PromptReturnTypeKind::McpResultVecPromptMessage => quote! {
638 result
639 },
640 PromptReturnTypeKind::Other => quote! {
641 Ok(result)
643 },
644 }
645}
646
647enum ResourceReturnTypeKind {
653 String,
655 VecResourceContent,
657 ResultString,
659 McpResultString,
661 ResultVecResourceContent,
663 McpResultVecResourceContent,
665 Other,
667}
668
669fn analyze_resource_return_type(output: &syn::ReturnType) -> ResourceReturnTypeKind {
671 match output {
672 syn::ReturnType::Default => ResourceReturnTypeKind::Other, syn::ReturnType::Type(_, ty) => analyze_resource_type(ty),
674 }
675}
676
677fn analyze_resource_type(ty: &Type) -> ResourceReturnTypeKind {
679 if let Type::Path(type_path) = ty {
680 if let Some(segment) = type_path.path.segments.last() {
681 let type_name = segment.ident.to_string();
682
683 match type_name.as_str() {
684 "String" => return ResourceReturnTypeKind::String,
685 "Vec" => {
686 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
688 if let Some(syn::GenericArgument::Type(Type::Path(inner_path))) =
689 args.args.first()
690 {
691 if inner_path
692 .path
693 .segments
694 .last()
695 .is_some_and(|s| s.ident == "ResourceContent")
696 {
697 return ResourceReturnTypeKind::VecResourceContent;
698 }
699 }
700 }
701 }
702 "Result" | "McpResult" => {
703 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
705 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
706 let inner_kind = analyze_resource_type(inner_ty);
707 return match inner_kind {
708 ResourceReturnTypeKind::String => {
709 if type_name == "McpResult" {
710 ResourceReturnTypeKind::McpResultString
711 } else {
712 ResourceReturnTypeKind::ResultString
713 }
714 }
715 ResourceReturnTypeKind::VecResourceContent => {
716 if type_name == "McpResult" {
717 ResourceReturnTypeKind::McpResultVecResourceContent
718 } else {
719 ResourceReturnTypeKind::ResultVecResourceContent
720 }
721 }
722 _ => ResourceReturnTypeKind::Other,
723 };
724 }
725 }
726 }
727 _ => {}
728 }
729 }
730 }
731 ResourceReturnTypeKind::Other
732}
733
734fn generate_resource_result_conversion(output: &syn::ReturnType, mime_type: &str) -> TokenStream2 {
744 let kind = analyze_resource_return_type(output);
745
746 match kind {
747 ResourceReturnTypeKind::String => quote! {
748 let text = result;
749 Ok(vec![fastmcp_protocol::ResourceContent {
750 uri: uri.to_string(),
751 mime_type: Some(#mime_type.to_string()),
752 text: Some(text),
753 blob: None,
754 }])
755 },
756 ResourceReturnTypeKind::VecResourceContent => quote! {
757 Ok(result)
758 },
759 ResourceReturnTypeKind::ResultString => quote! {
760 let text = result.map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))?;
761 Ok(vec![fastmcp_protocol::ResourceContent {
762 uri: uri.to_string(),
763 mime_type: Some(#mime_type.to_string()),
764 text: Some(text),
765 blob: None,
766 }])
767 },
768 ResourceReturnTypeKind::McpResultString => quote! {
769 let text = result?;
770 Ok(vec![fastmcp_protocol::ResourceContent {
771 uri: uri.to_string(),
772 mime_type: Some(#mime_type.to_string()),
773 text: Some(text),
774 blob: None,
775 }])
776 },
777 ResourceReturnTypeKind::ResultVecResourceContent => quote! {
778 result.map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))
779 },
780 ResourceReturnTypeKind::McpResultVecResourceContent => quote! {
781 result
782 },
783 ResourceReturnTypeKind::Other => quote! {
784 let text = result.to_string();
786 Ok(vec![fastmcp_protocol::ResourceContent {
787 uri: uri.to_string(),
788 mime_type: Some(#mime_type.to_string()),
789 text: Some(text),
790 blob: None,
791 }])
792 },
793 }
794}
795
796fn type_to_json_schema(ty: &Type) -> TokenStream2 {
798 let Type::Path(type_path) = ty else {
799 return quote! { serde_json::json!({}) };
800 };
801
802 let segment = type_path.path.segments.last().unwrap();
803 let type_name = segment.ident.to_string();
804
805 match type_name.as_str() {
806 "String" | "str" => quote! {
807 serde_json::json!({ "type": "string" })
808 },
809 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
810 | "usize" => quote! {
811 serde_json::json!({ "type": "integer" })
812 },
813 "f32" | "f64" => quote! {
814 serde_json::json!({ "type": "number" })
815 },
816 "bool" => quote! {
817 serde_json::json!({ "type": "boolean" })
818 },
819 "Option" => {
820 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
822 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
823 return type_to_json_schema(inner_ty);
824 }
825 }
826 quote! { serde_json::json!({}) }
827 }
828 "Vec" => {
829 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
831 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
832 let inner_schema = type_to_json_schema(inner_ty);
833 return quote! {
834 serde_json::json!({
835 "type": "array",
836 "items": #inner_schema
837 })
838 };
839 }
840 }
841 quote! { serde_json::json!({ "type": "array" }) }
842 }
843 "HashSet" | "BTreeSet" => {
844 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
846 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
847 let inner_schema = type_to_json_schema(inner_ty);
848 return quote! {
849 serde_json::json!({
850 "type": "array",
851 "items": #inner_schema,
852 "uniqueItems": true
853 })
854 };
855 }
856 }
857 quote! { serde_json::json!({ "type": "array", "uniqueItems": true }) }
858 }
859 "HashMap" | "BTreeMap" => {
860 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
862 if args.args.len() >= 2 {
865 if let Some(syn::GenericArgument::Type(value_ty)) = args.args.iter().nth(1) {
866 let value_schema = type_to_json_schema(value_ty);
867 return quote! {
868 serde_json::json!({
869 "type": "object",
870 "additionalProperties": #value_schema
871 })
872 };
873 }
874 }
875 }
876 quote! { serde_json::json!({ "type": "object" }) }
877 }
878 "serde_json::Value" | "Value" => {
879 quote! { serde_json::json!({}) }
881 }
882 _ => {
883 quote! { <#ty>::json_schema() }
886 }
887 }
888}
889
890struct ToolAttrs {
896 name: Option<String>,
897 description: Option<String>,
898 timeout: Option<String>,
899 tags: Vec<String>,
900 defaults: HashMap<String, Lit>,
901 output_schema: Option<syn::Expr>,
903 version: Option<String>,
905 annotations_read_only: Option<bool>,
907 annotations_idempotent: Option<bool>,
908 annotations_destructive: Option<bool>,
909 annotations_open_world_hint: Option<String>,
910}
911
912impl Parse for ToolAttrs {
913 fn parse(input: ParseStream) -> syn::Result<Self> {
914 let mut name = None;
915 let mut description = None;
916 let mut timeout = None;
917 let mut tags = Vec::new();
918 let mut defaults: HashMap<String, Lit> = HashMap::new();
919 let mut output_schema = None;
920 let mut version = None;
921 let mut annotations_read_only = None;
922 let mut annotations_idempotent = None;
923 let mut annotations_destructive = None;
924 let mut annotations_open_world_hint = None;
925
926 while !input.is_empty() {
927 let ident: Ident = input.parse()?;
928
929 match ident.to_string().as_str() {
930 "name" => {
931 input.parse::<Token![=]>()?;
932 let lit: LitStr = input.parse()?;
933 name = Some(lit.value());
934 }
935 "description" => {
936 input.parse::<Token![=]>()?;
937 let lit: LitStr = input.parse()?;
938 description = Some(lit.value());
939 }
940 "timeout" => {
941 input.parse::<Token![=]>()?;
942 let lit: LitStr = input.parse()?;
943 timeout = Some(lit.value());
944 }
945 "version" => {
946 input.parse::<Token![=]>()?;
947 let lit: LitStr = input.parse()?;
948 version = Some(lit.value());
949 }
950 "tags" => {
951 input.parse::<Token![=]>()?;
952 let expr_array: syn::ExprArray = input.parse()?;
953 for expr in expr_array.elems {
954 match expr {
955 syn::Expr::Lit(syn::ExprLit {
956 lit: Lit::Str(tag), ..
957 }) => tags.push(tag.value()),
958 other => {
959 return Err(syn::Error::new_spanned(
960 other,
961 "tags entries must be string literals",
962 ));
963 }
964 }
965 }
966 }
967 "defaults" => {
968 let content;
969 syn::parenthesized!(content in input);
970 while !content.is_empty() {
971 let key: Ident = content.parse()?;
972 content.parse::<Token![=]>()?;
973 let lit: Lit = content.parse()?;
974 defaults.insert(key.to_string(), lit);
975 if !content.is_empty() {
976 content.parse::<Token![,]>()?;
977 }
978 }
979 }
980 "output_schema" => {
981 input.parse::<Token![=]>()?;
982 let expr: syn::Expr = input.parse()?;
984 output_schema = Some(expr);
985 }
986 "annotations" => {
987 let content;
988 syn::parenthesized!(content in input);
989 while !content.is_empty() {
990 let ann_ident: Ident = content.parse()?;
991 match ann_ident.to_string().as_str() {
992 "read_only" => annotations_read_only = Some(true),
993 "idempotent" => annotations_idempotent = Some(true),
994 "destructive" => annotations_destructive = Some(true),
995 "open_world_hint" => {
996 content.parse::<Token![=]>()?;
997 let lit: LitStr = content.parse()?;
998 annotations_open_world_hint = Some(lit.value());
999 }
1000 other => {
1001 return Err(syn::Error::new(
1002 ann_ident.span(),
1003 format!(
1004 "unknown annotation: {other}; expected read_only, idempotent, destructive, or open_world_hint"
1005 ),
1006 ));
1007 }
1008 }
1009 if !content.is_empty() {
1010 content.parse::<Token![,]>()?;
1011 }
1012 }
1013 }
1014 _ => {
1015 return Err(syn::Error::new(ident.span(), "unknown attribute"));
1016 }
1017 }
1018
1019 if !input.is_empty() {
1020 input.parse::<Token![,]>()?;
1021 }
1022 }
1023
1024 Ok(Self {
1025 name,
1026 description,
1027 timeout,
1028 tags,
1029 defaults,
1030 output_schema,
1031 version,
1032 annotations_read_only,
1033 annotations_idempotent,
1034 annotations_destructive,
1035 annotations_open_world_hint,
1036 })
1037 }
1038}
1039
1040#[proc_macro_attribute]
1067#[allow(clippy::too_many_lines)]
1068pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
1069 let attrs = parse_macro_input!(attr as ToolAttrs);
1070 let input_fn = parse_macro_input!(item as ItemFn);
1071
1072 let fn_name = &input_fn.sig.ident;
1073 let fn_name_str = fn_name.to_string();
1074
1075 let handler_name = format_ident!("{}", to_pascal_case(&fn_name_str));
1077
1078 let tool_name = attrs.name.unwrap_or_else(|| fn_name_str.clone());
1080
1081 let description = attrs
1083 .description
1084 .or_else(|| extract_doc_comments(&input_fn.attrs));
1085 let description_tokens = description.as_ref().map_or_else(
1086 || quote! { None },
1087 |desc| quote! { Some(#desc.to_string()) },
1088 );
1089
1090 let timeout_tokens = if let Some(ref timeout_str) = attrs.timeout {
1092 match parse_duration_to_millis(timeout_str) {
1093 Ok(millis) => {
1094 quote! {
1095 fn timeout(&self) -> Option<std::time::Duration> {
1096 Some(std::time::Duration::from_millis(#millis))
1097 }
1098 }
1099 }
1100 Err(e) => {
1101 return syn::Error::new_spanned(
1102 &input_fn.sig.ident,
1103 format!("invalid timeout: {e}"),
1104 )
1105 .to_compile_error()
1106 .into();
1107 }
1108 }
1109 } else {
1110 quote! {}
1111 };
1112
1113 let (output_schema_field, output_schema_method) =
1115 if let Some(ref schema_expr) = attrs.output_schema {
1116 (
1117 quote! { Some(#schema_expr) },
1118 quote! {
1119 fn output_schema(&self) -> Option<serde_json::Value> {
1120 Some(#schema_expr)
1121 }
1122 },
1123 )
1124 } else {
1125 (quote! { None }, quote! {})
1126 };
1127
1128 let tag_entries: Vec<TokenStream2> = attrs
1129 .tags
1130 .iter()
1131 .map(|tag| quote! { #tag.to_string() })
1132 .collect();
1133
1134 let version_tokens = attrs
1136 .version
1137 .as_ref()
1138 .map_or_else(|| quote! { None }, |v| quote! { Some(#v.to_string()) });
1139
1140 let has_annotations = attrs.annotations_read_only.is_some()
1142 || attrs.annotations_idempotent.is_some()
1143 || attrs.annotations_destructive.is_some()
1144 || attrs.annotations_open_world_hint.is_some();
1145
1146 let annotations_tokens = if has_annotations {
1147 let ro = attrs
1148 .annotations_read_only
1149 .map_or_else(|| quote! { None }, |v| quote! { Some(#v) });
1150 let idem = attrs
1151 .annotations_idempotent
1152 .map_or_else(|| quote! { None }, |v| quote! { Some(#v) });
1153 let destr = attrs
1154 .annotations_destructive
1155 .map_or_else(|| quote! { None }, |v| quote! { Some(#v) });
1156 let owh = attrs
1157 .annotations_open_world_hint
1158 .as_ref()
1159 .map_or_else(|| quote! { None }, |v| quote! { Some(#v.to_string()) });
1160 quote! {
1161 Some(fastmcp_protocol::ToolAnnotations {
1162 read_only: #ro,
1163 idempotent: #idem,
1164 destructive: #destr,
1165 open_world_hint: #owh,
1166 })
1167 }
1168 } else {
1169 quote! { None }
1170 };
1171
1172 let mut params: Vec<(&Ident, &Type, Option<String>, Option<Lit>)> = Vec::new();
1174 let mut required_params: Vec<String> = Vec::new();
1175 let mut expects_context = false;
1176
1177 for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
1178 if let FnArg::Typed(pat_type) = arg {
1179 if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
1181 expects_context = true;
1182 continue;
1183 }
1184
1185 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1186 let param_name = &pat_ident.ident;
1187 let param_type = pat_type.ty.as_ref();
1188 let param_doc = extract_doc_comments(&pat_type.attrs);
1189 let param_default = attrs.defaults.get(¶m_name.to_string()).cloned();
1190
1191 let is_optional = is_option_type(param_type);
1193
1194 if !is_optional && param_default.is_none() {
1195 required_params.push(param_name.to_string());
1196 }
1197
1198 params.push((param_name, param_type, param_doc, param_default));
1199 }
1200 }
1201 }
1202
1203 let property_entries: Vec<TokenStream2> = params
1205 .iter()
1206 .map(|(name, ty, doc, default_expr)| {
1207 let name_str = name.to_string();
1208 let schema = type_to_json_schema(ty);
1209
1210 let default_insert = default_expr.as_ref().map_or_else(
1211 || quote! {},
1212 |lit| {
1213 quote! {
1214 obj.insert("default".to_string(), serde_json::json!(#lit));
1215 }
1216 },
1217 );
1218
1219 match (doc.as_ref(), default_expr.as_ref()) {
1220 (None, None) => quote! {
1221 (#name_str.to_string(), #schema)
1222 },
1223 (Some(desc), _) => quote! {
1224 (#name_str.to_string(), {
1225 let mut s = #schema;
1226 if let Some(obj) = s.as_object_mut() {
1227 obj.insert("description".to_string(), serde_json::json!(#desc));
1228 #default_insert
1229 }
1230 s
1231 })
1232 },
1233 (None, Some(_)) => quote! {
1234 (#name_str.to_string(), {
1235 let mut s = #schema;
1236 if let Some(obj) = s.as_object_mut() {
1237 #default_insert
1238 }
1239 s
1240 })
1241 },
1242 }
1243 })
1244 .collect();
1245
1246 let mut param_extractions: Vec<TokenStream2> = Vec::new();
1248 for (name, ty, _, default_lit) in ¶ms {
1249 let name_str = name.to_string();
1250 let is_optional = is_option_type(ty);
1251
1252 if is_optional {
1253 if let Some(default_lit) = default_lit {
1254 let default_expr = match default_lit_expr_for_type(default_lit, ty) {
1255 Ok(v) => v,
1256 Err(e) => return e.to_compile_error().into(),
1257 };
1258 param_extractions.push(quote! {
1259 let #name: #ty = match arguments.get(#name_str) {
1260 Some(value) => Some(
1261 serde_json::from_value(value.clone()).map_err(|e| {
1262 fastmcp_core::McpError::invalid_params(e.to_string())
1263 })?,
1264 ),
1265 None => #default_expr,
1266 };
1267 });
1268 } else {
1269 param_extractions.push(quote! {
1270 let #name: #ty = match arguments.get(#name_str) {
1271 Some(value) => Some(
1272 serde_json::from_value(value.clone()).map_err(|e| {
1273 fastmcp_core::McpError::invalid_params(e.to_string())
1274 })?,
1275 ),
1276 None => None,
1277 };
1278 });
1279 }
1280 } else if let Some(default_lit) = default_lit {
1281 let default_expr = match default_lit_expr_for_type(default_lit, ty) {
1282 Ok(v) => v,
1283 Err(e) => return e.to_compile_error().into(),
1284 };
1285 param_extractions.push(quote! {
1286 let #name: #ty = match arguments.get(#name_str) {
1287 Some(v) => serde_json::from_value(v.clone())
1288 .map_err(|e| fastmcp_core::McpError::invalid_params(e.to_string()))?,
1289 None => #default_expr,
1290 };
1291 });
1292 } else {
1293 param_extractions.push(quote! {
1294 let #name: #ty = arguments.get(#name_str)
1295 .ok_or_else(|| fastmcp_core::McpError::invalid_params(
1296 format!("missing required parameter: {}", #name_str)
1297 ))
1298 .and_then(|v| serde_json::from_value(v.clone())
1299 .map_err(|e| fastmcp_core::McpError::invalid_params(e.to_string())))?;
1300 });
1301 }
1302 }
1303
1304 let param_names: Vec<&Ident> = params.iter().map(|(name, _, _, _)| *name).collect();
1306
1307 let is_async = input_fn.sig.asyncness.is_some();
1309
1310 let return_type = &input_fn.sig.output;
1312 let result_conversion = generate_result_conversion(return_type);
1313
1314 let call_expr = if is_async {
1316 if expects_context {
1317 quote! {
1318 fastmcp_core::runtime::block_on(async move {
1319 #fn_name(ctx, #(#param_names),*).await
1320 })
1321 }
1322 } else {
1323 quote! {
1324 fastmcp_core::runtime::block_on(async move {
1325 #fn_name(#(#param_names),*).await
1326 })
1327 }
1328 }
1329 } else {
1330 if expects_context {
1331 quote! {
1332 #fn_name(ctx, #(#param_names),*)
1333 }
1334 } else {
1335 quote! {
1336 #fn_name(#(#param_names),*)
1337 }
1338 }
1339 };
1340
1341 let expanded = quote! {
1343 #input_fn
1345
1346 #[derive(Clone)]
1348 pub struct #handler_name;
1349
1350 impl fastmcp_server::ToolHandler for #handler_name {
1351 fn definition(&self) -> fastmcp_protocol::Tool {
1352 let properties: std::collections::HashMap<String, serde_json::Value> = vec![
1353 #(#property_entries),*
1354 ].into_iter().collect();
1355
1356 let required: Vec<String> = vec![#(#required_params.to_string()),*];
1357
1358 fastmcp_protocol::Tool {
1359 name: #tool_name.to_string(),
1360 description: #description_tokens,
1361 input_schema: serde_json::json!({
1362 "type": "object",
1363 "properties": properties,
1364 "required": required,
1365 }),
1366 output_schema: #output_schema_field,
1367 icon: None,
1368 version: #version_tokens,
1369 tags: vec![#(#tag_entries),*],
1370 annotations: #annotations_tokens,
1371 }
1372 }
1373
1374 #timeout_tokens
1375
1376 #output_schema_method
1377
1378 fn call(
1379 &self,
1380 ctx: &fastmcp_core::McpContext,
1381 arguments: serde_json::Value,
1382 ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::Content>> {
1383 let arguments = arguments.as_object()
1385 .cloned()
1386 .unwrap_or_default();
1387
1388 #(#param_extractions)*
1390
1391 let result = #call_expr;
1393
1394 #result_conversion
1396 }
1397 }
1398 };
1399
1400 TokenStream::from(expanded)
1401}
1402
1403struct ResourceAttrs {
1409 uri: Option<String>,
1410 name: Option<String>,
1411 description: Option<String>,
1412 mime_type: Option<String>,
1413 timeout: Option<String>,
1414 version: Option<String>,
1415 tags: Vec<String>,
1416}
1417
1418impl Parse for ResourceAttrs {
1419 fn parse(input: ParseStream) -> syn::Result<Self> {
1420 let mut uri = None;
1421 let mut name = None;
1422 let mut description = None;
1423 let mut mime_type = None;
1424 let mut timeout = None;
1425 let mut version = None;
1426 let mut tags = Vec::new();
1427
1428 while !input.is_empty() {
1429 let ident: Ident = input.parse()?;
1430
1431 match ident.to_string().as_str() {
1432 "tags" => {
1433 input.parse::<Token![=]>()?;
1434 let expr_array: syn::ExprArray = input.parse()?;
1435 for expr in expr_array.elems {
1436 match expr {
1437 syn::Expr::Lit(syn::ExprLit {
1438 lit: Lit::Str(tag), ..
1439 }) => tags.push(tag.value()),
1440 other => {
1441 return Err(syn::Error::new_spanned(
1442 other,
1443 "tags entries must be string literals",
1444 ));
1445 }
1446 }
1447 }
1448 }
1449 _ => {
1450 input.parse::<Token![=]>()?;
1451 match ident.to_string().as_str() {
1452 "uri" => {
1453 let lit: LitStr = input.parse()?;
1454 uri = Some(lit.value());
1455 }
1456 "name" => {
1457 let lit: LitStr = input.parse()?;
1458 name = Some(lit.value());
1459 }
1460 "description" => {
1461 let lit: LitStr = input.parse()?;
1462 description = Some(lit.value());
1463 }
1464 "mime_type" => {
1465 let lit: LitStr = input.parse()?;
1466 mime_type = Some(lit.value());
1467 }
1468 "timeout" => {
1469 let lit: LitStr = input.parse()?;
1470 timeout = Some(lit.value());
1471 }
1472 "version" => {
1473 let lit: LitStr = input.parse()?;
1474 version = Some(lit.value());
1475 }
1476 _ => {
1477 return Err(syn::Error::new(ident.span(), "unknown attribute"));
1478 }
1479 }
1480 }
1481 }
1482
1483 if !input.is_empty() {
1484 input.parse::<Token![,]>()?;
1485 }
1486 }
1487
1488 Ok(Self {
1489 uri,
1490 name,
1491 description,
1492 mime_type,
1493 timeout,
1494 version,
1495 tags,
1496 })
1497 }
1498}
1499
1500#[proc_macro_attribute]
1509#[allow(clippy::too_many_lines)]
1510pub fn resource(attr: TokenStream, item: TokenStream) -> TokenStream {
1511 let attrs = parse_macro_input!(attr as ResourceAttrs);
1512 let input_fn = parse_macro_input!(item as ItemFn);
1513
1514 let fn_name = &input_fn.sig.ident;
1515 let fn_name_str = fn_name.to_string();
1516
1517 let handler_name = format_ident!("{}Resource", to_pascal_case(&fn_name_str));
1519
1520 let Some(uri) = attrs.uri else {
1522 return syn::Error::new_spanned(&input_fn.sig.ident, "resource requires uri attribute")
1523 .to_compile_error()
1524 .into();
1525 };
1526
1527 let resource_name = attrs.name.unwrap_or_else(|| fn_name_str.clone());
1529 let description = attrs
1530 .description
1531 .or_else(|| extract_doc_comments(&input_fn.attrs));
1532 let mime_type = attrs.mime_type.unwrap_or_else(|| "text/plain".to_string());
1533
1534 let description_tokens = description.as_ref().map_or_else(
1535 || quote! { None },
1536 |desc| quote! { Some(#desc.to_string()) },
1537 );
1538
1539 let timeout_tokens = if let Some(ref timeout_str) = attrs.timeout {
1541 match parse_duration_to_millis(timeout_str) {
1542 Ok(millis) => {
1543 quote! {
1544 fn timeout(&self) -> Option<std::time::Duration> {
1545 Some(std::time::Duration::from_millis(#millis))
1546 }
1547 }
1548 }
1549 Err(e) => {
1550 return syn::Error::new_spanned(
1551 &input_fn.sig.ident,
1552 format!("invalid timeout: {e}"),
1553 )
1554 .to_compile_error()
1555 .into();
1556 }
1557 }
1558 } else {
1559 quote! {}
1560 };
1561
1562 let version_tokens = attrs
1564 .version
1565 .as_ref()
1566 .map_or_else(|| quote! { None }, |v| quote! { Some(#v.to_string()) });
1567
1568 let tag_entries: Vec<TokenStream2> = attrs
1570 .tags
1571 .iter()
1572 .map(|tag| quote! { #tag.to_string() })
1573 .collect();
1574
1575 let template_params = extract_template_params(&uri);
1576
1577 let mut params: Vec<(&Ident, &Type)> = Vec::new();
1579 let mut expects_context = false;
1580
1581 for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
1582 if let FnArg::Typed(pat_type) = arg {
1583 if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
1584 expects_context = true;
1585 continue;
1586 }
1587
1588 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1589 let param_name = &pat_ident.ident;
1590 let param_type = pat_type.ty.as_ref();
1591 params.push((param_name, param_type));
1592 }
1593 }
1594 }
1595
1596 if template_params.is_empty() && !params.is_empty() {
1597 return syn::Error::new_spanned(
1598 &input_fn.sig.ident,
1599 "resource parameters require a URI template with matching {params}",
1600 )
1601 .to_compile_error()
1602 .into();
1603 }
1604
1605 let missing_params: Vec<String> = params
1606 .iter()
1607 .map(|(name, _)| name.to_string())
1608 .filter(|name| !template_params.contains(name))
1609 .collect();
1610
1611 if !missing_params.is_empty() {
1612 return syn::Error::new_spanned(
1613 &input_fn.sig.ident,
1614 format!(
1615 "resource parameters missing from uri template: {}",
1616 missing_params.join(", ")
1617 ),
1618 )
1619 .to_compile_error()
1620 .into();
1621 }
1622
1623 let is_template = !template_params.is_empty();
1624
1625 let param_extractions: Vec<TokenStream2> = params
1626 .iter()
1627 .map(|(name, ty)| {
1628 let name_str = name.to_string();
1629 if let Some(inner_ty) = option_inner_type(ty) {
1630 if is_string_type(inner_ty) {
1631 quote! {
1632 let #name: #ty = uri_params.get(#name_str).cloned();
1633 }
1634 } else {
1635 quote! {
1636 let #name: #ty = match uri_params.get(#name_str) {
1637 Some(value) => Some(value.parse().map_err(|_| {
1638 fastmcp_core::McpError::invalid_params(
1639 format!("invalid uri parameter: {}", #name_str)
1640 )
1641 })?),
1642 None => None,
1643 };
1644 }
1645 }
1646 } else if is_string_type(ty) {
1647 quote! {
1648 let #name: #ty = uri_params
1649 .get(#name_str)
1650 .ok_or_else(|| fastmcp_core::McpError::invalid_params(
1651 format!("missing uri parameter: {}", #name_str)
1652 ))?
1653 .clone();
1654 }
1655 } else {
1656 quote! {
1657 let #name: #ty = uri_params
1658 .get(#name_str)
1659 .ok_or_else(|| fastmcp_core::McpError::invalid_params(
1660 format!("missing uri parameter: {}", #name_str)
1661 ))?
1662 .parse()
1663 .map_err(|_| fastmcp_core::McpError::invalid_params(
1664 format!("invalid uri parameter: {}", #name_str)
1665 ))?;
1666 }
1667 }
1668 })
1669 .collect();
1670
1671 let param_names: Vec<&Ident> = params.iter().map(|(name, _)| *name).collect();
1672 let call_args = if expects_context {
1673 quote! { ctx, #(#param_names),* }
1674 } else {
1675 quote! { #(#param_names),* }
1676 };
1677
1678 let is_async = input_fn.sig.asyncness.is_some();
1679 let call_expr = if is_async {
1680 quote! {
1681 fastmcp_core::runtime::block_on(async move {
1682 #fn_name(#call_args).await
1683 })
1684 }
1685 } else {
1686 quote! {
1687 #fn_name(#call_args)
1688 }
1689 };
1690
1691 let template_tokens = if is_template {
1692 quote! {
1693 Some(fastmcp_protocol::ResourceTemplate {
1694 uri_template: #uri.to_string(),
1695 name: #resource_name.to_string(),
1696 description: #description_tokens,
1697 mime_type: Some(#mime_type.to_string()),
1698 icon: None,
1699 version: #version_tokens,
1700 tags: vec![#(#tag_entries),*],
1701 })
1702 }
1703 } else {
1704 quote! { None }
1705 };
1706
1707 let return_type = &input_fn.sig.output;
1709 let resource_result_conversion = generate_resource_result_conversion(return_type, &mime_type);
1710
1711 let expanded = quote! {
1712 #input_fn
1714
1715 #[derive(Clone)]
1717 pub struct #handler_name;
1718
1719 impl fastmcp_server::ResourceHandler for #handler_name {
1720 fn definition(&self) -> fastmcp_protocol::Resource {
1721 fastmcp_protocol::Resource {
1722 uri: #uri.to_string(),
1723 name: #resource_name.to_string(),
1724 description: #description_tokens,
1725 mime_type: Some(#mime_type.to_string()),
1726 icon: None,
1727 version: #version_tokens,
1728 tags: vec![#(#tag_entries),*],
1729 }
1730 }
1731
1732 fn template(&self) -> Option<fastmcp_protocol::ResourceTemplate> {
1733 #template_tokens
1734 }
1735
1736 #timeout_tokens
1737
1738 fn read(
1739 &self,
1740 ctx: &fastmcp_core::McpContext,
1741 ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::ResourceContent>> {
1742 let uri_params = std::collections::HashMap::new();
1743 self.read_with_uri(ctx, #uri, &uri_params)
1744 }
1745
1746 fn read_with_uri(
1747 &self,
1748 ctx: &fastmcp_core::McpContext,
1749 uri: &str,
1750 uri_params: &std::collections::HashMap<String, String>,
1751 ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::ResourceContent>> {
1752 #(#param_extractions)*
1753 let result = #call_expr;
1754 #resource_result_conversion
1755 }
1756
1757 fn read_async_with_uri<'a>(
1758 &'a self,
1759 ctx: &'a fastmcp_core::McpContext,
1760 uri: &'a str,
1761 uri_params: &'a std::collections::HashMap<String, String>,
1762 ) -> fastmcp_server::BoxFuture<'a, fastmcp_core::McpOutcome<Vec<fastmcp_protocol::ResourceContent>>> {
1763 Box::pin(async move {
1764 match self.read_with_uri(ctx, uri, uri_params) {
1765 Ok(value) => fastmcp_core::Outcome::Ok(value),
1766 Err(error) => fastmcp_core::Outcome::Err(error),
1767 }
1768 })
1769 }
1770 }
1771 };
1772
1773 TokenStream::from(expanded)
1774}
1775
1776struct PromptAttrs {
1782 name: Option<String>,
1783 description: Option<String>,
1784 timeout: Option<String>,
1785 defaults: HashMap<String, Lit>,
1786 version: Option<String>,
1787 tags: Vec<String>,
1788}
1789
1790impl Parse for PromptAttrs {
1791 fn parse(input: ParseStream) -> syn::Result<Self> {
1792 let mut name = None;
1793 let mut description = None;
1794 let mut timeout = None;
1795 let mut defaults: HashMap<String, Lit> = HashMap::new();
1796 let mut version = None;
1797 let mut tags = Vec::new();
1798
1799 while !input.is_empty() {
1800 let ident: Ident = input.parse()?;
1801
1802 match ident.to_string().as_str() {
1803 "name" => {
1804 input.parse::<Token![=]>()?;
1805 let lit: LitStr = input.parse()?;
1806 name = Some(lit.value());
1807 }
1808 "description" => {
1809 input.parse::<Token![=]>()?;
1810 let lit: LitStr = input.parse()?;
1811 description = Some(lit.value());
1812 }
1813 "timeout" => {
1814 input.parse::<Token![=]>()?;
1815 let lit: LitStr = input.parse()?;
1816 timeout = Some(lit.value());
1817 }
1818 "version" => {
1819 input.parse::<Token![=]>()?;
1820 let lit: LitStr = input.parse()?;
1821 version = Some(lit.value());
1822 }
1823 "tags" => {
1824 input.parse::<Token![=]>()?;
1825 let expr_array: syn::ExprArray = input.parse()?;
1826 for expr in expr_array.elems {
1827 match expr {
1828 syn::Expr::Lit(syn::ExprLit {
1829 lit: Lit::Str(tag), ..
1830 }) => tags.push(tag.value()),
1831 other => {
1832 return Err(syn::Error::new_spanned(
1833 other,
1834 "tags entries must be string literals",
1835 ));
1836 }
1837 }
1838 }
1839 }
1840 "defaults" => {
1841 let content;
1842 syn::parenthesized!(content in input);
1843 while !content.is_empty() {
1844 let key: Ident = content.parse()?;
1845 content.parse::<Token![=]>()?;
1846 let lit: Lit = content.parse()?;
1847 defaults.insert(key.to_string(), lit);
1848 if !content.is_empty() {
1849 content.parse::<Token![,]>()?;
1850 }
1851 }
1852 }
1853 _ => {
1854 return Err(syn::Error::new(ident.span(), "unknown attribute"));
1855 }
1856 }
1857
1858 if !input.is_empty() {
1859 input.parse::<Token![,]>()?;
1860 }
1861 }
1862
1863 Ok(Self {
1864 name,
1865 description,
1866 timeout,
1867 defaults,
1868 version,
1869 tags,
1870 })
1871 }
1872}
1873
1874#[proc_macro_attribute]
1893#[allow(clippy::too_many_lines)]
1894pub fn prompt(attr: TokenStream, item: TokenStream) -> TokenStream {
1895 let attrs = parse_macro_input!(attr as PromptAttrs);
1896 let input_fn = parse_macro_input!(item as ItemFn);
1897
1898 let fn_name = &input_fn.sig.ident;
1899 let fn_name_str = fn_name.to_string();
1900
1901 let handler_name = format_ident!("{}Prompt", to_pascal_case(&fn_name_str));
1903
1904 let prompt_name = attrs.name.unwrap_or_else(|| fn_name_str.clone());
1906
1907 let description = attrs
1909 .description
1910 .or_else(|| extract_doc_comments(&input_fn.attrs));
1911 let description_tokens = description.as_ref().map_or_else(
1912 || quote! { None },
1913 |desc| quote! { Some(#desc.to_string()) },
1914 );
1915
1916 let timeout_tokens = if let Some(ref timeout_str) = attrs.timeout {
1918 match parse_duration_to_millis(timeout_str) {
1919 Ok(millis) => {
1920 quote! {
1921 fn timeout(&self) -> Option<std::time::Duration> {
1922 Some(std::time::Duration::from_millis(#millis))
1923 }
1924 }
1925 }
1926 Err(e) => {
1927 return syn::Error::new_spanned(
1928 &input_fn.sig.ident,
1929 format!("invalid timeout: {e}"),
1930 )
1931 .to_compile_error()
1932 .into();
1933 }
1934 }
1935 } else {
1936 quote! {}
1937 };
1938
1939 let mut prompt_args: Vec<TokenStream2> = Vec::new();
1941 let mut expects_context = false;
1942
1943 for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
1944 if let FnArg::Typed(pat_type) = arg {
1945 if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
1947 expects_context = true;
1948 continue;
1949 }
1950
1951 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1952 let param_name = pat_ident.ident.to_string();
1953 let param_doc = extract_doc_comments(&pat_type.attrs);
1954 let is_optional = is_option_type(pat_type.ty.as_ref());
1955 let has_default = attrs.defaults.contains_key(¶m_name);
1956 let required = !(is_optional || has_default);
1957
1958 let desc_tokens = param_doc
1959 .as_ref()
1960 .map_or_else(|| quote! { None }, |d| quote! { Some(#d.to_string()) });
1961
1962 prompt_args.push(quote! {
1963 fastmcp_protocol::PromptArgument {
1964 name: #param_name.to_string(),
1965 description: #desc_tokens,
1966 required: #required,
1967 }
1968 });
1969 }
1970 }
1971 }
1972
1973 let mut param_extractions: Vec<TokenStream2> = Vec::new();
1975 let mut param_names: Vec<Ident> = Vec::new();
1976
1977 for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
1978 if let FnArg::Typed(pat_type) = arg {
1979 if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
1981 continue;
1982 }
1983
1984 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1985 let param_name = &pat_ident.ident;
1986 let param_name_str = param_name.to_string();
1987 let param_ty = pat_type.ty.as_ref();
1988 let is_optional = is_option_type(param_ty);
1989 let default_lit = attrs.defaults.get(¶m_name_str).cloned();
1990
1991 param_names.push(param_name.clone());
1992
1993 if is_optional {
1994 if let Some(default_lit) = default_lit {
1995 let default_expr = match default_lit_expr_for_type(&default_lit, param_ty) {
1996 Ok(v) => v,
1997 Err(e) => return e.to_compile_error().into(),
1998 };
1999 param_extractions.push(quote! {
2001 let #param_name: #param_ty = match arguments.get(#param_name_str) {
2002 Some(v) => Some(v.clone()),
2003 None => #default_expr,
2004 };
2005 });
2006 } else {
2007 param_extractions.push(quote! {
2009 let #param_name: #param_ty = arguments.get(#param_name_str).cloned();
2010 });
2011 }
2012 } else {
2013 if let Some(default_lit) = default_lit {
2014 let default_expr = match default_lit_expr_for_type(&default_lit, param_ty) {
2015 Ok(v) => v,
2016 Err(e) => return e.to_compile_error().into(),
2017 };
2018 param_extractions.push(quote! {
2020 let #param_name: #param_ty = match arguments.get(#param_name_str) {
2021 Some(v) => v.clone(),
2022 None => #default_expr,
2023 };
2024 });
2025 } else {
2026 param_extractions.push(quote! {
2028 let #param_name: #param_ty = arguments.get(#param_name_str)
2029 .cloned()
2030 .ok_or_else(|| fastmcp_core::McpError::invalid_params(
2031 format!("missing required argument: {}", #param_name_str)
2032 ))?;
2033 });
2034 }
2035 }
2036 }
2037 }
2038 }
2039
2040 let is_async = input_fn.sig.asyncness.is_some();
2041 let call_expr = if is_async {
2042 if expects_context {
2043 quote! {
2044 fastmcp_core::runtime::block_on(async move {
2045 #fn_name(ctx, #(#param_names),*).await
2046 })
2047 }
2048 } else {
2049 quote! {
2050 fastmcp_core::runtime::block_on(async move {
2051 #fn_name(#(#param_names),*).await
2052 })
2053 }
2054 }
2055 } else {
2056 if expects_context {
2057 quote! {
2058 #fn_name(ctx, #(#param_names),*)
2059 }
2060 } else {
2061 quote! {
2062 #fn_name(#(#param_names),*)
2063 }
2064 }
2065 };
2066
2067 let return_type = &input_fn.sig.output;
2069 let prompt_result_conversion = generate_prompt_result_conversion(return_type);
2070
2071 let version_tokens = attrs
2073 .version
2074 .as_ref()
2075 .map_or_else(|| quote! { None }, |v| quote! { Some(#v.to_string()) });
2076
2077 let tag_entries: Vec<TokenStream2> = attrs
2079 .tags
2080 .iter()
2081 .map(|tag| quote! { #tag.to_string() })
2082 .collect();
2083
2084 let expanded = quote! {
2085 #input_fn
2087
2088 #[derive(Clone)]
2090 pub struct #handler_name;
2091
2092 impl fastmcp_server::PromptHandler for #handler_name {
2093 fn definition(&self) -> fastmcp_protocol::Prompt {
2094 fastmcp_protocol::Prompt {
2095 name: #prompt_name.to_string(),
2096 description: #description_tokens,
2097 arguments: vec![#(#prompt_args),*],
2098 icon: None,
2099 version: #version_tokens,
2100 tags: vec![#(#tag_entries),*],
2101 }
2102 }
2103
2104 #timeout_tokens
2105
2106 fn get(
2107 &self,
2108 ctx: &fastmcp_core::McpContext,
2109 arguments: std::collections::HashMap<String, String>,
2110 ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::PromptMessage>> {
2111 #(#param_extractions)*
2112 let result = #call_expr;
2113 #prompt_result_conversion
2114 }
2115 }
2116 };
2117
2118 TokenStream::from(expanded)
2119}
2120
2121#[proc_macro_derive(JsonSchema, attributes(json_schema))]
2170pub fn derive_json_schema(input: TokenStream) -> TokenStream {
2171 let input = parse_macro_input!(input as syn::DeriveInput);
2172
2173 let name = &input.ident;
2174 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
2175
2176 let type_description = extract_doc_comments(&input.attrs);
2178 let type_desc_tokens = type_description
2179 .as_ref()
2180 .map_or_else(|| quote! { None::<&str> }, |desc| quote! { Some(#desc) });
2181
2182 let schema_impl = match &input.data {
2184 syn::Data::Struct(data_struct) => generate_struct_schema(data_struct, &type_desc_tokens),
2185 syn::Data::Enum(data_enum) => generate_enum_schema(data_enum, &type_desc_tokens),
2186 syn::Data::Union(_) => {
2187 return syn::Error::new_spanned(input, "JsonSchema cannot be derived for unions")
2188 .to_compile_error()
2189 .into();
2190 }
2191 };
2192
2193 let expanded = quote! {
2194 impl #impl_generics #name #ty_generics #where_clause {
2195 pub fn json_schema() -> serde_json::Value {
2197 #schema_impl
2198 }
2199 }
2200 };
2201
2202 TokenStream::from(expanded)
2203}
2204
2205fn generate_struct_schema(data: &syn::DataStruct, type_desc_tokens: &TokenStream2) -> TokenStream2 {
2207 match &data.fields {
2208 syn::Fields::Named(fields) => {
2209 let mut property_entries = Vec::new();
2210 let mut required_fields = Vec::new();
2211
2212 for field in &fields.named {
2213 if has_json_schema_attr(&field.attrs, "skip") {
2215 continue;
2216 }
2217
2218 let field_name = field.ident.as_ref().unwrap();
2219
2220 let schema_name =
2222 get_json_schema_rename(&field.attrs).unwrap_or_else(|| field_name.to_string());
2223
2224 let field_doc = extract_doc_comments(&field.attrs);
2226
2227 let field_type = &field.ty;
2229 let is_optional = is_option_type(field_type);
2230
2231 let field_schema = type_to_json_schema(field_type);
2233
2234 let property_value = if let Some(desc) = &field_doc {
2236 quote! {
2237 {
2238 let mut schema = #field_schema;
2239 if let Some(obj) = schema.as_object_mut() {
2240 obj.insert("description".to_string(), serde_json::json!(#desc));
2241 }
2242 schema
2243 }
2244 }
2245 } else {
2246 field_schema
2247 };
2248
2249 property_entries.push(quote! {
2250 (#schema_name.to_string(), #property_value)
2251 });
2252
2253 if !is_optional {
2255 required_fields.push(schema_name);
2256 }
2257 }
2258
2259 quote! {
2260 {
2261 let properties: std::collections::HashMap<String, serde_json::Value> = vec![
2262 #(#property_entries),*
2263 ].into_iter().collect();
2264
2265 let required: Vec<String> = vec![#(#required_fields.to_string()),*];
2266
2267 let mut schema = serde_json::json!({
2268 "type": "object",
2269 "properties": properties,
2270 "required": required,
2271 });
2272
2273 if let Some(desc) = #type_desc_tokens {
2275 if let Some(obj) = schema.as_object_mut() {
2276 obj.insert("description".to_string(), serde_json::json!(desc));
2277 }
2278 }
2279
2280 schema
2281 }
2282 }
2283 }
2284 syn::Fields::Unnamed(fields) => {
2285 if fields.unnamed.len() == 1 {
2287 let inner_type = &fields.unnamed.first().unwrap().ty;
2289 let inner_schema = type_to_json_schema(inner_type);
2290 quote! { #inner_schema }
2291 } else {
2292 let item_schemas: Vec<_> = fields
2294 .unnamed
2295 .iter()
2296 .map(|f| type_to_json_schema(&f.ty))
2297 .collect();
2298 let num_items = item_schemas.len();
2299 quote! {
2300 {
2301 let items: Vec<serde_json::Value> = vec![#(#item_schemas),*];
2302 serde_json::json!({
2303 "type": "array",
2304 "prefixItems": items,
2305 "minItems": #num_items,
2306 "maxItems": #num_items,
2307 })
2308 }
2309 }
2310 }
2311 }
2312 syn::Fields::Unit => {
2313 quote! { serde_json::json!({ "type": "null" }) }
2315 }
2316 }
2317}
2318
2319fn generate_enum_schema(data: &syn::DataEnum, type_desc_tokens: &TokenStream2) -> TokenStream2 {
2321 let all_unit = data
2323 .variants
2324 .iter()
2325 .all(|v| matches!(v.fields, syn::Fields::Unit));
2326
2327 if all_unit {
2328 let variant_names: Vec<String> =
2330 data.variants.iter().map(|v| v.ident.to_string()).collect();
2331
2332 quote! {
2333 {
2334 let mut schema = serde_json::json!({
2335 "type": "string",
2336 "enum": [#(#variant_names),*]
2337 });
2338
2339 if let Some(desc) = #type_desc_tokens {
2340 if let Some(obj) = schema.as_object_mut() {
2341 obj.insert("description".to_string(), serde_json::json!(desc));
2342 }
2343 }
2344
2345 schema
2346 }
2347 }
2348 } else {
2349 let variant_schemas: Vec<TokenStream2> = data
2351 .variants
2352 .iter()
2353 .map(|variant| {
2354 let variant_name = variant.ident.to_string();
2355 match &variant.fields {
2356 syn::Fields::Unit => {
2357 quote! {
2358 serde_json::json!({
2359 "type": "object",
2360 "properties": {
2361 #variant_name: { "type": "null" }
2362 },
2363 "required": [#variant_name]
2364 })
2365 }
2366 }
2367 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
2368 let inner_type = &fields.unnamed.first().unwrap().ty;
2369 let inner_schema = type_to_json_schema(inner_type);
2370 quote! {
2371 serde_json::json!({
2372 "type": "object",
2373 "properties": {
2374 #variant_name: #inner_schema
2375 },
2376 "required": [#variant_name]
2377 })
2378 }
2379 }
2380 _ => {
2381 quote! {
2383 serde_json::json!({
2384 "type": "object",
2385 "properties": {
2386 #variant_name: { "type": "object" }
2387 },
2388 "required": [#variant_name]
2389 })
2390 }
2391 }
2392 }
2393 })
2394 .collect();
2395
2396 quote! {
2397 {
2398 let mut schema = serde_json::json!({
2399 "oneOf": [#(#variant_schemas),*]
2400 });
2401
2402 if let Some(desc) = #type_desc_tokens {
2403 if let Some(obj) = schema.as_object_mut() {
2404 obj.insert("description".to_string(), serde_json::json!(desc));
2405 }
2406 }
2407
2408 schema
2409 }
2410 }
2411 }
2412}
2413
2414fn has_json_schema_attr(attrs: &[Attribute], attr_name: &str) -> bool {
2416 for attr in attrs {
2417 if attr.path().is_ident("json_schema") {
2418 if let Meta::List(meta_list) = &attr.meta {
2419 if let Ok(nested) = meta_list.parse_args::<Ident>() {
2420 if nested == attr_name {
2421 return true;
2422 }
2423 }
2424 }
2425 }
2426 }
2427 false
2428}
2429
2430fn get_json_schema_rename(attrs: &[Attribute]) -> Option<String> {
2432 for attr in attrs {
2433 if attr.path().is_ident("json_schema") {
2434 if let Meta::List(meta_list) = &attr.meta {
2435 let result: syn::Result<(Ident, LitStr)> =
2437 meta_list.parse_args_with(|input: ParseStream| {
2438 let ident: Ident = input.parse()?;
2439 let _: Token![=] = input.parse()?;
2440 let lit: LitStr = input.parse()?;
2441 Ok((ident, lit))
2442 });
2443
2444 if let Ok((ident, lit)) = result {
2445 if ident == "rename" {
2446 return Some(lit.value());
2447 }
2448 }
2449 }
2450 }
2451 }
2452 None
2453}