1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use std::collections::HashMap;
5use syn::{
6    parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated, Expr, ExprLit,
7    FnArg, ItemFn, Lit, Meta, Pat, PatType, Token, Type,
8};
9
10struct MacroArgs {
11    name: Option<String>,
12    description: Option<String>,
13    param_descriptions: HashMap<String, String>,
14    annotations: ToolAnnotations,
15}
16
17struct ToolAnnotations {
18    title: Option<String>,
19    read_only_hint: Option<bool>,
20    destructive_hint: Option<bool>,
21    idempotent_hint: Option<bool>,
22    open_world_hint: Option<bool>,
23}
24
25impl Default for ToolAnnotations {
26    fn default() -> Self {
27        Self {
28            title: None,
29            read_only_hint: None,
30            destructive_hint: None,
31            idempotent_hint: None,
32            open_world_hint: None,
33        }
34    }
35}
36
37impl Parse for MacroArgs {
38    fn parse(input: ParseStream) -> syn::Result<Self> {
39        let mut name = None;
40        let mut description = None;
41        let mut param_descriptions = HashMap::new();
42        let mut annotations = ToolAnnotations::default();
43
44        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
45
46        for meta in meta_list {
47            match meta {
48                Meta::NameValue(nv) => {
49                    let ident = nv.path.get_ident().unwrap().to_string();
50                    if let Expr::Lit(ExprLit {
51                        lit: Lit::Str(lit_str),
52                        ..
53                    }) = nv.value
54                    {
55                        match ident.as_str() {
56                            "name" => name = Some(lit_str.value()),
57                            "description" => description = Some(lit_str.value()),
58                            "title" => annotations.title = Some(lit_str.value()),
59                            _ => {}
60                        }
61                    } else if let Expr::Lit(ExprLit {
62                        lit: Lit::Bool(lit_bool),
63                        ..
64                    }) = nv.value
65                    {
66                        match ident.as_str() {
67                            "read_only_hint" | "readOnlyHint" => {
68                                annotations.read_only_hint = Some(lit_bool.value)
69                            }
70                            "destructive_hint" | "destructiveHint" => {
71                                annotations.destructive_hint = Some(lit_bool.value)
72                            }
73                            "idempotent_hint" | "idempotentHint" => {
74                                annotations.idempotent_hint = Some(lit_bool.value)
75                            }
76                            "open_world_hint" | "openWorldHint" => {
77                                annotations.open_world_hint = Some(lit_bool.value)
78                            }
79                            _ => {}
80                        }
81                    }
82                }
83                Meta::List(list) if list.path.is_ident("params") => {
84                    let nested: Punctuated<Meta, Token![,]> =
85                        list.parse_args_with(Punctuated::parse_terminated)?;
86
87                    for meta in nested {
88                        if let Meta::NameValue(nv) = meta {
89                            if let Expr::Lit(ExprLit {
90                                lit: Lit::Str(lit_str),
91                                ..
92                            }) = nv.value
93                            {
94                                let param_name = nv.path.get_ident().unwrap().to_string();
95                                param_descriptions.insert(param_name, lit_str.value());
96                            }
97                        }
98                    }
99                }
100                Meta::List(list) if list.path.is_ident("annotations") => {
101                    let nested: Punctuated<Meta, Token![,]> =
102                        list.parse_args_with(Punctuated::parse_terminated)?;
103
104                    for meta in nested {
105                        if let Meta::NameValue(nv) = meta {
106                            let key = nv.path.get_ident().unwrap().to_string();
107
108                            if let Expr::Lit(ExprLit {
109                                lit: Lit::Str(lit_str),
110                                ..
111                            }) = nv.value
112                            {
113                                if key == "title" {
114                                    annotations.title = Some(lit_str.value());
115                                }
116                            } else if let Expr::Lit(ExprLit {
117                                lit: Lit::Bool(lit_bool),
118                                ..
119                            }) = nv.value
120                            {
121                                match key.as_str() {
122                                    "read_only_hint" | "readOnlyHint" => {
123                                        annotations.read_only_hint = Some(lit_bool.value)
124                                    }
125                                    "destructive_hint" | "destructiveHint" => {
126                                        annotations.destructive_hint = Some(lit_bool.value)
127                                    }
128                                    "idempotent_hint" | "idempotentHint" => {
129                                        annotations.idempotent_hint = Some(lit_bool.value)
130                                    }
131                                    "open_world_hint" | "openWorldHint" => {
132                                        annotations.open_world_hint = Some(lit_bool.value)
133                                    }
134                                    _ => {}
135                                }
136                            }
137                        }
138                    }
139                }
140                _ => {}
141            }
142        }
143
144        Ok(MacroArgs {
145            name,
146            description,
147            param_descriptions,
148            annotations,
149        })
150    }
151}
152
153#[proc_macro_attribute]
154pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
155    let args = parse_macro_input!(args as MacroArgs);
156    let input_fn = parse_macro_input!(input as ItemFn);
157
158    let fn_name = &input_fn.sig.ident;
159    let fn_name_str = fn_name.to_string();
160    let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
161    let tool_name = args.name.unwrap_or(fn_name_str);
162    let tool_description = args.description.unwrap_or_default();
163
164    let title = args.annotations.title.unwrap_or_else(|| tool_name.clone());
166    let read_only_hint = args.annotations.read_only_hint.unwrap_or(false);
167    let destructive_hint = args.annotations.destructive_hint.unwrap_or(true);
168    let idempotent_hint = args.annotations.idempotent_hint.unwrap_or(false);
169    let open_world_hint = args.annotations.open_world_hint.unwrap_or(true);
170
171    let mut param_defs = Vec::new();
172    let mut param_names = Vec::new();
173    let mut required_params = Vec::new();
174
175    for arg in input_fn.sig.inputs.iter() {
176        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
177            if let Pat::Ident(param_ident) = &**pat {
178                let param_name = ¶m_ident.ident;
179                let param_name_str = param_name.to_string();
180                let description = args
181                    .param_descriptions
182                    .get(¶m_name_str)
183                    .map(|s| s.as_str())
184                    .unwrap_or("");
185
186                param_names.push(param_name);
187
188                let is_optional = if let Type::Path(type_path) = &**ty {
190                    type_path
191                        .path
192                        .segments
193                        .last()
194                        .map_or(false, |segment| segment.ident == "Option")
195                } else {
196                    false
197                };
198
199                if !is_optional {
200                    required_params.push(param_name_str.clone());
201                }
202
203                param_defs.push(quote! {
204                    #[schemars(description = #description)]
205                    #param_name: #ty
206                });
207            }
208        }
209    }
210
211    let params_struct_name = format_ident!("{}Parameters", struct_name);
212    let expanded = quote! {
213        #[derive(serde::Deserialize, schemars::JsonSchema)]
214        struct #params_struct_name {
215            #(#param_defs,)*
216        }
217
218        #input_fn
219
220        #[derive(Default)]
221        pub struct #struct_name;
222
223        impl #struct_name {
224            pub fn tool() -> mcp_core::types::Tool {
225                let schema = schemars::schema_for!(#params_struct_name);
226                let mut schema = serde_json::to_value(schema.schema).unwrap_or_default();
227                if let serde_json::Value::Object(ref mut map) = schema {
228                    map.insert("required".to_string(), serde_json::Value::Array(
230                        vec![#(serde_json::Value::String(#required_params.to_string())),*]
231                    ));
232                    map.remove("title");
233
234                    if let Some(serde_json::Value::Object(props)) = map.get_mut("properties") {
236                        for (_name, prop) in props.iter_mut() {
237                            if let serde_json::Value::Object(prop_obj) = prop {
238                                if let Some(type_val) = prop_obj.get("type") {
240                                    if type_val == "integer" || prop_obj.contains_key("format") {
241                                        prop_obj.insert("type".to_string(), serde_json::Value::String("number".to_string()));
243                                        prop_obj.remove("format");
245                                        prop_obj.remove("minimum");
246                                    }
247                                }
248
249                                if let Some(serde_json::Value::Array(types)) = prop_obj.get("type") {
251                                    if types.len() == 2 && types.contains(&serde_json::Value::String("null".to_string())) {
252                                        let main_type = types.iter()
254                                            .find(|&t| t != &serde_json::Value::String("null".to_string()))
255                                            .cloned()
256                                            .unwrap_or(serde_json::Value::String("string".to_string()));
257
258                                        prop_obj.insert("type".to_string(), main_type);
260                                    }
261                                }
262                            }
263                        }
264                    }
265                }
266
267                let annotations = serde_json::json!({
269                    "title": #title,
270                    "readOnlyHint": #read_only_hint,
271                    "destructiveHint": #destructive_hint,
272                    "idempotentHint": #idempotent_hint,
273                    "openWorldHint": #open_world_hint
274                });
275
276                mcp_core::types::Tool {
277                    name: #tool_name.to_string(),
278                    description: Some(#tool_description.to_string()),
279                    input_schema: schema,
280                    annotations: Some(annotations),
281                }
282            }
283
284            pub fn call() -> mcp_core::tools::ToolHandlerFn {
285                move |req: mcp_core::types::CallToolRequest| {
286                    Box::pin(async move {
287                        let params = match req.arguments {
288                            Some(args) => serde_json::to_value(args).unwrap_or_default(),
289                            None => serde_json::Value::Null,
290                        };
291
292                        let params: #params_struct_name = match serde_json::from_value(params) {
293                            Ok(p) => p,
294                            Err(e) => return mcp_core::types::CallToolResponse {
295                                content: vec![mcp_core::types::ToolResponseContent::Text {
296                                    text: format!("Invalid parameters: {}", e)
297                                }],
298                                is_error: Some(true),
299                                meta: None,
300                            },
301                        };
302
303                        match #fn_name(#(params.#param_names,)*).await {
304                            Ok(response) => {
305                                let content = if let Ok(vec_content) = serde_json::from_value::<Vec<mcp_core::types::ToolResponseContent>>(
306                                    serde_json::to_value(&response).unwrap_or_default()
307                                ) {
308                                    vec_content
309                                } else if let Ok(single_content) = serde_json::from_value::<mcp_core::types::ToolResponseContent>(
310                                    serde_json::to_value(&response).unwrap_or_default()
311                                ) {
312                                    vec![single_content]
313                                } else {
314                                    vec![mcp_core::types::ToolResponseContent::Text {
315                                        text: format!("Invalid response type: {:?}", response)
316                                    }]
317                                };
318
319                                mcp_core::types::CallToolResponse {
320                                    content,
321                                    is_error: None,
322                                    meta: None,
323                                }
324                            },
325                            Err(e) => mcp_core::types::CallToolResponse {
326                                content: vec![mcp_core::types::ToolResponseContent::Text {
327                                    text: format!("Tool execution error: {}", e)
328                                }],
329                                is_error: Some(true),
330                                meta: None,
331                            },
332                        }
333                    })
334                }
335            }
336        }
337    };
338
339    TokenStream::from(expanded)
340}