Skip to main content

mcp_kit_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::quote;
6use syn::{
7    parse_macro_input, punctuated::Punctuated, spanned::Spanned, Attribute, FnArg, ItemFn, Lit,
8    Meta, MetaNameValue, Pat, PatType, Token, Type,
9};
10
11// ─── #[tool] ─────────────────────────────────────────────────────────────────
12
13/// Marks an async function as an MCP tool and generates a companion
14/// `{fn_name}_tool_def()` function that returns a `mcp::ToolDef`.
15///
16/// # Attributes
17/// - `description = "..."` — human-readable description (required)
18/// - `name = "..."` — tool name (defaults to the function name)
19///
20/// # Example
21/// ```rust,ignore
22/// use mcp_kit::prelude::*;
23///
24/// /// Add two numbers together.
25/// #[tool(description = "Add two numbers")]
26/// async fn add(a: f64, b: f64) -> String {
27///     format!("{}", a + b)
28/// }
29///
30/// #[tokio::main]
31/// async fn main() -> anyhow::Result<()> {
32///     // Register with the builder:
33///     let _server = McpServer::builder()
34///         .name("example")
35///         .version("1.0.0")
36///         .tool_def(add_tool_def())
37///         .build();
38///     Ok(())
39/// }
40/// ```
41///
42/// Each function parameter must implement `serde::Deserialize` and will be
43/// extracted from the tool call's `arguments` JSON object by field name.
44#[proc_macro_attribute]
45pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
46    let attr_args = parse_macro_input!(args with Punctuated::<Meta, Token![,]>::parse_terminated);
47    let func = parse_macro_input!(input as ItemFn);
48
49    match tool_impl(attr_args, func) {
50        Ok(ts) => ts.into(),
51        Err(e) => e.into_compile_error().into(),
52    }
53}
54
55fn tool_impl(attr_args: Punctuated<Meta, Token![,]>, func: ItemFn) -> syn::Result<TokenStream2> {
56    // ── Parse attribute options ───────────────────────────────────────────────
57    let mut description: Option<String> = None;
58    let mut tool_name: Option<String> = None;
59
60    for meta in &attr_args {
61        match meta {
62            Meta::NameValue(MetaNameValue { path, value, .. }) => {
63                let key = path.get_ident().map(|i| i.to_string()).unwrap_or_default();
64                if let syn::Expr::Lit(syn::ExprLit {
65                    lit: Lit::Str(s), ..
66                }) = value
67                {
68                    match key.as_str() {
69                        "description" => description = Some(s.value()),
70                        "name" => tool_name = Some(s.value()),
71                        other => {
72                            return Err(syn::Error::new(
73                                path.span(),
74                                format!("Unknown attribute: {other}"),
75                            ));
76                        }
77                    }
78                }
79            }
80            other => {
81                return Err(syn::Error::new(
82                    other.span(),
83                    "Expected key = \"value\" pairs",
84                ));
85            }
86        }
87    }
88
89    // Fall back to doc comment for description
90    if description.is_none() {
91        description = extract_doc_comment(&func.attrs);
92    }
93
94    let description = description.ok_or_else(|| {
95        syn::Error::new(
96            Span::call_site(),
97            "#[tool] requires `description = \"...\"`",
98        )
99    })?;
100
101    let fn_ident = &func.sig.ident;
102    let fn_name_str = tool_name.unwrap_or_else(|| fn_ident.to_string().replace('_', "-"));
103    let def_fn_ident = syn::Ident::new(&format!("{fn_ident}_tool_def"), fn_ident.span());
104
105    // ── Parse function parameters ─────────────────────────────────────────────
106    struct Param {
107        name: String,
108        ty: Type,
109        doc: String,
110    }
111
112    let mut params: Vec<Param> = Vec::new();
113    let mut has_auth_param = false;
114
115    for arg in &func.sig.inputs {
116        match arg {
117            FnArg::Typed(PatType { pat, ty, attrs, .. }) => {
118                let name = match pat.as_ref() {
119                    Pat::Ident(id) => id.ident.to_string(),
120                    _ => {
121                        return Err(syn::Error::new(
122                            pat.span(),
123                            "Only simple identifiers supported",
124                        ));
125                    }
126                };
127                // Detect `Auth` extractor — exclude it from the JSON schema and
128                // argument extraction; it will be injected via `Auth::from_context()`.
129                if type_is_auth(ty) {
130                    has_auth_param = true;
131                    continue;
132                }
133                let doc = extract_doc_comment(attrs).unwrap_or_default();
134                params.push(Param {
135                    name,
136                    ty: *ty.clone(),
137                    doc,
138                });
139            }
140            FnArg::Receiver(r) => {
141                return Err(syn::Error::new(
142                    r.span(),
143                    "#[tool] functions must not take `self`",
144                ));
145            }
146        }
147    }
148
149    // ── Build JSON Schema for input parameters ────────────────────────────────
150    let prop_inserts: Vec<TokenStream2> = params
151        .iter()
152        .map(|p| {
153            let name = &p.name;
154            let doc = &p.doc;
155            let ty = &p.ty;
156            quote! {
157                {
158                    let mut schema = ::mcp_kit::__private::schemars::schema_for!(#ty).schema;
159                    // Inline the schema as JSON
160                    let schema_val = ::mcp_kit::__private::serde_json::to_value(&schema)
161                        .expect("schema serialization failed");
162                    let final_val = if !#doc.is_empty() {
163                        // Wrap with description
164                        let mut obj = match schema_val {
165                            ::mcp_kit::__private::serde_json::Value::Object(m) => m,
166                            other => {
167                                let mut m = ::mcp_kit::__private::serde_json::Map::new();
168                                m.insert("type".to_string(), other);
169                                m
170                            }
171                        };
172                        obj.insert("description".to_string(), ::mcp_kit::__private::serde_json::Value::String(#doc.to_string()));
173                        ::mcp_kit::__private::serde_json::Value::Object(obj)
174                    } else {
175                        schema_val
176                    };
177                    properties.insert(#name.to_string(), final_val);
178                }
179            }
180        })
181        .collect();
182
183    let required_entries: Vec<String> = params.iter().map(|p| p.name.clone()).collect();
184
185    let param_extracts: Vec<TokenStream2> = params
186        .iter()
187        .map(|p| {
188            let name_str = &p.name;
189            let name_ident = syn::Ident::new(name_str, Span::call_site());
190            let ty = &p.ty;
191            quote! {
192                let #name_ident: #ty = ::mcp_kit::__private::serde_json::from_value(
193                    args.get(#name_str)
194                        .cloned()
195                        .unwrap_or(::mcp_kit::__private::serde_json::Value::Null)
196                ).map_err(|e| ::mcp_kit::McpError::InvalidParams(
197                    format!("param `{}`: {}", #name_str, e)
198                ))?;
199            }
200        })
201        .collect();
202
203    let param_names: Vec<syn::Ident> = params
204        .iter()
205        .map(|p| syn::Ident::new(&p.name, Span::call_site()))
206        .collect();
207
208    let fn_vis = &func.vis;
209
210    // When the handler declares an `Auth` parameter, inject it from the
211    // task-local auth context before calling the user function.
212    let auth_extract = if has_auth_param {
213        quote! {
214            let auth = ::mcp_kit::__private::Auth::from_context()?;
215        }
216    } else {
217        quote! {}
218    };
219
220    // Build the call arguments - if we have regular params plus auth, use comma separator
221    let call_args = if has_auth_param {
222        if param_names.is_empty() {
223            quote! { auth }
224        } else {
225            quote! { #(#param_names),*, auth }
226        }
227    } else {
228        quote! { #(#param_names),* }
229    };
230
231    let expanded = quote! {
232        // Keep the original function unchanged
233        #func
234
235        /// Auto-generated tool definition (from `#[tool]` macro).
236        #fn_vis fn #def_fn_ident() -> ::mcp_kit::ToolDef {
237            use ::mcp_kit::__private::serde_json;
238
239            // Build the input schema
240            let mut properties = serde_json::Map::new();
241            #(#prop_inserts)*
242
243            let input_schema = serde_json::json!({
244                "type": "object",
245                "properties": properties,
246                "required": [ #(#required_entries),* ],
247            });
248
249            let tool = ::mcp_kit::Tool::new(
250                #fn_name_str,
251                #description,
252                input_schema,
253            );
254
255            let handler = ::std::sync::Arc::new(move |req: ::mcp_kit::__private::CallToolRequest| {
256                Box::pin(async move {
257                    let args = match req.arguments {
258                        serde_json::Value::Object(m) => m,
259                        serde_json::Value::Null => serde_json::Map::new(),
260                        other => {
261                            return Err(::mcp_kit::McpError::InvalidParams(
262                                format!("expected object, got: {other}")
263                            ));
264                        }
265                    };
266                    #auth_extract
267                    #(#param_extracts)*
268                    let result = #fn_ident(#call_args).await;
269                    Ok(::mcp_kit::__private::IntoToolResult::into_tool_result(result))
270                }) as ::mcp_kit::__private::BoxFuture<'static, ::mcp_kit::__private::McpResult<::mcp_kit::CallToolResult>>
271            });
272
273            ::mcp_kit::ToolDef::new(tool, handler)
274        }
275    };
276
277    Ok(expanded)
278}
279
280// ─── #[resource] ─────────────────────────────────────────────────────────────
281
282/// Marks an async function as an MCP resource handler and generates a companion
283/// `{fn_name}_resource_def()` function that returns a `mcp::ResourceDef`.
284///
285/// # Attributes
286/// - `uri = "..."` — Resource URI (required). Use `{variable}` for templates.
287/// - `name = "..."` — Human-readable name (required)
288/// - `description = "..."` — Optional description
289/// - `mime_type = "..."` — Optional MIME type (e.g., "application/json")
290///
291/// # Examples
292///
293/// Static resource:
294/// ```rust,ignore
295/// use mcp_kit::prelude::*;
296///
297/// #[resource(uri = "config://app", name = "App Config", description = "Application configuration")]
298/// async fn app_config(_req: ReadResourceRequest) -> McpResult<ReadResourceResult> {
299///     Ok(ReadResourceResult::text("config://app", r#"{"version": "1.0"}"#))
300/// }
301/// ```
302///
303/// Template resource:
304/// ```rust,ignore
305/// use mcp_kit::prelude::*;
306///
307/// #[resource(uri = "file://{path}", name = "File System")]
308/// async fn read_file(req: ReadResourceRequest) -> McpResult<ReadResourceResult> {
309///     let path = req.uri.trim_start_matches("file://");
310///     let content = tokio::fs::read_to_string(path).await?;
311///     Ok(ReadResourceResult::text(req.uri.clone(), content))
312/// }
313/// ```
314#[proc_macro_attribute]
315pub fn resource(args: TokenStream, input: TokenStream) -> TokenStream {
316    let attr_args = parse_macro_input!(args with Punctuated::<Meta, Token![,]>::parse_terminated);
317    let func = parse_macro_input!(input as ItemFn);
318
319    match resource_impl(attr_args, func) {
320        Ok(ts) => ts.into(),
321        Err(e) => e.into_compile_error().into(),
322    }
323}
324
325fn resource_impl(
326    attr_args: Punctuated<Meta, Token![,]>,
327    func: ItemFn,
328) -> syn::Result<TokenStream2> {
329    // ── Parse attribute options ───────────────────────────────────────────────
330    let mut uri: Option<String> = None;
331    let mut name: Option<String> = None;
332    let mut description: Option<String> = None;
333    let mut mime_type: Option<String> = None;
334
335    for meta in &attr_args {
336        match meta {
337            Meta::NameValue(MetaNameValue { path, value, .. }) => {
338                let key = path.get_ident().map(|i| i.to_string()).unwrap_or_default();
339                if let syn::Expr::Lit(syn::ExprLit {
340                    lit: Lit::Str(s), ..
341                }) = value
342                {
343                    match key.as_str() {
344                        "uri" => uri = Some(s.value()),
345                        "name" => name = Some(s.value()),
346                        "description" => description = Some(s.value()),
347                        "mime_type" => mime_type = Some(s.value()),
348                        other => {
349                            return Err(syn::Error::new(
350                                path.span(),
351                                format!("Unknown attribute: {other}"),
352                            ));
353                        }
354                    }
355                }
356            }
357            other => {
358                return Err(syn::Error::new(
359                    other.span(),
360                    "Expected key = \"value\" pairs",
361                ));
362            }
363        }
364    }
365
366    let uri = uri.ok_or_else(|| {
367        syn::Error::new(Span::call_site(), "#[resource] requires `uri = \"...\"`")
368    })?;
369    let name = name.ok_or_else(|| {
370        syn::Error::new(Span::call_site(), "#[resource] requires `name = \"...\"`")
371    })?;
372
373    let fn_ident = &func.sig.ident;
374    let def_fn_ident = syn::Ident::new(&format!("{fn_ident}_resource_def"), fn_ident.span());
375    let fn_vis = &func.vis;
376
377    // Check if URI is a template (contains {variable})
378    let is_template = uri.contains('{');
379
380    // Generate optional method calls
381    let with_description = description.as_ref().map(|desc| {
382        quote! { .with_description(#desc) }
383    });
384    let with_mime_type = mime_type.as_ref().map(|mime| {
385        quote! { .with_mime_type(#mime) }
386    });
387
388    let expanded = if is_template {
389        // Generate ResourceDef::Template
390        quote! {
391            // Keep the original function unchanged
392            #func
393
394            /// Auto-generated resource definition (from `#[resource]` macro).
395            #fn_vis fn #def_fn_ident() -> ::mcp_kit::__private::ResourceDef {
396                let template = ::mcp_kit::__private::ResourceTemplate::new(#uri, #name)
397                    #with_description
398                    #with_mime_type;
399
400                let handler = ::std::sync::Arc::new(move |req: ::mcp_kit::__private::ReadResourceRequest| {
401                    Box::pin(async move {
402                        #fn_ident(req).await
403                    }) as ::mcp_kit::__private::BoxFuture<'static, ::mcp_kit::__private::McpResult<::mcp_kit::__private::ReadResourceResult>>
404                });
405
406                ::mcp_kit::__private::ResourceDef::new_template(template, handler)
407            }
408        }
409    } else {
410        // Generate ResourceDef::Static
411        quote! {
412            // Keep the original function unchanged
413            #func
414
415            /// Auto-generated resource definition (from `#[resource]` macro).
416            #fn_vis fn #def_fn_ident() -> ::mcp_kit::__private::ResourceDef {
417                let resource = ::mcp_kit::__private::Resource::new(#uri, #name)
418                    #with_description
419                    #with_mime_type;
420
421                let handler = ::std::sync::Arc::new(move |req: ::mcp_kit::__private::ReadResourceRequest| {
422                    Box::pin(async move {
423                        #fn_ident(req).await
424                    }) as ::mcp_kit::__private::BoxFuture<'static, ::mcp_kit::__private::McpResult<::mcp_kit::__private::ReadResourceResult>>
425                });
426
427                ::mcp_kit::__private::ResourceDef::new_static(resource, handler)
428            }
429        }
430    };
431
432    Ok(expanded)
433}
434
435// ─── #[prompt] ───────────────────────────────────────────────────────────────
436
437/// Marks an async function as an MCP prompt handler and generates a companion
438/// `{fn_name}_prompt_def()` function that returns a `mcp::PromptDef`.
439///
440/// # Attributes
441/// - `name = "..."` — Prompt name (defaults to function name with `-` instead of `_`)
442/// - `description = "..."` — Optional description
443/// - `arguments = ["arg1", "arg2:required", "arg3:optional"]` — Optional argument list
444///
445/// # Examples
446///
447/// Basic prompt:
448/// ```rust,ignore
449/// use mcp_kit::prelude::*;
450///
451/// #[prompt(name = "greeting", description = "Generate a greeting message")]
452/// async fn greeting(_req: GetPromptRequest) -> McpResult<GetPromptResult> {
453///     Ok(GetPromptResult::new(vec![
454///         PromptMessage::user_text("Hello! How can I help you today?")
455///     ]))
456/// }
457/// ```
458///
459/// Prompt with arguments:
460/// ```rust,ignore
461/// use mcp_kit::prelude::*;
462///
463/// #[prompt(
464///     name = "code-review",
465///     description = "Generate a code review",
466///     arguments = ["code:required", "language:optional"]
467/// )]
468/// async fn code_review(req: GetPromptRequest) -> McpResult<GetPromptResult> {
469///     let code = req.arguments.get("code").cloned().unwrap_or_default();
470///     let lang = req.arguments.get("language").cloned().unwrap_or_else(|| "unknown".into());
471///     
472///     Ok(GetPromptResult::new(vec![
473///         PromptMessage::user_text(format!("Review this {lang} code:\n\n```{lang}\n{code}\n```"))
474///     ]))
475/// }
476/// ```
477#[proc_macro_attribute]
478pub fn prompt(args: TokenStream, input: TokenStream) -> TokenStream {
479    let attr_args = parse_macro_input!(args with Punctuated::<Meta, Token![,]>::parse_terminated);
480    let func = parse_macro_input!(input as ItemFn);
481
482    match prompt_impl(attr_args, func) {
483        Ok(ts) => ts.into(),
484        Err(e) => e.into_compile_error().into(),
485    }
486}
487
488fn prompt_impl(attr_args: Punctuated<Meta, Token![,]>, func: ItemFn) -> syn::Result<TokenStream2> {
489    // ── Parse attribute options ───────────────────────────────────────────────
490    let mut prompt_name: Option<String> = None;
491    let mut description: Option<String> = None;
492    let mut arguments: Vec<(String, bool)> = Vec::new(); // (name, required)
493
494    for meta in &attr_args {
495        match meta {
496            Meta::NameValue(MetaNameValue { path, value, .. }) => {
497                let key = path.get_ident().map(|i| i.to_string()).unwrap_or_default();
498                match key.as_str() {
499                    "name" => {
500                        if let syn::Expr::Lit(syn::ExprLit {
501                            lit: Lit::Str(s), ..
502                        }) = value
503                        {
504                            prompt_name = Some(s.value());
505                        }
506                    }
507                    "description" => {
508                        if let syn::Expr::Lit(syn::ExprLit {
509                            lit: Lit::Str(s), ..
510                        }) = value
511                        {
512                            description = Some(s.value());
513                        }
514                    }
515                    "arguments" => {
516                        // Parse array of argument strings: ["arg1", "arg2:required", "arg3:optional"]
517                        if let syn::Expr::Array(syn::ExprArray { elems, .. }) = value {
518                            for elem in elems {
519                                if let syn::Expr::Lit(syn::ExprLit {
520                                    lit: Lit::Str(s), ..
521                                }) = elem
522                                {
523                                    let arg_str = s.value();
524                                    let (name, required) = if arg_str.contains(':') {
525                                        let parts: Vec<&str> = arg_str.split(':').collect();
526                                        let name = parts[0].to_string();
527                                        let required =
528                                            parts.get(1).map_or(true, |&r| r == "required");
529                                        (name, required)
530                                    } else {
531                                        (arg_str, true) // default to required
532                                    };
533                                    arguments.push((name, required));
534                                }
535                            }
536                        }
537                    }
538                    other => {
539                        return Err(syn::Error::new(
540                            path.span(),
541                            format!("Unknown attribute: {other}"),
542                        ));
543                    }
544                }
545            }
546            other => {
547                return Err(syn::Error::new(
548                    other.span(),
549                    "Expected key = \"value\" pairs",
550                ));
551            }
552        }
553    }
554
555    let fn_ident = &func.sig.ident;
556    let prompt_name = prompt_name.unwrap_or_else(|| fn_ident.to_string().replace('_', "-"));
557    let def_fn_ident = syn::Ident::new(&format!("{fn_ident}_prompt_def"), fn_ident.span());
558    let fn_vis = &func.vis;
559
560    // Generate argument definitions
561    let arg_definitions: Vec<TokenStream2> = arguments
562        .iter()
563        .map(|(name, required)| {
564            if *required {
565                quote! {
566                    ::mcp_kit::__private::PromptArgument::required(#name)
567                }
568            } else {
569                quote! {
570                    ::mcp_kit::__private::PromptArgument::optional(#name)
571                }
572            }
573        })
574        .collect();
575
576    let with_args = if !arguments.is_empty() {
577        quote! {
578            .with_arguments(vec![#(#arg_definitions),*])
579        }
580    } else {
581        quote! {}
582    };
583
584    let with_desc = if let Some(desc) = &description {
585        quote! {
586            .with_description(#desc)
587        }
588    } else {
589        quote! {}
590    };
591
592    let expanded = quote! {
593        // Keep the original function unchanged
594        #func
595
596        /// Auto-generated prompt definition (from `#[prompt]` macro).
597        #fn_vis fn #def_fn_ident() -> ::mcp_kit::__private::PromptDef {
598            let prompt = ::mcp_kit::__private::Prompt::new(#prompt_name)
599                #with_desc
600                #with_args;
601
602            let handler = ::std::sync::Arc::new(move |req: ::mcp_kit::__private::GetPromptRequest| {
603                Box::pin(async move {
604                    #fn_ident(req).await
605                }) as ::mcp_kit::__private::BoxFuture<'static, ::mcp_kit::__private::McpResult<::mcp_kit::__private::GetPromptResult>>
606            });
607
608            ::mcp_kit::__private::PromptDef::new(prompt, handler)
609        }
610    };
611
612    Ok(expanded)
613}
614
615// ─── Helpers ──────────────────────────────────────────────────────────────────
616
617/// Returns `true` if `ty` refers to the `Auth` extractor type.
618///
619/// Matches: `Auth`, `mcp_kit::Auth`, `::mcp_kit::Auth`.
620fn type_is_auth(ty: &Type) -> bool {
621    if let Type::Path(tp) = ty {
622        let segments = &tp.path.segments;
623        if let Some(last) = segments.last() {
624            return last.ident == "Auth";
625        }
626    }
627    false
628}
629
630fn extract_doc_comment(attrs: &[Attribute]) -> Option<String> {
631    let lines: Vec<String> = attrs
632        .iter()
633        .filter_map(|attr| {
634            if !attr.path().is_ident("doc") {
635                return None;
636            }
637            if let Meta::NameValue(MetaNameValue {
638                value:
639                    syn::Expr::Lit(syn::ExprLit {
640                        lit: Lit::Str(s), ..
641                    }),
642                ..
643            }) = &attr.meta
644            {
645                Some(s.value().trim().to_owned())
646            } else {
647                None
648            }
649        })
650        .collect();
651
652    if lines.is_empty() {
653        None
654    } else {
655        Some(lines.join(" "))
656    }
657}