mcp_macros/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use std::collections::HashMap;
5use syn::{
6    parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated, Expr, ExprLit,
7    FnArg, ItemFn, Lit, Meta, Pat, PatType, Token,
8};
9
10struct MacroArgs {
11    name: Option<String>,
12    description: Option<String>,
13    param_descriptions: HashMap<String, String>,
14}
15
16impl Parse for MacroArgs {
17    fn parse(input: ParseStream) -> syn::Result<Self> {
18        let mut name = None;
19        let mut description = None;
20        let mut param_descriptions = HashMap::new();
21
22        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
23
24        for meta in meta_list {
25            match meta {
26                Meta::NameValue(nv) => {
27                    let ident = nv.path.get_ident().unwrap().to_string();
28                    if let Expr::Lit(ExprLit {
29                        lit: Lit::Str(lit_str),
30                        ..
31                    }) = nv.value
32                    {
33                        match ident.as_str() {
34                            "name" => name = Some(lit_str.value()),
35                            "description" => description = Some(lit_str.value()),
36                            _ => {}
37                        }
38                    }
39                }
40                Meta::List(list) if list.path.is_ident("params") => {
41                    let nested: Punctuated<Meta, Token![,]> =
42                        list.parse_args_with(Punctuated::parse_terminated)?;
43
44                    for meta in nested {
45                        if let Meta::NameValue(nv) = meta {
46                            if let Expr::Lit(ExprLit {
47                                lit: Lit::Str(lit_str),
48                                ..
49                            }) = nv.value
50                            {
51                                let param_name = nv.path.get_ident().unwrap().to_string();
52                                param_descriptions.insert(param_name, lit_str.value());
53                            }
54                        }
55                    }
56                }
57                _ => {}
58            }
59        }
60
61        Ok(MacroArgs {
62            name,
63            description,
64            param_descriptions,
65        })
66    }
67}
68
69#[proc_macro_attribute]
70pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
71    let args = parse_macro_input!(args as MacroArgs);
72    let input_fn = parse_macro_input!(input as ItemFn);
73
74    // Extract function details
75    let fn_name = &input_fn.sig.ident;
76    let fn_name_str = fn_name.to_string();
77
78    // Generate PascalCase struct name from the function name
79    let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
80
81    // Use provided name or function name as default
82    let tool_name = args.name.unwrap_or(fn_name_str);
83    let tool_description = args.description.unwrap_or_default();
84
85    // Extract parameter names, types, and descriptions
86    let mut param_defs = Vec::new();
87    let mut param_names = Vec::new();
88
89    for arg in input_fn.sig.inputs.iter() {
90        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
91            if let Pat::Ident(param_ident) = &**pat {
92                let param_name = &param_ident.ident;
93                let param_name_str = param_name.to_string();
94                let description = args
95                    .param_descriptions
96                    .get(&param_name_str)
97                    .map(|s| s.as_str())
98                    .unwrap_or("");
99
100                param_names.push(param_name);
101                param_defs.push(quote! {
102                    #[schemars(description = #description)]
103                    #param_name: #ty
104                });
105            }
106        }
107    }
108
109    // Generate the implementation
110    let params_struct_name = format_ident!("{}Parameters", struct_name);
111    let expanded = quote! {
112        #[derive(serde::Deserialize, schemars::JsonSchema)]
113        struct #params_struct_name {
114            #(#param_defs,)*
115        }
116
117        #input_fn
118
119        #[derive(Default)]
120        struct #struct_name;
121
122        #[async_trait::async_trait]
123        impl mcp_spec::handler::ToolHandler for #struct_name {
124            fn name(&self) -> &'static str {
125                #tool_name
126            }
127
128            fn description(&self) -> &'static str {
129                #tool_description
130            }
131
132            fn schema(&self) -> serde_json::Value {
133                mcp_spec::handler::generate_schema::<#params_struct_name>()
134                    .expect("Failed to generate schema")
135            }
136
137            async fn call(&self, params: serde_json::Value) -> Result<serde_json::Value, mcp_spec::handler::ToolError> {
138                let params: #params_struct_name = serde_json::from_value(params)
139                    .map_err(|e| mcp_spec::handler::ToolError::InvalidParameters(e.to_string()))?;
140
141                // Extract parameters and call the function
142                let result = #fn_name(#(params.#param_names,)*).await
143                    .map_err(|e| mcp_spec::handler::ToolError::ExecutionError(e.to_string()))?;
144
145                Ok(serde_json::to_value(result).expect("should serialize"))
146
147            }
148        }
149    };
150
151    TokenStream::from(expanded)
152}