1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use std::collections::HashMap;
5use syn::{
6 parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated, Expr, ExprLit,
7 FnArg, ItemFn, Lit, Meta, Pat, PatType, Token,
8};
9
10struct MacroArgs {
11 name: Option<String>,
12 description: Option<String>,
13 param_descriptions: HashMap<String, String>,
14}
15
16impl Parse for MacroArgs {
17 fn parse(input: ParseStream) -> syn::Result<Self> {
18 let mut name = None;
19 let mut description = None;
20 let mut param_descriptions = HashMap::new();
21
22 let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
23
24 for meta in meta_list {
25 match meta {
26 Meta::NameValue(nv) => {
27 let ident = nv.path.get_ident().unwrap().to_string();
28 if let Expr::Lit(ExprLit {
29 lit: Lit::Str(lit_str),
30 ..
31 }) = nv.value
32 {
33 match ident.as_str() {
34 "name" => name = Some(lit_str.value()),
35 "description" => description = Some(lit_str.value()),
36 _ => {}
37 }
38 }
39 }
40 Meta::List(list) if list.path.is_ident("params") => {
41 let nested: Punctuated<Meta, Token![,]> =
42 list.parse_args_with(Punctuated::parse_terminated)?;
43
44 for meta in nested {
45 if let Meta::NameValue(nv) = meta {
46 if let Expr::Lit(ExprLit {
47 lit: Lit::Str(lit_str),
48 ..
49 }) = nv.value
50 {
51 let param_name = nv.path.get_ident().unwrap().to_string();
52 param_descriptions.insert(param_name, lit_str.value());
53 }
54 }
55 }
56 }
57 _ => {}
58 }
59 }
60
61 Ok(MacroArgs {
62 name,
63 description,
64 param_descriptions,
65 })
66 }
67}
68
69#[proc_macro_attribute]
70pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
71 let args = parse_macro_input!(args as MacroArgs);
72 let input_fn = parse_macro_input!(input as ItemFn);
73
74 let fn_name = &input_fn.sig.ident;
76 let fn_name_str = fn_name.to_string();
77
78 let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
80
81 let tool_name = args.name.unwrap_or(fn_name_str);
83 let tool_description = args.description.unwrap_or_default();
84
85 let mut param_defs = Vec::new();
87 let mut param_names = Vec::new();
88
89 for arg in input_fn.sig.inputs.iter() {
90 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
91 if let Pat::Ident(param_ident) = &**pat {
92 let param_name = ¶m_ident.ident;
93 let param_name_str = param_name.to_string();
94 let description = args
95 .param_descriptions
96 .get(¶m_name_str)
97 .map(|s| s.as_str())
98 .unwrap_or("");
99
100 param_names.push(param_name);
101 param_defs.push(quote! {
102 #[schemars(description = #description)]
103 #param_name: #ty
104 });
105 }
106 }
107 }
108
109 let params_struct_name = format_ident!("{}Parameters", struct_name);
111 let expanded = quote! {
112 #[derive(serde::Deserialize, schemars::JsonSchema)]
113 struct #params_struct_name {
114 #(#param_defs,)*
115 }
116
117 #input_fn
118
119 #[derive(Default)]
120 struct #struct_name;
121
122 #[async_trait::async_trait]
123 impl mcp_spec::handler::ToolHandler for #struct_name {
124 fn name(&self) -> &'static str {
125 #tool_name
126 }
127
128 fn description(&self) -> &'static str {
129 #tool_description
130 }
131
132 fn schema(&self) -> serde_json::Value {
133 mcp_spec::handler::generate_schema::<#params_struct_name>()
134 .expect("Failed to generate schema")
135 }
136
137 async fn call(&self, params: serde_json::Value) -> Result<serde_json::Value, mcp_spec::handler::ToolError> {
138 let params: #params_struct_name = serde_json::from_value(params)
139 .map_err(|e| mcp_spec::handler::ToolError::InvalidParameters(e.to_string()))?;
140
141 let result = #fn_name(#(params.#param_names,)*).await
143 .map_err(|e| mcp_spec::handler::ToolError::ExecutionError(e.to_string()))?;
144
145 Ok(serde_json::to_value(result).expect("should serialize"))
146
147 }
148 }
149 };
150
151 TokenStream::from(expanded)
152}