1use convert_case::{Case, Casing};
8use proc_macro::TokenStream;
9use quote::{format_ident, quote};
10use syn::{
11 parse::{Parse, ParseStream},
12 punctuated::Punctuated,
13 Expr, ExprLit, FnArg, ItemFn, Lit, Meta, Pat, PatType, Token, Type,
14};
15
16#[derive(Debug)]
17struct ToolArgs {
18 name: Option<String>,
19 description: Option<String>,
20 annotations: ToolAnnotations,
21}
22
23#[derive(Debug)]
24struct ToolAnnotations {
25 title: Option<String>,
26 read_only_hint: Option<bool>,
27 destructive_hint: Option<bool>,
28 idempotent_hint: Option<bool>,
29 open_world_hint: Option<bool>,
30}
31
32impl Default for ToolAnnotations {
33 fn default() -> Self {
34 Self {
35 title: None,
36 read_only_hint: None,
37 destructive_hint: None,
38 idempotent_hint: None,
39 open_world_hint: None,
40 }
41 }
42}
43
44impl Parse for ToolArgs {
45 fn parse(input: ParseStream) -> syn::Result<Self> {
46 let mut name = None;
47 let mut description = None;
48 let mut annotations = ToolAnnotations::default();
49
50 let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
51
52 for meta in meta_list {
53 match meta {
54 Meta::NameValue(nv) => {
55 let ident = nv.path.get_ident().unwrap().to_string();
56 if let Expr::Lit(ExprLit {
57 lit: Lit::Str(lit_str),
58 ..
59 }) = nv.value
60 {
61 match ident.as_str() {
62 "name" => name = Some(lit_str.value()),
63 "description" => description = Some(lit_str.value()),
64 _ => {
65 return Err(syn::Error::new_spanned(
66 nv.path,
67 format!("Unknown attribute: {}", ident),
68 ))
69 }
70 }
71 } else {
72 return Err(syn::Error::new_spanned(nv.value, "Expected string literal"));
73 }
74 }
75 Meta::List(list) if list.path.is_ident("annotations") => {
76 let nested: Punctuated<Meta, Token![,]> =
77 list.parse_args_with(Punctuated::parse_terminated)?;
78
79 for meta in nested {
80 if let Meta::NameValue(nv) = meta {
81 let key = nv.path.get_ident().unwrap().to_string();
82
83 if let Expr::Lit(ExprLit {
84 lit: Lit::Str(lit_str),
85 ..
86 }) = nv.value
87 {
88 if key == "title" {
89 annotations.title = Some(lit_str.value());
90 } else {
91 return Err(syn::Error::new_spanned(
92 nv.path,
93 format!("Unknown string annotation: {}", key),
94 ));
95 }
96 } else if let Expr::Lit(ExprLit {
97 lit: Lit::Bool(lit_bool),
98 ..
99 }) = nv.value
100 {
101 match key.as_str() {
102 "read_only_hint" | "readOnlyHint" => {
103 annotations.read_only_hint = Some(lit_bool.value)
104 }
105 "destructive_hint" | "destructiveHint" => {
106 annotations.destructive_hint = Some(lit_bool.value)
107 }
108 "idempotent_hint" | "idempotentHint" => {
109 annotations.idempotent_hint = Some(lit_bool.value)
110 }
111 "open_world_hint" | "openWorldHint" => {
112 annotations.open_world_hint = Some(lit_bool.value)
113 }
114 _ => {
115 return Err(syn::Error::new_spanned(
116 nv.path,
117 format!("Unknown boolean annotation: {}", key),
118 ))
119 }
120 }
121 } else {
122 return Err(syn::Error::new_spanned(
123 nv.value,
124 "Expected string or boolean literal for annotation value",
125 ));
126 }
127 } else {
128 return Err(syn::Error::new_spanned(
129 meta,
130 "Expected name-value pair for annotation",
131 ));
132 }
133 }
134 }
135 _ => {
136 return Err(syn::Error::new_spanned(
137 meta,
138 "Expected name-value pair or list",
139 ))
140 }
141 }
142 }
143
144 Ok(ToolArgs {
145 name,
146 description,
147 annotations,
148 })
149 }
150}
151
152#[proc_macro_attribute]
193pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
194 let args = match syn::parse::<ToolArgs>(args) {
195 Ok(args) => args,
196 Err(e) => return e.to_compile_error().into(),
197 };
198
199 let input_fn = match syn::parse::<ItemFn>(input.clone()) {
200 Ok(input_fn) => input_fn,
201 Err(e) => return e.to_compile_error().into(),
202 };
203
204 let fn_name = &input_fn.sig.ident;
205 let fn_name_str = fn_name.to_string();
206 let struct_name = format_ident!("{}", fn_name_str.to_case(Case::Pascal));
207 let tool_name = args.name.unwrap_or(fn_name_str.clone());
208 let tool_description = args.description.unwrap_or_default();
209
210 let title = args.annotations.title.unwrap_or(fn_name_str.clone());
212 let read_only_hint = args.annotations.read_only_hint.unwrap_or(false);
213 let destructive_hint = args.annotations.destructive_hint.unwrap_or(true);
214 let idempotent_hint = args.annotations.idempotent_hint.unwrap_or(false);
215 let open_world_hint = args.annotations.open_world_hint.unwrap_or(true);
216
217 let mut param_defs = Vec::new();
218 let mut param_names = Vec::new();
219 let mut required_params = Vec::new();
220 let mut hidden_params: Vec<String> = Vec::new();
221 let mut param_descriptions = Vec::new();
222
223 for arg in input_fn.sig.inputs.iter() {
224 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
225 let mut is_hidden = false;
226 let mut description: Option<String> = None;
227 let mut is_optional = false;
228
229 if let Type::Macro(type_macro) = &**ty {
231 if let Some(ident) = type_macro.mac.path.get_ident() {
232 if ident == "tool_param" {
233 if let Ok(args) =
234 syn::parse2::<ToolParamArgs>(type_macro.mac.tokens.clone())
235 {
236 is_hidden = args.hidden;
237 description = args.description;
238
239 if let Type::Path(type_path) = &args.ty {
241 is_optional = type_path
242 .path
243 .segments
244 .last()
245 .map_or(false, |segment| segment.ident == "Option");
246 }
247 }
248 }
249 }
250 }
251
252 if is_hidden {
253 if let Pat::Ident(ident) = &**pat {
254 hidden_params.push(ident.ident.to_string());
255 }
256 }
257
258 if let Pat::Ident(param_ident) = &**pat {
259 let param_name = ¶m_ident.ident;
260 let param_name_str = param_name.to_string();
261
262 param_names.push(param_name.clone());
263
264 if !is_optional {
266 is_optional = if let Type::Path(type_path) = &**ty {
267 type_path
268 .path
269 .segments
270 .last()
271 .map_or(false, |segment| segment.ident == "Option")
272 } else {
273 false
274 }
275 }
276
277 if !is_optional && !is_hidden {
279 required_params.push(param_name_str.clone());
280 }
281
282 if let Some(desc) = description {
283 param_descriptions.push(quote! {
284 if name == #param_name_str {
285 prop_obj.insert("description".to_string(), serde_json::Value::String(#desc.to_string()));
286 }
287 });
288 }
289
290 param_defs.push(quote! {
291 #param_name: #ty
292 });
293 }
294 }
295 }
296
297 let params_struct_name = format_ident!("{}Parameters", struct_name);
298 let expanded = quote! {
299 #[derive(serde::Deserialize, schemars::JsonSchema)]
300 struct #params_struct_name {
301 #(#param_defs,)*
302 }
303
304 #input_fn
305
306 #[derive(Default)]
307 pub struct #struct_name;
308
309 impl #struct_name {
310 pub fn tool() -> mcp_core::types::Tool {
311 let schema = schemars::schema_for!(#params_struct_name);
312 let mut schema = serde_json::to_value(schema.schema).unwrap_or_default();
313 if let serde_json::Value::Object(ref mut map) = schema {
314 map.insert("required".to_string(), serde_json::Value::Array(
316 vec![#(serde_json::Value::String(#required_params.to_string())),*]
317 ));
318 map.remove("title");
319
320 if let Some(serde_json::Value::Object(props)) = map.get_mut("properties") {
322 for (name, prop) in props.iter_mut() {
323 if let serde_json::Value::Object(prop_obj) = prop {
324 if let Some(type_val) = prop_obj.get("type") {
326 if type_val == "integer" || type_val == "number" || prop_obj.contains_key("format") {
327 prop_obj.insert("type".to_string(), serde_json::Value::String("number".to_string()));
329 prop_obj.remove("format");
330 prop_obj.remove("minimum");
331 prop_obj.remove("maximum");
332 }
333 }
334
335 if let Some(serde_json::Value::Array(types)) = prop_obj.get("type") {
337 if types.len() == 2 && types.contains(&serde_json::Value::String("null".to_string())) {
338 let mut main_type = types.iter()
339 .find(|&t| t != &serde_json::Value::String("null".to_string()))
340 .cloned()
341 .unwrap_or(serde_json::Value::String("string".to_string()));
342
343 if main_type == serde_json::Value::String("integer".to_string()) {
345 main_type = serde_json::Value::String("number".to_string());
346 }
347
348 prop_obj.insert("type".to_string(), main_type);
349 }
350 }
351
352 #(#param_descriptions)*
354 }
355 }
356
357 #(props.remove(#hidden_params);)*
358 }
359 }
360
361 let annotations = serde_json::json!({
362 "title": #title,
363 "readOnlyHint": #read_only_hint,
364 "destructiveHint": #destructive_hint,
365 "idempotentHint": #idempotent_hint,
366 "openWorldHint": #open_world_hint
367 });
368
369 mcp_core::types::Tool {
370 name: #tool_name.to_string(),
371 description: Some(#tool_description.to_string()),
372 input_schema: schema,
373 annotations: Some(mcp_core::types::ToolAnnotations {
374 title: Some(#title.to_string()),
375 read_only_hint: Some(#read_only_hint),
376 destructive_hint: Some(#destructive_hint),
377 idempotent_hint: Some(#idempotent_hint),
378 open_world_hint: Some(#open_world_hint),
379 }),
380 }
381 }
382
383 pub fn call() -> mcp_core::tools::ToolHandlerFn {
384 move |req: mcp_core::types::CallToolRequest| {
385 Box::pin(async move {
386 let params = match req.arguments {
387 Some(args) => serde_json::to_value(args).unwrap_or_default(),
388 None => serde_json::Value::Null,
389 };
390
391 let params: #params_struct_name = match serde_json::from_value(params) {
392 Ok(p) => p,
393 Err(e) => return mcp_core::types::CallToolResponse {
394 content: vec![mcp_core::types::ToolResponseContent::Text(
395 mcp_core::types::TextContent {
396 content_type: "text".to_string(),
397 text: format!("Invalid parameters: {}", e),
398 annotations: None,
399 }
400 )],
401 is_error: Some(true),
402 meta: None,
403 },
404 };
405
406 match #fn_name(#(params.#param_names,)*).await {
407 Ok(response) => {
408 let content = if let Ok(vec_content) = serde_json::from_value::<Vec<mcp_core::types::ToolResponseContent>>(serde_json::to_value(&response).unwrap_or_default()) {
409 vec_content
410 } else if let Ok(single_content) = serde_json::from_value::<mcp_core::types::ToolResponseContent>(serde_json::to_value(&response).unwrap_or_default()) {
411 vec![single_content]
412 } else {
413 vec![mcp_core::types::ToolResponseContent::Text(
414 mcp_core::types::TextContent {
415 content_type: "text".to_string(),
416 text: format!("Invalid response type: {:?}", response),
417 annotations: None,
418 }
419 )]
420 };
421
422 mcp_core::types::CallToolResponse {
423 content,
424 is_error: None,
425 meta: None,
426 }
427 }
428 Err(e) => mcp_core::types::CallToolResponse {
429 content: vec![mcp_core::types::ToolResponseContent::Text(
430 mcp_core::types::TextContent {
431 content_type: "text".to_string(),
432 text: format!("Tool execution error: {}", e),
433 annotations: None,
434 }
435 )],
436 is_error: Some(true),
437 meta: None,
438 },
439 }
440 })
441 }
442 }
443 }
444 };
445
446 TokenStream::from(expanded)
447}
448
449#[derive(Debug)]
450struct ToolParamArgs {
451 ty: Type,
452 hidden: bool,
453 description: Option<String>,
454}
455
456impl Parse for ToolParamArgs {
457 fn parse(input: ParseStream) -> syn::Result<Self> {
458 let mut hidden = false;
459 let mut description = None;
460 let ty = input.parse()?;
461
462 if input.peek(Token![,]) {
463 input.parse::<Token![,]>()?;
464 let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
465
466 for meta in meta_list {
467 match meta {
468 Meta::Path(path) if path.is_ident("hidden") => {
469 hidden = true;
470 }
471 Meta::NameValue(nv) if nv.path.is_ident("description") => {
472 if let Expr::Lit(ExprLit {
473 lit: Lit::Str(lit_str),
474 ..
475 }) = &nv.value
476 {
477 description = Some(lit_str.value().to_string());
478 }
479 }
480 _ => {}
481 }
482 }
483 }
484
485 Ok(ToolParamArgs {
486 ty,
487 hidden,
488 description,
489 })
490 }
491}
492
493#[proc_macro]
523pub fn tool_param(input: TokenStream) -> TokenStream {
524 let args = syn::parse_macro_input!(input as ToolParamArgs);
525 let ty = args.ty;
526
527 TokenStream::from(quote! {
528 #ty
529 })
530}