mcp_core_macros/
lib.rs

1//! A library of procedural macros for defining MCP core tools.
2//!
3//! This crate provides macros for defining tools that can be used with the MCP system.
4//! The main macro is `tool`, which is used to define a tool function that can be called
5//! by the system.
6
7use convert_case::{Case, Casing};
8use proc_macro::TokenStream;
9use quote::{format_ident, quote};
10use syn::{
11    parse::{Parse, ParseStream},
12    punctuated::Punctuated,
13    Expr, ExprLit, FnArg, ItemFn, Lit, Meta, Pat, PatType, Token, Type,
14};
15
16#[derive(Debug)]
17struct ToolArgs {
18    name: Option<String>,
19    description: Option<String>,
20    annotations: ToolAnnotations,
21}
22
23#[derive(Debug)]
24struct ToolAnnotations {
25    title: Option<String>,
26    read_only_hint: Option<bool>,
27    destructive_hint: Option<bool>,
28    idempotent_hint: Option<bool>,
29    open_world_hint: Option<bool>,
30}
31
32impl Default for ToolAnnotations {
33    fn default() -> Self {
34        Self {
35            title: None,
36            read_only_hint: None,
37            destructive_hint: None,
38            idempotent_hint: None,
39            open_world_hint: None,
40        }
41    }
42}
43
44impl Parse for ToolArgs {
45    fn parse(input: ParseStream) -> syn::Result<Self> {
46        let mut name = None;
47        let mut description = None;
48        let mut annotations = ToolAnnotations::default();
49
50        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
51
52        for meta in meta_list {
53            match meta {
54                Meta::NameValue(nv) => {
55                    let ident = nv.path.get_ident().unwrap().to_string();
56                    if let Expr::Lit(ExprLit {
57                        lit: Lit::Str(lit_str),
58                        ..
59                    }) = nv.value
60                    {
61                        match ident.as_str() {
62                            "name" => name = Some(lit_str.value()),
63                            "description" => description = Some(lit_str.value()),
64                            _ => {
65                                return Err(syn::Error::new_spanned(
66                                    nv.path,
67                                    format!("Unknown attribute: {}", ident),
68                                ))
69                            }
70                        }
71                    } else {
72                        return Err(syn::Error::new_spanned(nv.value, "Expected string literal"));
73                    }
74                }
75                Meta::List(list) if list.path.is_ident("annotations") => {
76                    let nested: Punctuated<Meta, Token![,]> =
77                        list.parse_args_with(Punctuated::parse_terminated)?;
78
79                    for meta in nested {
80                        if let Meta::NameValue(nv) = meta {
81                            let key = nv.path.get_ident().unwrap().to_string();
82
83                            if let Expr::Lit(ExprLit {
84                                lit: Lit::Str(lit_str),
85                                ..
86                            }) = nv.value
87                            {
88                                if key == "title" {
89                                    annotations.title = Some(lit_str.value());
90                                } else {
91                                    return Err(syn::Error::new_spanned(
92                                        nv.path,
93                                        format!("Unknown string annotation: {}", key),
94                                    ));
95                                }
96                            } else if let Expr::Lit(ExprLit {
97                                lit: Lit::Bool(lit_bool),
98                                ..
99                            }) = nv.value
100                            {
101                                match key.as_str() {
102                                    "read_only_hint" | "readOnlyHint" => {
103                                        annotations.read_only_hint = Some(lit_bool.value)
104                                    }
105                                    "destructive_hint" | "destructiveHint" => {
106                                        annotations.destructive_hint = Some(lit_bool.value)
107                                    }
108                                    "idempotent_hint" | "idempotentHint" => {
109                                        annotations.idempotent_hint = Some(lit_bool.value)
110                                    }
111                                    "open_world_hint" | "openWorldHint" => {
112                                        annotations.open_world_hint = Some(lit_bool.value)
113                                    }
114                                    _ => {
115                                        return Err(syn::Error::new_spanned(
116                                            nv.path,
117                                            format!("Unknown boolean annotation: {}", key),
118                                        ))
119                                    }
120                                }
121                            } else {
122                                return Err(syn::Error::new_spanned(
123                                    nv.value,
124                                    "Expected string or boolean literal for annotation value",
125                                ));
126                            }
127                        } else {
128                            return Err(syn::Error::new_spanned(
129                                meta,
130                                "Expected name-value pair for annotation",
131                            ));
132                        }
133                    }
134                }
135                _ => {
136                    return Err(syn::Error::new_spanned(
137                        meta,
138                        "Expected name-value pair or list",
139                    ))
140                }
141            }
142        }
143
144        Ok(ToolArgs {
145            name,
146            description,
147            annotations,
148        })
149    }
150}
151
152/// Defines a tool function that can be called by the MCP system.
153///
154/// This macro transforms an async function into a tool that can be registered with the MCP system.
155/// It generates a corresponding structure with methods to get the tool definition and to handle
156/// calls to the tool.
157///
158/// # Arguments
159///
160/// * `name` - The name of the tool (optional, defaults to the function name)
161/// * `description` - A description of what the tool does
162/// * `annotations` - Additional metadata for the tool:
163///   * `title` - Display title for the tool (defaults to function name)
164///   * `read_only_hint` - Whether the tool only reads data (defaults to false)
165///   * `destructive_hint` - Whether the tool makes destructive changes (defaults to true)
166///   * `idempotent_hint` - Whether the tool is idempotent (defaults to false)
167///   * `open_world_hint` - Whether the tool can access resources outside the system (defaults to true)
168///
169/// # Example
170///
171/// ```rust
172/// use mcp_core_macros::{tool, tool_param};
173/// use mcp_core::types::ToolResponseContent;
174/// use mcp_core::tool_text_content;
175/// use anyhow::Result;
176///
177/// #[tool(name = "my_tool", description = "A tool with documented parameters", annotations(title = "My Tool"))]
178/// async fn my_tool(
179///     // A required parameter with description
180///     required_param: tool_param!(String, description = "A required parameter"),
181///     
182///     // An optional parameter
183///     optional_param: tool_param!(Option<String>, description = "An optional parameter"),
184///     
185///     // A hidden parameter that won't appear in the schema
186///     internal_param: tool_param!(String, hidden)
187/// ) -> Result<ToolResponseContent> {
188///     // Implementation
189///     Ok(tool_text_content!("Tool executed".to_string()))
190/// }
191/// ```
192#[proc_macro_attribute]
193pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
194    let args = match syn::parse::<ToolArgs>(args) {
195        Ok(args) => args,
196        Err(e) => return e.to_compile_error().into(),
197    };
198
199    let input_fn = match syn::parse::<ItemFn>(input.clone()) {
200        Ok(input_fn) => input_fn,
201        Err(e) => return e.to_compile_error().into(),
202    };
203
204    let fn_name = &input_fn.sig.ident;
205    let fn_name_str = fn_name.to_string();
206    let struct_name = format_ident!("{}", fn_name_str.to_case(Case::Pascal));
207    let tool_name = args.name.unwrap_or(fn_name_str.clone());
208    let tool_description = args.description.unwrap_or_default();
209
210    // Tool annotations
211    let title = args.annotations.title.unwrap_or(fn_name_str.clone());
212    let read_only_hint = args.annotations.read_only_hint.unwrap_or(false);
213    let destructive_hint = args.annotations.destructive_hint.unwrap_or(true);
214    let idempotent_hint = args.annotations.idempotent_hint.unwrap_or(false);
215    let open_world_hint = args.annotations.open_world_hint.unwrap_or(true);
216
217    let mut param_defs = Vec::new();
218    let mut param_names = Vec::new();
219    let mut required_params = Vec::new();
220    let mut hidden_params: Vec<String> = Vec::new();
221    let mut param_descriptions = Vec::new();
222
223    for arg in input_fn.sig.inputs.iter() {
224        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
225            let mut is_hidden = false;
226            let mut description: Option<String> = None;
227            let mut is_optional = false;
228
229            // Check for tool_type macro usage
230            if let Type::Macro(type_macro) = &**ty {
231                if let Some(ident) = type_macro.mac.path.get_ident() {
232                    if ident == "tool_param" {
233                        if let Ok(args) =
234                            syn::parse2::<ToolParamArgs>(type_macro.mac.tokens.clone())
235                        {
236                            is_hidden = args.hidden;
237                            description = args.description;
238
239                            // Check if the parameter type is Option<T>
240                            if let Type::Path(type_path) = &args.ty {
241                                is_optional = type_path
242                                    .path
243                                    .segments
244                                    .last()
245                                    .map_or(false, |segment| segment.ident == "Option");
246                            }
247                        }
248                    }
249                }
250            }
251
252            if is_hidden {
253                if let Pat::Ident(ident) = &**pat {
254                    hidden_params.push(ident.ident.to_string());
255                }
256            }
257
258            if let Pat::Ident(param_ident) = &**pat {
259                let param_name = &param_ident.ident;
260                let param_name_str = param_name.to_string();
261
262                param_names.push(param_name.clone());
263
264                // Check if the parameter type is Option<T>
265                if !is_optional {
266                    is_optional = if let Type::Path(type_path) = &**ty {
267                        type_path
268                            .path
269                            .segments
270                            .last()
271                            .map_or(false, |segment| segment.ident == "Option")
272                    } else {
273                        false
274                    }
275                }
276
277                // Only require non-optional, non-hidden
278                if !is_optional && !is_hidden {
279                    required_params.push(param_name_str.clone());
280                }
281
282                if let Some(desc) = description {
283                    param_descriptions.push(quote! {
284                        if name == #param_name_str {
285                            prop_obj.insert("description".to_string(), serde_json::Value::String(#desc.to_string()));
286                        }
287                    });
288                }
289
290                param_defs.push(quote! {
291                    #param_name: #ty
292                });
293            }
294        }
295    }
296
297    let params_struct_name = format_ident!("{}Parameters", struct_name);
298    let expanded = quote! {
299        #[derive(serde::Deserialize, schemars::JsonSchema)]
300        struct #params_struct_name {
301            #(#param_defs,)*
302        }
303
304        #input_fn
305
306        #[derive(Default)]
307        pub struct #struct_name;
308
309        impl #struct_name {
310            pub fn tool() -> mcp_core::types::Tool {
311                let schema = schemars::schema_for!(#params_struct_name);
312                let mut schema = serde_json::to_value(schema.schema).unwrap_or_default();
313                if let serde_json::Value::Object(ref mut map) = schema {
314                    // Add required fields
315                    map.insert("required".to_string(), serde_json::Value::Array(
316                        vec![#(serde_json::Value::String(#required_params.to_string())),*]
317                    ));
318                    map.remove("title");
319
320                    // Normalize property types
321                    if let Some(serde_json::Value::Object(props)) = map.get_mut("properties") {
322                        for (name, prop) in props.iter_mut() {
323                            if let serde_json::Value::Object(prop_obj) = prop {
324                                // Fix number types
325                                if let Some(type_val) = prop_obj.get("type") {
326                                    if type_val == "integer" || type_val == "number" || prop_obj.contains_key("format") {
327                                        // Convert any numeric type to "number"
328                                        prop_obj.insert("type".to_string(), serde_json::Value::String("number".to_string()));
329                                        prop_obj.remove("format");
330                                        prop_obj.remove("minimum");
331                                        prop_obj.remove("maximum");
332                                    }
333                                }
334
335                                // Fix optional types (array with null)
336                                if let Some(serde_json::Value::Array(types)) = prop_obj.get("type") {
337                                    if types.len() == 2 && types.contains(&serde_json::Value::String("null".to_string())) {
338                                        let mut main_type = types.iter()
339                                            .find(|&t| t != &serde_json::Value::String("null".to_string()))
340                                            .cloned()
341                                            .unwrap_or(serde_json::Value::String("string".to_string()));
342
343                                        // If the main type is "integer", convert it to "number"
344                                        if main_type == serde_json::Value::String("integer".to_string()) {
345                                            main_type = serde_json::Value::String("number".to_string());
346                                        }
347
348                                        prop_obj.insert("type".to_string(), main_type);
349                                    }
350                                }
351
352                                // Add descriptions if they exist
353                                #(#param_descriptions)*
354                            }
355                        }
356
357                        #(props.remove(#hidden_params);)*
358                    }
359                }
360
361                let annotations = serde_json::json!({
362                    "title": #title,
363                    "readOnlyHint": #read_only_hint,
364                    "destructiveHint": #destructive_hint,
365                    "idempotentHint": #idempotent_hint,
366                    "openWorldHint": #open_world_hint
367                });
368
369                mcp_core::types::Tool {
370                    name: #tool_name.to_string(),
371                    description: Some(#tool_description.to_string()),
372                    input_schema: schema,
373                    annotations: Some(mcp_core::types::ToolAnnotations {
374                        title: Some(#title.to_string()),
375                        read_only_hint: Some(#read_only_hint),
376                        destructive_hint: Some(#destructive_hint),
377                        idempotent_hint: Some(#idempotent_hint),
378                        open_world_hint: Some(#open_world_hint),
379                    }),
380                }
381            }
382
383            pub fn call() -> mcp_core::tools::ToolHandlerFn {
384                move |req: mcp_core::types::CallToolRequest| {
385                    Box::pin(async move {
386                        let params = match req.arguments {
387                            Some(args) => serde_json::to_value(args).unwrap_or_default(),
388                            None => serde_json::Value::Null,
389                        };
390
391                        let params: #params_struct_name = match serde_json::from_value(params) {
392                            Ok(p) => p,
393                            Err(e) => return mcp_core::types::CallToolResponse {
394                                content: vec![mcp_core::types::ToolResponseContent::Text(
395                                    mcp_core::types::TextContent {
396                                        content_type: "text".to_string(),
397                                        text: format!("Invalid parameters: {}", e),
398                                        annotations: None,
399                                    }
400                                )],
401                                is_error: Some(true),
402                                meta: None,
403                            },
404                        };
405
406                        match #fn_name(#(params.#param_names,)*).await {
407                            Ok(response) => {
408                                let content = if let Ok(vec_content) = serde_json::from_value::<Vec<mcp_core::types::ToolResponseContent>>(serde_json::to_value(&response).unwrap_or_default()) {
409                                    vec_content
410                                } else if let Ok(single_content) = serde_json::from_value::<mcp_core::types::ToolResponseContent>(serde_json::to_value(&response).unwrap_or_default()) {
411                                    vec![single_content]
412                                } else {
413                                    vec![mcp_core::types::ToolResponseContent::Text(
414                                        mcp_core::types::TextContent {
415                                            content_type: "text".to_string(),
416                                            text: format!("Invalid response type: {:?}", response),
417                                            annotations: None,
418                                        }
419                                    )]
420                                };
421
422                                mcp_core::types::CallToolResponse {
423                                    content,
424                                    is_error: None,
425                                    meta: None,
426                                }
427                            }
428                            Err(e) => mcp_core::types::CallToolResponse {
429                                content: vec![mcp_core::types::ToolResponseContent::Text(
430                                    mcp_core::types::TextContent {
431                                        content_type: "text".to_string(),
432                                        text: format!("Tool execution error: {}", e),
433                                        annotations: None,
434                                    }
435                                )],
436                                is_error: Some(true),
437                                meta: None,
438                            },
439                        }
440                    })
441                }
442            }
443        }
444    };
445
446    TokenStream::from(expanded)
447}
448
449#[derive(Debug)]
450struct ToolParamArgs {
451    ty: Type,
452    hidden: bool,
453    description: Option<String>,
454}
455
456impl Parse for ToolParamArgs {
457    fn parse(input: ParseStream) -> syn::Result<Self> {
458        let mut hidden = false;
459        let mut description = None;
460        let ty = input.parse()?;
461
462        if input.peek(Token![,]) {
463            input.parse::<Token![,]>()?;
464            let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
465
466            for meta in meta_list {
467                match meta {
468                    Meta::Path(path) if path.is_ident("hidden") => {
469                        hidden = true;
470                    }
471                    Meta::NameValue(nv) if nv.path.is_ident("description") => {
472                        if let Expr::Lit(ExprLit {
473                            lit: Lit::Str(lit_str),
474                            ..
475                        }) = &nv.value
476                        {
477                            description = Some(lit_str.value().to_string());
478                        }
479                    }
480                    _ => {}
481                }
482            }
483        }
484
485        Ok(ToolParamArgs {
486            ty,
487            hidden,
488            description,
489        })
490    }
491}
492
493/// Defines a parameter for a tool function with additional metadata.
494///
495/// This macro allows specifying parameter attributes such as:
496/// * `hidden` - Excludes the parameter from the generated schema
497/// * `description` - Adds a description to the parameter in the schema
498///
499/// # Example
500///
501/// ```rust
502/// use mcp_core_macros::{tool, tool_param};
503/// use mcp_core::types::ToolResponseContent;
504/// use mcp_core::tool_text_content;
505/// use anyhow::Result;
506///
507/// #[tool(name = "my_tool", description = "A tool with documented parameters")]
508/// async fn my_tool(
509///     // A required parameter with description
510///     required_param: tool_param!(String, description = "A required parameter"),
511///     
512///     // An optional parameter
513///     optional_param: tool_param!(Option<String>, description = "An optional parameter"),
514///     
515///     // A hidden parameter that won't appear in the schema
516///     internal_param: tool_param!(String, hidden)
517/// ) -> Result<ToolResponseContent> {
518///     // Implementation
519///     Ok(tool_text_content!("Tool executed".to_string()))
520/// }
521/// ```
522#[proc_macro]
523pub fn tool_param(input: TokenStream) -> TokenStream {
524    let args = syn::parse_macro_input!(input as ToolParamArgs);
525    let ty = args.ty;
526
527    TokenStream::from(quote! {
528        #ty
529    })
530}