Skip to main content

fastmcp_derive/
lib.rs

1//! Procedural macros for FastMCP.
2//!
3//! This crate provides attribute macros for defining MCP handlers:
4//! - `#[tool]` - Define a tool handler
5//! - `#[resource]` - Define a resource handler
6//! - `#[prompt]` - Define a prompt handler
7//!
8//! # Example
9//!
10//! ```ignore
11//! use fastmcp::prelude::*;
12//!
13//! /// Greets a user by name.
14//! #[tool]
15//! async fn greet(
16//!     ctx: &McpContext,
17//!     /// The name to greet
18//!     name: String,
19//! ) -> String {
20//!     format!("Hello, {name}!")
21//! }
22//!
23//! /// A configuration file resource.
24//! #[resource(uri = "config://app")]
25//! async fn app_config(ctx: &McpContext) -> String {
26//!     std::fs::read_to_string("config.json").unwrap()
27//! }
28//!
29//! /// A code review prompt.
30//! #[prompt]
31//! async fn code_review(
32//!     ctx: &McpContext,
33//!     /// The code to review
34//!     code: String,
35//! ) -> Vec<PromptMessage> {
36//!     vec![PromptMessage {
37//!         role: Role::User,
38//!         content: Content::Text { text: format!("Review this code:\n\n{code}") },
39//!     }]
40//! }
41//! ```
42//!
43//! # Role in the System
44//!
45//! `fastmcp-derive` is the **ergonomics layer** of FastMCP. The attribute
46//! macros expand handler functions into the trait implementations used by
47//! `fastmcp-server`, and they also generate JSON Schema metadata consumed by
48//! `fastmcp-protocol` during tool registration.
49//!
50//! Most users never need to depend on this crate directly; it is re-exported
51//! by the `fastmcp` façade for `use fastmcp::prelude::*`.
52
53#![forbid(unsafe_code)]
54
55use proc_macro::TokenStream;
56use proc_macro2::TokenStream as TokenStream2;
57use quote::{format_ident, quote};
58use syn::{
59    Attribute, FnArg, Ident, ItemFn, Lit, LitStr, Meta, Pat, Token, Type, parse::Parse,
60    parse::ParseStream, parse_macro_input,
61};
62
63/// Extracts documentation from attributes.
64fn extract_doc_comments(attrs: &[Attribute]) -> Option<String> {
65    let docs: Vec<String> = attrs
66        .iter()
67        .filter_map(|attr| {
68            if attr.path().is_ident("doc") {
69                if let Meta::NameValue(nv) = &attr.meta {
70                    if let syn::Expr::Lit(syn::ExprLit {
71                        lit: Lit::Str(s), ..
72                    }) = &nv.value
73                    {
74                        return Some(s.value().trim().to_string());
75                    }
76                }
77            }
78            None
79        })
80        .collect();
81
82    if docs.is_empty() {
83        None
84    } else {
85        Some(docs.join("\n"))
86    }
87}
88
89/// Checks if a type is `&McpContext`.
90fn is_mcp_context_ref(ty: &Type) -> bool {
91    if let Type::Reference(type_ref) = ty {
92        if let Type::Path(type_path) = type_ref.elem.as_ref() {
93            return type_path
94                .path
95                .segments
96                .last()
97                .is_some_and(|s| s.ident == "McpContext");
98        }
99    }
100    false
101}
102
103/// Checks if a type is `Option<T>`.
104fn is_option_type(ty: &Type) -> bool {
105    if let Type::Path(type_path) = ty {
106        return type_path
107            .path
108            .segments
109            .last()
110            .is_some_and(|s| s.ident == "Option");
111    }
112    false
113}
114
115/// Returns the inner type if `ty` is `Option<T>`.
116fn option_inner_type(ty: &Type) -> Option<&Type> {
117    if let Type::Path(type_path) = ty {
118        if let Some(segment) = type_path.path.segments.last() {
119            if segment.ident == "Option" {
120                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
121                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
122                        return Some(inner_ty);
123                    }
124                }
125            }
126        }
127    }
128    None
129}
130
131/// Returns true if the type is `String`.
132fn is_string_type(ty: &Type) -> bool {
133    if let Type::Path(type_path) = ty {
134        return type_path
135            .path
136            .segments
137            .last()
138            .is_some_and(|s| s.ident == "String");
139    }
140    false
141}
142
143/// Parses a human-readable duration string and returns milliseconds.
144///
145/// Supports formats like "30s", "5m", "1h", "500ms", "1h30m".
146fn parse_duration_to_millis(s: &str) -> Result<u64, String> {
147    let s = s.trim();
148    if s.is_empty() {
149        return Err("empty string".to_string());
150    }
151
152    let mut total_millis: u64 = 0;
153    let mut current_num = String::new();
154    let mut chars = s.chars().peekable();
155
156    while let Some(c) = chars.next() {
157        if c.is_ascii_digit() {
158            current_num.push(c);
159        } else if c.is_ascii_alphabetic() {
160            if current_num.is_empty() {
161                return Err(format!(
162                    "unexpected unit character '{c}' without preceding number"
163                ));
164            }
165
166            let num: u64 = current_num
167                .parse()
168                .map_err(|_| format!("invalid number: {current_num}"))?;
169
170            // Check for multi-character units (ms)
171            let unit = if c == 'm' && chars.peek() == Some(&'s') {
172                chars.next(); // consume 's'
173                "ms"
174            } else {
175                // Single character unit
176                match c {
177                    'h' => "h",
178                    'm' => "m",
179                    's' => "s",
180                    _ => return Err(format!("unknown unit '{c}'")),
181                }
182            };
183
184            let millis = match unit {
185                "ms" => num,
186                "s" => num * 1000,
187                "m" => num * 60 * 1000,
188                "h" => num * 60 * 60 * 1000,
189                _ => unreachable!(),
190            };
191
192            total_millis = total_millis.saturating_add(millis);
193            current_num.clear();
194        } else if c.is_whitespace() {
195            continue;
196        } else {
197            return Err(format!("unexpected character '{c}'"));
198        }
199    }
200
201    if !current_num.is_empty() {
202        return Err(format!(
203            "number '{current_num}' missing unit (use s, m, h, or ms)"
204        ));
205    }
206
207    if total_millis == 0 {
208        return Err("duration must be greater than zero".to_string());
209    }
210
211    Ok(total_millis)
212}
213
214/// Extracts template parameter names from a URI template string.
215fn extract_template_params(uri: &str) -> Vec<String> {
216    let mut params = Vec::new();
217    let mut chars = uri.chars();
218
219    while let Some(ch) = chars.next() {
220        if ch == '{' {
221            let mut name = String::new();
222            for next in chars.by_ref() {
223                if next == '}' {
224                    break;
225                }
226                name.push(next);
227            }
228            if !name.is_empty() {
229                params.push(name);
230            }
231        }
232    }
233
234    params
235}
236
237/// Converts a snake_case identifier to PascalCase.
238fn to_pascal_case(s: &str) -> String {
239    s.split('_')
240        .map(|word| {
241            let mut chars = word.chars();
242            match chars.next() {
243                Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
244                None => String::new(),
245            }
246        })
247        .collect()
248}
249
250/// Represents different return type conversion strategies.
251enum ReturnTypeKind {
252    /// Returns Vec<Content> directly
253    VecContent,
254    /// Returns String, wrap in Content::Text
255    String,
256    /// Returns Result<T, E> - need to unwrap and convert T
257    ResultVecContent,
258    /// Returns Result<String, E> - unwrap and wrap in Content::Text
259    ResultString,
260    /// Returns McpResult<Vec<Content>>
261    McpResultVecContent,
262    /// Returns McpResult<String>
263    McpResultString,
264    /// Unknown type - try to convert via Display or Debug
265    Other,
266    /// Unit type () - return empty content
267    Unit,
268}
269
270/// Analyzes a function's return type and determines conversion strategy.
271fn analyze_return_type(output: &syn::ReturnType) -> ReturnTypeKind {
272    match output {
273        syn::ReturnType::Default => ReturnTypeKind::Unit,
274        syn::ReturnType::Type(_, ty) => analyze_type(ty),
275    }
276}
277
278/// Analyzes a type and determines what kind of return it is.
279fn analyze_type(ty: &Type) -> ReturnTypeKind {
280    if let Type::Path(type_path) = ty {
281        if let Some(segment) = type_path.path.segments.last() {
282            let type_name = segment.ident.to_string();
283
284            match type_name.as_str() {
285                "String" => return ReturnTypeKind::String,
286                "Vec" => {
287                    // Check if it's Vec<Content>
288                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
289                        if let Some(syn::GenericArgument::Type(Type::Path(inner_path))) =
290                            args.args.first()
291                        {
292                            if inner_path
293                                .path
294                                .segments
295                                .last()
296                                .is_some_and(|s| s.ident == "Content")
297                            {
298                                return ReturnTypeKind::VecContent;
299                            }
300                        }
301                    }
302                }
303                "Result" | "McpResult" => {
304                    // Check the Ok type
305                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
306                        if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
307                            let inner_kind = analyze_type(inner_ty);
308                            return match inner_kind {
309                                ReturnTypeKind::VecContent => {
310                                    if type_name == "McpResult" {
311                                        ReturnTypeKind::McpResultVecContent
312                                    } else {
313                                        ReturnTypeKind::ResultVecContent
314                                    }
315                                }
316                                ReturnTypeKind::String => {
317                                    if type_name == "McpResult" {
318                                        ReturnTypeKind::McpResultString
319                                    } else {
320                                        ReturnTypeKind::ResultString
321                                    }
322                                }
323                                _ => ReturnTypeKind::Other,
324                            };
325                        }
326                    }
327                }
328                _ => {}
329            }
330        }
331    }
332    ReturnTypeKind::Other
333}
334
335/// Generates code to convert a function result to Vec<Content>.
336fn generate_result_conversion(output: &syn::ReturnType) -> TokenStream2 {
337    let kind = analyze_return_type(output);
338
339    match kind {
340        ReturnTypeKind::Unit => quote! {
341            Ok(vec![])
342        },
343        ReturnTypeKind::VecContent => quote! {
344            Ok(result)
345        },
346        ReturnTypeKind::String => quote! {
347            Ok(vec![fastmcp_protocol::Content::Text { text: result }])
348        },
349        ReturnTypeKind::ResultVecContent | ReturnTypeKind::McpResultVecContent => quote! {
350            result.map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))
351        },
352        ReturnTypeKind::ResultString | ReturnTypeKind::McpResultString => quote! {
353            result
354                .map(|s| vec![fastmcp_protocol::Content::Text { text: s }])
355                .map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))
356        },
357        ReturnTypeKind::Other => quote! {
358            // Convert via ToString or Debug as fallback
359            let text = format!("{}", result);
360            Ok(vec![fastmcp_protocol::Content::Text { text }])
361        },
362    }
363}
364
365// ============================================================================
366// Prompt Return Type Analysis
367// ============================================================================
368
369/// Represents return type strategies for prompt handlers.
370enum PromptReturnTypeKind {
371    /// Returns Vec<PromptMessage> directly
372    VecPromptMessage,
373    /// Returns Result<Vec<PromptMessage>, E>
374    ResultVecPromptMessage,
375    /// Returns McpResult<Vec<PromptMessage>>
376    McpResultVecPromptMessage,
377    /// Unknown type - will fail at compile time
378    Other,
379}
380
381/// Analyzes a prompt function's return type.
382fn analyze_prompt_return_type(output: &syn::ReturnType) -> PromptReturnTypeKind {
383    match output {
384        syn::ReturnType::Default => PromptReturnTypeKind::Other, // () not valid for prompts
385        syn::ReturnType::Type(_, ty) => analyze_prompt_type(ty),
386    }
387}
388
389/// Analyzes a type for prompt return type classification.
390fn analyze_prompt_type(ty: &Type) -> PromptReturnTypeKind {
391    if let Type::Path(type_path) = ty {
392        if let Some(segment) = type_path.path.segments.last() {
393            let type_name = segment.ident.to_string();
394
395            match type_name.as_str() {
396                "Vec" => {
397                    // Check if it's Vec<PromptMessage>
398                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
399                        if let Some(syn::GenericArgument::Type(Type::Path(inner_path))) =
400                            args.args.first()
401                        {
402                            if inner_path
403                                .path
404                                .segments
405                                .last()
406                                .is_some_and(|s| s.ident == "PromptMessage")
407                            {
408                                return PromptReturnTypeKind::VecPromptMessage;
409                            }
410                        }
411                    }
412                }
413                "Result" | "McpResult" => {
414                    // Check the Ok type
415                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
416                        if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
417                            let inner_kind = analyze_prompt_type(inner_ty);
418                            return match inner_kind {
419                                PromptReturnTypeKind::VecPromptMessage => {
420                                    if type_name == "McpResult" {
421                                        PromptReturnTypeKind::McpResultVecPromptMessage
422                                    } else {
423                                        PromptReturnTypeKind::ResultVecPromptMessage
424                                    }
425                                }
426                                _ => PromptReturnTypeKind::Other,
427                            };
428                        }
429                    }
430                }
431                _ => {}
432            }
433        }
434    }
435    PromptReturnTypeKind::Other
436}
437
438/// Generates code to convert a prompt function result to McpResult<Vec<PromptMessage>>.
439fn generate_prompt_result_conversion(output: &syn::ReturnType) -> TokenStream2 {
440    let kind = analyze_prompt_return_type(output);
441
442    match kind {
443        PromptReturnTypeKind::VecPromptMessage => quote! {
444            Ok(result)
445        },
446        PromptReturnTypeKind::ResultVecPromptMessage
447        | PromptReturnTypeKind::McpResultVecPromptMessage => quote! {
448            result.map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))
449        },
450        PromptReturnTypeKind::Other => quote! {
451            // Fallback: assume the result is Vec<PromptMessage>
452            Ok(result)
453        },
454    }
455}
456
457// ============================================================================
458// Resource Return Type Analysis
459// ============================================================================
460
461/// Represents return type strategies for resource handlers.
462enum ResourceReturnTypeKind {
463    /// Returns String directly
464    String,
465    /// Returns Result<String, E>
466    ResultString,
467    /// Returns McpResult<String>
468    McpResultString,
469    /// Unknown type - use ToString
470    Other,
471}
472
473/// Analyzes a resource function's return type.
474fn analyze_resource_return_type(output: &syn::ReturnType) -> ResourceReturnTypeKind {
475    match output {
476        syn::ReturnType::Default => ResourceReturnTypeKind::Other, // () not typical for resources
477        syn::ReturnType::Type(_, ty) => analyze_resource_type(ty),
478    }
479}
480
481/// Analyzes a type for resource return type classification.
482fn analyze_resource_type(ty: &Type) -> ResourceReturnTypeKind {
483    if let Type::Path(type_path) = ty {
484        if let Some(segment) = type_path.path.segments.last() {
485            let type_name = segment.ident.to_string();
486
487            match type_name.as_str() {
488                "String" => return ResourceReturnTypeKind::String,
489                "Result" | "McpResult" => {
490                    // Check the Ok type
491                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
492                        if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
493                            let inner_kind = analyze_resource_type(inner_ty);
494                            return match inner_kind {
495                                ResourceReturnTypeKind::String => {
496                                    if type_name == "McpResult" {
497                                        ResourceReturnTypeKind::McpResultString
498                                    } else {
499                                        ResourceReturnTypeKind::ResultString
500                                    }
501                                }
502                                _ => ResourceReturnTypeKind::Other,
503                            };
504                        }
505                    }
506                }
507                _ => {}
508            }
509        }
510    }
511    ResourceReturnTypeKind::Other
512}
513
514/// Generates code to convert a resource function result to McpResult<Vec<ResourceContent>>.
515///
516/// The generated code handles:
517/// - `String` → wrap in ResourceContent
518/// - `Result<String, E>` → unwrap result, then wrap in ResourceContent
519/// - `McpResult<String>` → unwrap result, then wrap in ResourceContent
520/// - Other types → use ToString trait
521///
522/// The generated code uses `uri` and `mime_type` variables that must be in scope.
523fn generate_resource_result_conversion(output: &syn::ReturnType, mime_type: &str) -> TokenStream2 {
524    let kind = analyze_resource_return_type(output);
525
526    match kind {
527        ResourceReturnTypeKind::String => quote! {
528            let text = result;
529            Ok(vec![fastmcp_protocol::ResourceContent {
530                uri: uri.to_string(),
531                mime_type: Some(#mime_type.to_string()),
532                text: Some(text),
533                blob: None,
534            }])
535        },
536        ResourceReturnTypeKind::ResultString | ResourceReturnTypeKind::McpResultString => quote! {
537            let text = result.map_err(|e| fastmcp_core::McpError::internal_error(e.to_string()))?;
538            Ok(vec![fastmcp_protocol::ResourceContent {
539                uri: uri.to_string(),
540                mime_type: Some(#mime_type.to_string()),
541                text: Some(text),
542                blob: None,
543            }])
544        },
545        ResourceReturnTypeKind::Other => quote! {
546            // Fallback: use ToString trait
547            let text = result.to_string();
548            Ok(vec![fastmcp_protocol::ResourceContent {
549                uri: uri.to_string(),
550                mime_type: Some(#mime_type.to_string()),
551                text: Some(text),
552                blob: None,
553            }])
554        },
555    }
556}
557
558/// Generates a JSON schema type for a Rust type.
559fn type_to_json_schema(ty: &Type) -> TokenStream2 {
560    let Type::Path(type_path) = ty else {
561        return quote! { serde_json::json!({}) };
562    };
563
564    let segment = type_path.path.segments.last().unwrap();
565    let type_name = segment.ident.to_string();
566
567    match type_name.as_str() {
568        "String" | "str" => quote! {
569            serde_json::json!({ "type": "string" })
570        },
571        "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
572        | "usize" => quote! {
573            serde_json::json!({ "type": "integer" })
574        },
575        "f32" | "f64" => quote! {
576            serde_json::json!({ "type": "number" })
577        },
578        "bool" => quote! {
579            serde_json::json!({ "type": "boolean" })
580        },
581        "Option" => {
582            // For Option<T>, get the inner type
583            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
584                if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
585                    return type_to_json_schema(inner_ty);
586                }
587            }
588            quote! { serde_json::json!({}) }
589        }
590        "Vec" => {
591            // For Vec<T>, create array schema
592            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
593                if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
594                    let inner_schema = type_to_json_schema(inner_ty);
595                    return quote! {
596                        serde_json::json!({
597                            "type": "array",
598                            "items": #inner_schema
599                        })
600                    };
601                }
602            }
603            quote! { serde_json::json!({ "type": "array" }) }
604        }
605        "HashSet" | "BTreeSet" => {
606            // For Set<T>, create array schema with uniqueItems
607            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
608                if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
609                    let inner_schema = type_to_json_schema(inner_ty);
610                    return quote! {
611                        serde_json::json!({
612                            "type": "array",
613                            "items": #inner_schema,
614                            "uniqueItems": true
615                        })
616                    };
617                }
618            }
619            quote! { serde_json::json!({ "type": "array", "uniqueItems": true }) }
620        }
621        "HashMap" | "BTreeMap" => {
622            // For Map<K, V>, create object schema with additionalProperties
623            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
624                // Check if key is String-like (implied for JSON object keys)
625                // We mainly care about the value type (second arg)
626                if args.args.len() >= 2 {
627                    if let Some(syn::GenericArgument::Type(value_ty)) = args.args.iter().nth(1) {
628                        let value_schema = type_to_json_schema(value_ty);
629                        return quote! {
630                            serde_json::json!({
631                                "type": "object",
632                                "additionalProperties": #value_schema
633                            })
634                        };
635                    }
636                }
637            }
638            quote! { serde_json::json!({ "type": "object" }) }
639        }
640        "serde_json::Value" | "Value" => {
641            // Any JSON value
642            quote! { serde_json::json!({}) }
643        }
644        _ => {
645            // For other types, assume they implement a json_schema() method
646            // (e.g. via #[derive(JsonSchema)] or manual implementation)
647            quote! { <#ty>::json_schema() }
648        }
649    }
650}
651
652// ============================================================================
653// Tool Macro
654// ============================================================================
655
656/// Parsed attributes for #[tool].
657struct ToolAttrs {
658    name: Option<String>,
659    description: Option<String>,
660    timeout: Option<String>,
661    /// Output schema as a JSON literal or type name
662    output_schema: Option<syn::Expr>,
663}
664
665impl Parse for ToolAttrs {
666    fn parse(input: ParseStream) -> syn::Result<Self> {
667        let mut name = None;
668        let mut description = None;
669        let mut timeout = None;
670        let mut output_schema = None;
671
672        while !input.is_empty() {
673            let ident: Ident = input.parse()?;
674            input.parse::<Token![=]>()?;
675
676            match ident.to_string().as_str() {
677                "name" => {
678                    let lit: LitStr = input.parse()?;
679                    name = Some(lit.value());
680                }
681                "description" => {
682                    let lit: LitStr = input.parse()?;
683                    description = Some(lit.value());
684                }
685                "timeout" => {
686                    let lit: LitStr = input.parse()?;
687                    timeout = Some(lit.value());
688                }
689                "output_schema" => {
690                    // Accept any expression (json!(...), type name, etc.)
691                    let expr: syn::Expr = input.parse()?;
692                    output_schema = Some(expr);
693                }
694                _ => {
695                    return Err(syn::Error::new(ident.span(), "unknown attribute"));
696                }
697            }
698
699            if !input.is_empty() {
700                input.parse::<Token![,]>()?;
701            }
702        }
703
704        Ok(Self {
705            name,
706            description,
707            timeout,
708            output_schema,
709        })
710    }
711}
712
713/// Defines a tool handler.
714///
715/// The function signature should be:
716/// ```ignore
717/// async fn tool_name(ctx: &McpContext, args...) -> Result
718/// ```
719///
720/// # Attributes
721///
722/// - `name` - Override the tool name (default: function name)
723/// - `description` - Tool description (default: doc comment)
724#[proc_macro_attribute]
725#[allow(clippy::too_many_lines)]
726pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
727    let attrs = parse_macro_input!(attr as ToolAttrs);
728    let input_fn = parse_macro_input!(item as ItemFn);
729
730    let fn_name = &input_fn.sig.ident;
731    let fn_name_str = fn_name.to_string();
732
733    // Generate handler struct name (PascalCase)
734    let handler_name = format_ident!("{}", to_pascal_case(&fn_name_str));
735
736    // Get tool name (from attr or function name)
737    let tool_name = attrs.name.unwrap_or_else(|| fn_name_str.clone());
738
739    // Get description (from attr or doc comments)
740    let description = attrs
741        .description
742        .or_else(|| extract_doc_comments(&input_fn.attrs));
743    let description_tokens = description.as_ref().map_or_else(
744        || quote! { None },
745        |desc| quote! { Some(#desc.to_string()) },
746    );
747
748    // Parse timeout attribute
749    let timeout_tokens = if let Some(ref timeout_str) = attrs.timeout {
750        match parse_duration_to_millis(timeout_str) {
751            Ok(millis) => {
752                quote! {
753                    fn timeout(&self) -> Option<std::time::Duration> {
754                        Some(std::time::Duration::from_millis(#millis))
755                    }
756                }
757            }
758            Err(e) => {
759                return syn::Error::new_spanned(
760                    &input_fn.sig.ident,
761                    format!("invalid timeout: {e}"),
762                )
763                .to_compile_error()
764                .into();
765            }
766        }
767    } else {
768        quote! {}
769    };
770
771    // Parse output_schema attribute
772    let (output_schema_field, output_schema_method) =
773        if let Some(ref schema_expr) = attrs.output_schema {
774            (
775                quote! { Some(#schema_expr) },
776                quote! {
777                    fn output_schema(&self) -> Option<serde_json::Value> {
778                        Some(#schema_expr)
779                    }
780                },
781            )
782        } else {
783            (quote! { None }, quote! {})
784        };
785
786    // Parse parameters (skip first if it's &McpContext)
787    let mut params: Vec<(&Ident, &Type, Option<String>)> = Vec::new();
788    let mut required_params: Vec<String> = Vec::new();
789    let mut expects_context = false;
790
791    for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
792        if let FnArg::Typed(pat_type) = arg {
793            // Skip the first parameter if it looks like a context
794            if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
795                expects_context = true;
796                continue;
797            }
798
799            if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
800                let param_name = &pat_ident.ident;
801                let param_type = pat_type.ty.as_ref();
802                let param_doc = extract_doc_comments(&pat_type.attrs);
803
804                // Check if parameter is required (not Option<T>)
805                let is_optional = is_option_type(param_type);
806
807                if !is_optional {
808                    required_params.push(param_name.to_string());
809                }
810
811                params.push((param_name, param_type, param_doc));
812            }
813        }
814    }
815
816    // Generate JSON schema for input
817    let property_entries: Vec<TokenStream2> = params
818        .iter()
819        .map(|(name, ty, doc)| {
820            let name_str = name.to_string();
821            let schema = type_to_json_schema(ty);
822            if let Some(desc) = doc {
823                quote! {
824                    (#name_str.to_string(), {
825                        let mut s = #schema;
826                        if let Some(obj) = s.as_object_mut() {
827                            obj.insert("description".to_string(), serde_json::json!(#desc));
828                        }
829                        s
830                    })
831                }
832            } else {
833                quote! {
834                    (#name_str.to_string(), #schema)
835                }
836            }
837        })
838        .collect();
839
840    // Generate parameter extraction code
841    let param_extractions: Vec<TokenStream2> = params
842        .iter()
843        .map(|(name, ty, _)| {
844            let name_str = name.to_string();
845            let is_optional = is_option_type(ty);
846
847            if is_optional {
848                quote! {
849                    let #name: #ty = match arguments.get(#name_str) {
850                        Some(value) => Some(
851                            serde_json::from_value(value.clone()).map_err(|e| {
852                                fastmcp_core::McpError::invalid_params(e.to_string())
853                            })?,
854                        ),
855                        None => None,
856                    };
857                }
858            } else {
859                quote! {
860                    let #name: #ty = arguments.get(#name_str)
861                        .ok_or_else(|| fastmcp_core::McpError::invalid_params(
862                            format!("missing required parameter: {}", #name_str)
863                        ))
864                        .and_then(|v| serde_json::from_value(v.clone())
865                            .map_err(|e| fastmcp_core::McpError::invalid_params(e.to_string())))?;
866                }
867            }
868        })
869        .collect();
870
871    // Generate parameter names for function call
872    let param_names: Vec<&Ident> = params.iter().map(|(name, _, _)| *name).collect();
873
874    // Check if function is async
875    let is_async = input_fn.sig.asyncness.is_some();
876
877    // Analyze return type to determine conversion strategy
878    let return_type = &input_fn.sig.output;
879    let result_conversion = generate_result_conversion(return_type);
880
881    // Generate the call expression (async functions are executed via block_on)
882    let call_expr = if is_async {
883        if expects_context {
884            quote! {
885                fastmcp_core::runtime::block_on(async move {
886                    #fn_name(ctx, #(#param_names),*).await
887                })
888            }
889        } else {
890            quote! {
891                fastmcp_core::runtime::block_on(async move {
892                    #fn_name(#(#param_names),*).await
893                })
894            }
895        }
896    } else {
897        if expects_context {
898            quote! {
899                #fn_name(ctx, #(#param_names),*)
900            }
901        } else {
902            quote! {
903                #fn_name(#(#param_names),*)
904            }
905        }
906    };
907
908    // Generate the handler implementation
909    let expanded = quote! {
910        // Keep the original function
911        #input_fn
912
913        /// Handler for the #fn_name tool.
914        #[derive(Clone)]
915        pub struct #handler_name;
916
917        impl fastmcp_server::ToolHandler for #handler_name {
918            fn definition(&self) -> fastmcp_protocol::Tool {
919                let properties: std::collections::HashMap<String, serde_json::Value> = vec![
920                    #(#property_entries),*
921                ].into_iter().collect();
922
923                let required: Vec<String> = vec![#(#required_params.to_string()),*];
924
925                fastmcp_protocol::Tool {
926                    name: #tool_name.to_string(),
927                    description: #description_tokens,
928                    input_schema: serde_json::json!({
929                        "type": "object",
930                        "properties": properties,
931                        "required": required,
932                    }),
933                    output_schema: #output_schema_field,
934                    icon: None,
935                    version: None,
936                    tags: vec![],
937                    annotations: None,
938                }
939            }
940
941            #timeout_tokens
942
943            #output_schema_method
944
945            fn call(
946                &self,
947                ctx: &fastmcp_core::McpContext,
948                arguments: serde_json::Value,
949            ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::Content>> {
950                // Parse arguments as object
951                let arguments = arguments.as_object()
952                    .cloned()
953                    .unwrap_or_default();
954
955                // Extract parameters
956                #(#param_extractions)*
957
958                // Call the function
959                let result = #call_expr;
960
961                // Convert result to Vec<Content> based on return type
962                #result_conversion
963            }
964        }
965    };
966
967    TokenStream::from(expanded)
968}
969
970// ============================================================================
971// Resource Macro
972// ============================================================================
973
974/// Parsed attributes for #[resource].
975struct ResourceAttrs {
976    uri: Option<String>,
977    name: Option<String>,
978    description: Option<String>,
979    mime_type: Option<String>,
980    timeout: Option<String>,
981}
982
983impl Parse for ResourceAttrs {
984    fn parse(input: ParseStream) -> syn::Result<Self> {
985        let mut uri = None;
986        let mut name = None;
987        let mut description = None;
988        let mut mime_type = None;
989        let mut timeout = None;
990
991        while !input.is_empty() {
992            let ident: Ident = input.parse()?;
993            input.parse::<Token![=]>()?;
994
995            match ident.to_string().as_str() {
996                "uri" => {
997                    let lit: LitStr = input.parse()?;
998                    uri = Some(lit.value());
999                }
1000                "name" => {
1001                    let lit: LitStr = input.parse()?;
1002                    name = Some(lit.value());
1003                }
1004                "description" => {
1005                    let lit: LitStr = input.parse()?;
1006                    description = Some(lit.value());
1007                }
1008                "mime_type" => {
1009                    let lit: LitStr = input.parse()?;
1010                    mime_type = Some(lit.value());
1011                }
1012                "timeout" => {
1013                    let lit: LitStr = input.parse()?;
1014                    timeout = Some(lit.value());
1015                }
1016                _ => {
1017                    return Err(syn::Error::new(ident.span(), "unknown attribute"));
1018                }
1019            }
1020
1021            if !input.is_empty() {
1022                input.parse::<Token![,]>()?;
1023            }
1024        }
1025
1026        Ok(Self {
1027            uri,
1028            name,
1029            description,
1030            mime_type,
1031            timeout,
1032        })
1033    }
1034}
1035
1036/// Defines a resource handler.
1037///
1038/// # Attributes
1039///
1040/// - `uri` - The resource URI (required)
1041/// - `name` - Display name (default: function name)
1042/// - `description` - Resource description (default: doc comment)
1043/// - `mime_type` - MIME type (default: "text/plain")
1044#[proc_macro_attribute]
1045#[allow(clippy::too_many_lines)]
1046pub fn resource(attr: TokenStream, item: TokenStream) -> TokenStream {
1047    let attrs = parse_macro_input!(attr as ResourceAttrs);
1048    let input_fn = parse_macro_input!(item as ItemFn);
1049
1050    let fn_name = &input_fn.sig.ident;
1051    let fn_name_str = fn_name.to_string();
1052
1053    // Generate handler struct name
1054    let handler_name = format_ident!("{}Resource", to_pascal_case(&fn_name_str));
1055
1056    // Get resource URI (required)
1057    let Some(uri) = attrs.uri else {
1058        return syn::Error::new_spanned(&input_fn.sig.ident, "resource requires uri attribute")
1059            .to_compile_error()
1060            .into();
1061    };
1062
1063    // Get name and description
1064    let resource_name = attrs.name.unwrap_or_else(|| fn_name_str.clone());
1065    let description = attrs
1066        .description
1067        .or_else(|| extract_doc_comments(&input_fn.attrs));
1068    let mime_type = attrs.mime_type.unwrap_or_else(|| "text/plain".to_string());
1069
1070    let description_tokens = description.as_ref().map_or_else(
1071        || quote! { None },
1072        |desc| quote! { Some(#desc.to_string()) },
1073    );
1074
1075    // Parse timeout attribute
1076    let timeout_tokens = if let Some(ref timeout_str) = attrs.timeout {
1077        match parse_duration_to_millis(timeout_str) {
1078            Ok(millis) => {
1079                quote! {
1080                    fn timeout(&self) -> Option<std::time::Duration> {
1081                        Some(std::time::Duration::from_millis(#millis))
1082                    }
1083                }
1084            }
1085            Err(e) => {
1086                return syn::Error::new_spanned(
1087                    &input_fn.sig.ident,
1088                    format!("invalid timeout: {e}"),
1089                )
1090                .to_compile_error()
1091                .into();
1092            }
1093        }
1094    } else {
1095        quote! {}
1096    };
1097
1098    let template_params = extract_template_params(&uri);
1099
1100    // Parse parameters (skip first if it's &McpContext)
1101    let mut params: Vec<(&Ident, &Type)> = Vec::new();
1102    let mut expects_context = false;
1103
1104    for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
1105        if let FnArg::Typed(pat_type) = arg {
1106            if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
1107                expects_context = true;
1108                continue;
1109            }
1110
1111            if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1112                let param_name = &pat_ident.ident;
1113                let param_type = pat_type.ty.as_ref();
1114                params.push((param_name, param_type));
1115            }
1116        }
1117    }
1118
1119    if template_params.is_empty() && !params.is_empty() {
1120        return syn::Error::new_spanned(
1121            &input_fn.sig.ident,
1122            "resource parameters require a URI template with matching {params}",
1123        )
1124        .to_compile_error()
1125        .into();
1126    }
1127
1128    let missing_params: Vec<String> = params
1129        .iter()
1130        .map(|(name, _)| name.to_string())
1131        .filter(|name| !template_params.contains(name))
1132        .collect();
1133
1134    if !missing_params.is_empty() {
1135        return syn::Error::new_spanned(
1136            &input_fn.sig.ident,
1137            format!(
1138                "resource parameters missing from uri template: {}",
1139                missing_params.join(", ")
1140            ),
1141        )
1142        .to_compile_error()
1143        .into();
1144    }
1145
1146    let is_template = !template_params.is_empty();
1147
1148    let param_extractions: Vec<TokenStream2> = params
1149        .iter()
1150        .map(|(name, ty)| {
1151            let name_str = name.to_string();
1152            if let Some(inner_ty) = option_inner_type(ty) {
1153                if is_string_type(inner_ty) {
1154                    quote! {
1155                        let #name: #ty = uri_params.get(#name_str).cloned();
1156                    }
1157                } else {
1158                    quote! {
1159                        let #name: #ty = match uri_params.get(#name_str) {
1160                            Some(value) => Some(value.parse().map_err(|_| {
1161                                fastmcp_core::McpError::invalid_params(
1162                                    format!("invalid uri parameter: {}", #name_str)
1163                                )
1164                            })?),
1165                            None => None,
1166                        };
1167                    }
1168                }
1169            } else if is_string_type(ty) {
1170                quote! {
1171                    let #name: #ty = uri_params
1172                        .get(#name_str)
1173                        .ok_or_else(|| fastmcp_core::McpError::invalid_params(
1174                            format!("missing uri parameter: {}", #name_str)
1175                        ))?
1176                        .clone();
1177                }
1178            } else {
1179                quote! {
1180                    let #name: #ty = uri_params
1181                        .get(#name_str)
1182                        .ok_or_else(|| fastmcp_core::McpError::invalid_params(
1183                            format!("missing uri parameter: {}", #name_str)
1184                        ))?
1185                        .parse()
1186                        .map_err(|_| fastmcp_core::McpError::invalid_params(
1187                            format!("invalid uri parameter: {}", #name_str)
1188                        ))?;
1189                }
1190            }
1191        })
1192        .collect();
1193
1194    let param_names: Vec<&Ident> = params.iter().map(|(name, _)| *name).collect();
1195    let call_args = if expects_context {
1196        quote! { ctx, #(#param_names),* }
1197    } else {
1198        quote! { #(#param_names),* }
1199    };
1200
1201    let is_async = input_fn.sig.asyncness.is_some();
1202    let call_expr = if is_async {
1203        quote! {
1204            fastmcp_core::runtime::block_on(async move {
1205                #fn_name(#call_args).await
1206            })
1207        }
1208    } else {
1209        quote! {
1210            #fn_name(#call_args)
1211        }
1212    };
1213
1214    let template_tokens = if is_template {
1215        quote! {
1216            Some(fastmcp_protocol::ResourceTemplate {
1217                uri_template: #uri.to_string(),
1218                name: #resource_name.to_string(),
1219                description: #description_tokens,
1220                mime_type: Some(#mime_type.to_string()),
1221                icon: None,
1222                version: None,
1223                tags: vec![],
1224            })
1225        }
1226    } else {
1227        quote! { None }
1228    };
1229
1230    // Generate result conversion based on return type (supports Result<String, E>)
1231    let return_type = &input_fn.sig.output;
1232    let resource_result_conversion = generate_resource_result_conversion(return_type, &mime_type);
1233
1234    let expanded = quote! {
1235        // Keep the original function
1236        #input_fn
1237
1238        /// Handler for the #fn_name resource.
1239        #[derive(Clone)]
1240        pub struct #handler_name;
1241
1242        impl fastmcp_server::ResourceHandler for #handler_name {
1243            fn definition(&self) -> fastmcp_protocol::Resource {
1244                fastmcp_protocol::Resource {
1245                    uri: #uri.to_string(),
1246                    name: #resource_name.to_string(),
1247                    description: #description_tokens,
1248                    mime_type: Some(#mime_type.to_string()),
1249                    icon: None,
1250                    version: None,
1251                    tags: vec![],
1252                }
1253            }
1254
1255            fn template(&self) -> Option<fastmcp_protocol::ResourceTemplate> {
1256                #template_tokens
1257            }
1258
1259            #timeout_tokens
1260
1261            fn read(
1262                &self,
1263                ctx: &fastmcp_core::McpContext,
1264            ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::ResourceContent>> {
1265                let uri_params = std::collections::HashMap::new();
1266                self.read_with_uri(ctx, #uri, &uri_params)
1267            }
1268
1269            fn read_with_uri(
1270                &self,
1271                ctx: &fastmcp_core::McpContext,
1272                uri: &str,
1273                uri_params: &std::collections::HashMap<String, String>,
1274            ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::ResourceContent>> {
1275                #(#param_extractions)*
1276                let result = #call_expr;
1277                #resource_result_conversion
1278            }
1279
1280            fn read_async_with_uri<'a>(
1281                &'a self,
1282                ctx: &'a fastmcp_core::McpContext,
1283                uri: &'a str,
1284                uri_params: &'a std::collections::HashMap<String, String>,
1285            ) -> fastmcp_server::BoxFuture<'a, fastmcp_core::McpOutcome<Vec<fastmcp_protocol::ResourceContent>>> {
1286                Box::pin(async move {
1287                    match self.read_with_uri(ctx, uri, uri_params) {
1288                        Ok(value) => fastmcp_core::Outcome::Ok(value),
1289                        Err(error) => fastmcp_core::Outcome::Err(error),
1290                    }
1291                })
1292            }
1293        }
1294    };
1295
1296    TokenStream::from(expanded)
1297}
1298
1299// ============================================================================
1300// Prompt Macro
1301// ============================================================================
1302
1303/// Parsed attributes for #[prompt].
1304struct PromptAttrs {
1305    name: Option<String>,
1306    description: Option<String>,
1307    timeout: Option<String>,
1308}
1309
1310impl Parse for PromptAttrs {
1311    fn parse(input: ParseStream) -> syn::Result<Self> {
1312        let mut name = None;
1313        let mut description = None;
1314        let mut timeout = None;
1315
1316        while !input.is_empty() {
1317            let ident: Ident = input.parse()?;
1318            input.parse::<Token![=]>()?;
1319
1320            match ident.to_string().as_str() {
1321                "name" => {
1322                    let lit: LitStr = input.parse()?;
1323                    name = Some(lit.value());
1324                }
1325                "description" => {
1326                    let lit: LitStr = input.parse()?;
1327                    description = Some(lit.value());
1328                }
1329                "timeout" => {
1330                    let lit: LitStr = input.parse()?;
1331                    timeout = Some(lit.value());
1332                }
1333                _ => {
1334                    return Err(syn::Error::new(ident.span(), "unknown attribute"));
1335                }
1336            }
1337
1338            if !input.is_empty() {
1339                input.parse::<Token![,]>()?;
1340            }
1341        }
1342
1343        Ok(Self {
1344            name,
1345            description,
1346            timeout,
1347        })
1348    }
1349}
1350
1351/// Defines a prompt handler.
1352///
1353/// # Attributes
1354///
1355/// - `name` - Override the prompt name (default: function name)
1356/// - `description` - Prompt description (default: doc comment)
1357#[proc_macro_attribute]
1358#[allow(clippy::too_many_lines)]
1359pub fn prompt(attr: TokenStream, item: TokenStream) -> TokenStream {
1360    let attrs = parse_macro_input!(attr as PromptAttrs);
1361    let input_fn = parse_macro_input!(item as ItemFn);
1362
1363    let fn_name = &input_fn.sig.ident;
1364    let fn_name_str = fn_name.to_string();
1365
1366    // Generate handler struct name
1367    let handler_name = format_ident!("{}Prompt", to_pascal_case(&fn_name_str));
1368
1369    // Get prompt name
1370    let prompt_name = attrs.name.unwrap_or_else(|| fn_name_str.clone());
1371
1372    // Get description
1373    let description = attrs
1374        .description
1375        .or_else(|| extract_doc_comments(&input_fn.attrs));
1376    let description_tokens = description.as_ref().map_or_else(
1377        || quote! { None },
1378        |desc| quote! { Some(#desc.to_string()) },
1379    );
1380
1381    // Parse timeout attribute
1382    let timeout_tokens = if let Some(ref timeout_str) = attrs.timeout {
1383        match parse_duration_to_millis(timeout_str) {
1384            Ok(millis) => {
1385                quote! {
1386                    fn timeout(&self) -> Option<std::time::Duration> {
1387                        Some(std::time::Duration::from_millis(#millis))
1388                    }
1389                }
1390            }
1391            Err(e) => {
1392                return syn::Error::new_spanned(
1393                    &input_fn.sig.ident,
1394                    format!("invalid timeout: {e}"),
1395                )
1396                .to_compile_error()
1397                .into();
1398            }
1399        }
1400    } else {
1401        quote! {}
1402    };
1403
1404    // Parse parameters for prompt arguments (skip first if it's &McpContext)
1405    let mut prompt_args: Vec<TokenStream2> = Vec::new();
1406    let mut expects_context = false;
1407
1408    for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
1409        if let FnArg::Typed(pat_type) = arg {
1410            // Skip the context parameter
1411            if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
1412                expects_context = true;
1413                continue;
1414            }
1415
1416            if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1417                let param_name = pat_ident.ident.to_string();
1418                let param_doc = extract_doc_comments(&pat_type.attrs);
1419                let is_optional = is_option_type(pat_type.ty.as_ref());
1420
1421                let desc_tokens = param_doc
1422                    .as_ref()
1423                    .map_or_else(|| quote! { None }, |d| quote! { Some(#d.to_string()) });
1424
1425                prompt_args.push(quote! {
1426                    fastmcp_protocol::PromptArgument {
1427                        name: #param_name.to_string(),
1428                        description: #desc_tokens,
1429                        required: !#is_optional,
1430                    }
1431                });
1432            }
1433        }
1434    }
1435
1436    // Generate parameter extraction for the get method
1437    let mut param_extractions: Vec<TokenStream2> = Vec::new();
1438    let mut param_names: Vec<Ident> = Vec::new();
1439
1440    for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
1441        if let FnArg::Typed(pat_type) = arg {
1442            // Skip context
1443            if i == 0 && is_mcp_context_ref(pat_type.ty.as_ref()) {
1444                continue;
1445            }
1446
1447            if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1448                let param_name = &pat_ident.ident;
1449                let param_name_str = param_name.to_string();
1450                let is_optional = is_option_type(pat_type.ty.as_ref());
1451
1452                param_names.push(param_name.clone());
1453
1454                if is_optional {
1455                    // Optional parameters: return None if not provided
1456                    param_extractions.push(quote! {
1457                        let #param_name = arguments.get(#param_name_str).cloned();
1458                    });
1459                } else {
1460                    // Required parameters: return an error if missing
1461                    param_extractions.push(quote! {
1462                        let #param_name = arguments.get(#param_name_str)
1463                            .cloned()
1464                            .ok_or_else(|| fastmcp_core::McpError::invalid_params(
1465                                format!("missing required argument: {}", #param_name_str)
1466                            ))?;
1467                    });
1468                }
1469            }
1470        }
1471    }
1472
1473    let is_async = input_fn.sig.asyncness.is_some();
1474    let call_expr = if is_async {
1475        if expects_context {
1476            quote! {
1477                fastmcp_core::runtime::block_on(async move {
1478                    #fn_name(ctx, #(#param_names),*).await
1479                })
1480            }
1481        } else {
1482            quote! {
1483                fastmcp_core::runtime::block_on(async move {
1484                    #fn_name(#(#param_names),*).await
1485                })
1486            }
1487        }
1488    } else {
1489        if expects_context {
1490            quote! {
1491                #fn_name(ctx, #(#param_names),*)
1492            }
1493        } else {
1494            quote! {
1495                #fn_name(#(#param_names),*)
1496            }
1497        }
1498    };
1499
1500    // Generate result conversion based on return type (supports Result<Vec<PromptMessage>, E>)
1501    let return_type = &input_fn.sig.output;
1502    let prompt_result_conversion = generate_prompt_result_conversion(return_type);
1503
1504    let expanded = quote! {
1505        // Keep the original function
1506        #input_fn
1507
1508        /// Handler for the #fn_name prompt.
1509        #[derive(Clone)]
1510        pub struct #handler_name;
1511
1512        impl fastmcp_server::PromptHandler for #handler_name {
1513            fn definition(&self) -> fastmcp_protocol::Prompt {
1514                fastmcp_protocol::Prompt {
1515                    name: #prompt_name.to_string(),
1516                    description: #description_tokens,
1517                    arguments: vec![#(#prompt_args),*],
1518                    icon: None,
1519                    version: None,
1520                    tags: vec![],
1521                }
1522            }
1523
1524            #timeout_tokens
1525
1526            fn get(
1527                &self,
1528                ctx: &fastmcp_core::McpContext,
1529                arguments: std::collections::HashMap<String, String>,
1530            ) -> fastmcp_core::McpResult<Vec<fastmcp_protocol::PromptMessage>> {
1531                #(#param_extractions)*
1532                let result = #call_expr;
1533                #prompt_result_conversion
1534            }
1535        }
1536    };
1537
1538    TokenStream::from(expanded)
1539}
1540
1541/// Derives JSON Schema for a type.
1542///
1543/// Used for generating input schemas for tools. Generates a `json_schema()` method
1544/// that returns the JSON Schema representation of the type.
1545///
1546/// # Example
1547///
1548/// ```ignore
1549/// use fastmcp::JsonSchema;
1550///
1551/// #[derive(JsonSchema)]
1552/// struct MyToolInput {
1553///     /// The name of the person
1554///     name: String,
1555///     /// Optional age
1556///     age: Option<u32>,
1557///     /// List of tags
1558///     tags: Vec<String>,
1559/// }
1560///
1561/// // Generated schema:
1562/// // {
1563/// //   "type": "object",
1564/// //   "properties": {
1565/// //     "name": { "type": "string", "description": "The name of the person" },
1566/// //     "age": { "type": "integer", "description": "Optional age" },
1567/// //     "tags": { "type": "array", "items": { "type": "string" }, "description": "List of tags" }
1568/// //   },
1569/// //   "required": ["name", "tags"]
1570/// // }
1571/// ```
1572///
1573/// # Supported Types
1574///
1575/// - `String`, `&str` → `"string"`
1576/// - `i8`..`i128`, `u8`..`u128`, `isize`, `usize` → `"integer"`
1577/// - `f32`, `f64` → `"number"`
1578/// - `bool` → `"boolean"`
1579/// - `Option<T>` → schema for T, field not required
1580/// - `Vec<T>` → `"array"` with items schema
1581/// - `HashMap<String, T>` → `"object"` with additionalProperties
1582/// - Other types → `"object"` (custom types should derive JsonSchema)
1583///
1584/// # Attributes
1585///
1586/// - `#[json_schema(rename = "...")]` - Rename the field in the schema
1587/// - `#[json_schema(skip)]` - Skip this field
1588/// - `#[json_schema(flatten)]` - Flatten nested object properties
1589#[proc_macro_derive(JsonSchema, attributes(json_schema))]
1590pub fn derive_json_schema(input: TokenStream) -> TokenStream {
1591    let input = parse_macro_input!(input as syn::DeriveInput);
1592
1593    let name = &input.ident;
1594    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1595
1596    // Extract type-level doc comments for schema description
1597    let type_description = extract_doc_comments(&input.attrs);
1598    let type_desc_tokens = type_description
1599        .as_ref()
1600        .map_or_else(|| quote! { None::<&str> }, |desc| quote! { Some(#desc) });
1601
1602    // Process fields based on data type
1603    let schema_impl = match &input.data {
1604        syn::Data::Struct(data_struct) => generate_struct_schema(data_struct, &type_desc_tokens),
1605        syn::Data::Enum(data_enum) => generate_enum_schema(data_enum, &type_desc_tokens),
1606        syn::Data::Union(_) => {
1607            return syn::Error::new_spanned(input, "JsonSchema cannot be derived for unions")
1608                .to_compile_error()
1609                .into();
1610        }
1611    };
1612
1613    let expanded = quote! {
1614        impl #impl_generics #name #ty_generics #where_clause {
1615            /// Returns the JSON Schema for this type.
1616            pub fn json_schema() -> serde_json::Value {
1617                #schema_impl
1618            }
1619        }
1620    };
1621
1622    TokenStream::from(expanded)
1623}
1624
1625/// Generates JSON Schema for a struct.
1626fn generate_struct_schema(data: &syn::DataStruct, type_desc_tokens: &TokenStream2) -> TokenStream2 {
1627    match &data.fields {
1628        syn::Fields::Named(fields) => {
1629            let mut property_entries = Vec::new();
1630            let mut required_fields = Vec::new();
1631
1632            for field in &fields.named {
1633                // Check for skip attribute
1634                if has_json_schema_attr(&field.attrs, "skip") {
1635                    continue;
1636                }
1637
1638                let field_name = field.ident.as_ref().unwrap();
1639
1640                // Check for rename attribute
1641                let schema_name =
1642                    get_json_schema_rename(&field.attrs).unwrap_or_else(|| field_name.to_string());
1643
1644                // Get field doc comment
1645                let field_doc = extract_doc_comments(&field.attrs);
1646
1647                // Generate schema for this field's type
1648                let field_type = &field.ty;
1649                let is_optional = is_option_type(field_type);
1650
1651                // Generate the base schema
1652                let field_schema = type_to_json_schema(field_type);
1653
1654                // Add description if available
1655                let property_value = if let Some(desc) = &field_doc {
1656                    quote! {
1657                        {
1658                            let mut schema = #field_schema;
1659                            if let Some(obj) = schema.as_object_mut() {
1660                                obj.insert("description".to_string(), serde_json::json!(#desc));
1661                            }
1662                            schema
1663                        }
1664                    }
1665                } else {
1666                    field_schema
1667                };
1668
1669                property_entries.push(quote! {
1670                    (#schema_name.to_string(), #property_value)
1671                });
1672
1673                // Add to required if not optional
1674                if !is_optional {
1675                    required_fields.push(schema_name);
1676                }
1677            }
1678
1679            quote! {
1680                {
1681                    let properties: std::collections::HashMap<String, serde_json::Value> = vec![
1682                        #(#property_entries),*
1683                    ].into_iter().collect();
1684
1685                    let required: Vec<String> = vec![#(#required_fields.to_string()),*];
1686
1687                    let mut schema = serde_json::json!({
1688                        "type": "object",
1689                        "properties": properties,
1690                        "required": required,
1691                    });
1692
1693                    // Add description if available
1694                    if let Some(desc) = #type_desc_tokens {
1695                        if let Some(obj) = schema.as_object_mut() {
1696                            obj.insert("description".to_string(), serde_json::json!(desc));
1697                        }
1698                    }
1699
1700                    schema
1701                }
1702            }
1703        }
1704        syn::Fields::Unnamed(fields) => {
1705            // Tuple struct - generate as array
1706            if fields.unnamed.len() == 1 {
1707                // Newtype pattern - just use inner type's schema
1708                let inner_type = &fields.unnamed.first().unwrap().ty;
1709                let inner_schema = type_to_json_schema(inner_type);
1710                quote! { #inner_schema }
1711            } else {
1712                // Multiple fields - tuple represented as array with prefixItems
1713                let item_schemas: Vec<_> = fields
1714                    .unnamed
1715                    .iter()
1716                    .map(|f| type_to_json_schema(&f.ty))
1717                    .collect();
1718                let num_items = item_schemas.len();
1719                quote! {
1720                    {
1721                        let items: Vec<serde_json::Value> = vec![#(#item_schemas),*];
1722                        serde_json::json!({
1723                            "type": "array",
1724                            "prefixItems": items,
1725                            "minItems": #num_items,
1726                            "maxItems": #num_items,
1727                        })
1728                    }
1729                }
1730            }
1731        }
1732        syn::Fields::Unit => {
1733            // Unit struct - null type
1734            quote! { serde_json::json!({ "type": "null" }) }
1735        }
1736    }
1737}
1738
1739/// Generates JSON Schema for an enum.
1740fn generate_enum_schema(data: &syn::DataEnum, type_desc_tokens: &TokenStream2) -> TokenStream2 {
1741    // Check if all variants are unit variants (string enum)
1742    let all_unit = data
1743        .variants
1744        .iter()
1745        .all(|v| matches!(v.fields, syn::Fields::Unit));
1746
1747    if all_unit {
1748        // Simple string enum
1749        let variant_names: Vec<String> =
1750            data.variants.iter().map(|v| v.ident.to_string()).collect();
1751
1752        quote! {
1753            {
1754                let mut schema = serde_json::json!({
1755                    "type": "string",
1756                    "enum": [#(#variant_names),*]
1757                });
1758
1759                if let Some(desc) = #type_desc_tokens {
1760                    if let Some(obj) = schema.as_object_mut() {
1761                        obj.insert("description".to_string(), serde_json::json!(desc));
1762                    }
1763                }
1764
1765                schema
1766            }
1767        }
1768    } else {
1769        // Tagged union - use oneOf
1770        let variant_schemas: Vec<TokenStream2> = data
1771            .variants
1772            .iter()
1773            .map(|variant| {
1774                let variant_name = variant.ident.to_string();
1775                match &variant.fields {
1776                    syn::Fields::Unit => {
1777                        quote! {
1778                            serde_json::json!({
1779                                "type": "object",
1780                                "properties": {
1781                                    #variant_name: { "type": "null" }
1782                                },
1783                                "required": [#variant_name]
1784                            })
1785                        }
1786                    }
1787                    syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
1788                        let inner_type = &fields.unnamed.first().unwrap().ty;
1789                        let inner_schema = type_to_json_schema(inner_type);
1790                        quote! {
1791                            serde_json::json!({
1792                                "type": "object",
1793                                "properties": {
1794                                    #variant_name: #inner_schema
1795                                },
1796                                "required": [#variant_name]
1797                            })
1798                        }
1799                    }
1800                    _ => {
1801                        // Complex variant - just mark as object
1802                        quote! {
1803                            serde_json::json!({
1804                                "type": "object",
1805                                "properties": {
1806                                    #variant_name: { "type": "object" }
1807                                },
1808                                "required": [#variant_name]
1809                            })
1810                        }
1811                    }
1812                }
1813            })
1814            .collect();
1815
1816        quote! {
1817            {
1818                let mut schema = serde_json::json!({
1819                    "oneOf": [#(#variant_schemas),*]
1820                });
1821
1822                if let Some(desc) = #type_desc_tokens {
1823                    if let Some(obj) = schema.as_object_mut() {
1824                        obj.insert("description".to_string(), serde_json::json!(desc));
1825                    }
1826                }
1827
1828                schema
1829            }
1830        }
1831    }
1832}
1833
1834/// Checks if a field has a specific json_schema attribute.
1835fn has_json_schema_attr(attrs: &[Attribute], attr_name: &str) -> bool {
1836    for attr in attrs {
1837        if attr.path().is_ident("json_schema") {
1838            if let Meta::List(meta_list) = &attr.meta {
1839                if let Ok(nested) = meta_list.parse_args::<Ident>() {
1840                    if nested == attr_name {
1841                        return true;
1842                    }
1843                }
1844            }
1845        }
1846    }
1847    false
1848}
1849
1850/// Gets the rename value from json_schema attribute if present.
1851fn get_json_schema_rename(attrs: &[Attribute]) -> Option<String> {
1852    for attr in attrs {
1853        if attr.path().is_ident("json_schema") {
1854            if let Meta::List(meta_list) = &attr.meta {
1855                // Parse as ident = "value"
1856                let result: syn::Result<(Ident, LitStr)> =
1857                    meta_list.parse_args_with(|input: ParseStream| {
1858                        let ident: Ident = input.parse()?;
1859                        let _: Token![=] = input.parse()?;
1860                        let lit: LitStr = input.parse()?;
1861                        Ok((ident, lit))
1862                    });
1863
1864                if let Ok((ident, lit)) = result {
1865                    if ident == "rename" {
1866                        return Some(lit.value());
1867                    }
1868                }
1869            }
1870        }
1871    }
1872    None
1873}