1use convert_case::{Case, Casing};
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{
9 parse::{Parse, ParseStream},
10 parse_macro_input, ItemFn, ItemStruct, LitStr, Token,
11};
12
13#[proc_macro_attribute]
20pub fn arg(_args: TokenStream, input: TokenStream) -> TokenStream {
21 input
24}
25
26#[proc_macro_attribute]
33pub fn context(_args: TokenStream, input: TokenStream) -> TokenStream {
34 input
37}
38
39#[derive(Debug, Default)]
41struct ToolArgs {
42 name: Option<String>,
44
45 description: Option<String>,
47}
48
49impl Parse for ToolArgs {
51 fn parse(input: ParseStream) -> syn::Result<Self> {
52 let mut args = ToolArgs::default();
53
54 if input.is_empty() {
56 return Err(syn::Error::new(
57 proc_macro2::Span::call_site(),
58 "description is required: use #[tool(description = \"your description here\")]",
59 ));
60 }
61
62 while !input.is_empty() {
63 let ident: syn::Ident = input.parse().map_err(|e| {
65 syn::Error::new(
66 e.span(),
67 format!("Expected 'name' or 'description', got parse error: {}", e),
68 )
69 })?;
70
71 let ident_str = ident.to_string();
72
73 let _: Token![=] = input.parse().map_err(|e| {
75 syn::Error::new(
76 e.span(),
77 format!(
78 "Expected '=' after '{}', use syntax: {} = \"...\"",
79 ident_str, ident_str
80 ),
81 )
82 })?;
83
84 let value: LitStr = input.parse().map_err(|e| {
86 syn::Error::new(
87 e.span(),
88 format!(
89 "Expected string literal after '{} =', got parse error: {}",
90 ident_str, e
91 ),
92 )
93 })?;
94
95 match ident_str.as_str() {
96 "name" => args.name = Some(value.value()),
97 "description" => args.description = Some(value.value()),
98 _ => {
99 return Err(syn::Error::new_spanned(
100 ident,
101 format!(
102 "Unknown attribute '{}'. Expected 'name' or 'description'",
103 ident_str
104 ),
105 ))
106 }
107 }
108
109 if input.peek(Token![,]) {
111 let _: Token![,] = input.parse()?;
112 }
113 }
114
115 if args.description.is_none() {
116 return Err(syn::Error::new(
117 proc_macro2::Span::call_site(),
118 "description is required: use #[tool(description = \"your description here\")]",
119 ));
120 }
121
122 Ok(args)
123 }
124}
125
126#[proc_macro_attribute]
159pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
160 let tool_args = parse_macro_input!(args as ToolArgs);
162
163 if let Ok(func) = syn::parse::<ItemFn>(input.clone()) {
165 return expand_tool_function(tool_args, func);
166 }
167
168 if let Ok(struct_item) = syn::parse::<ItemStruct>(input) {
170 return expand_tool_struct(tool_args, struct_item);
171 }
172
173 syn::Error::new(
175 proc_macro2::Span::call_site(),
176 "#[tool] can only be applied to functions or structs",
177 )
178 .to_compile_error()
179 .into()
180}
181
182#[derive(Debug, Clone)]
184struct ParamInfo {
185 name: syn::Ident,
186 param_type: Box<syn::Type>,
187 description: Option<String>,
188 is_individual: bool, }
190
191fn analyze_flexible_parameters(
194 inputs: &syn::punctuated::Punctuated<syn::FnArg, Token![,]>,
195) -> syn::Result<Vec<ParamInfo>> {
196 if inputs.is_empty() {
197 return Err(syn::Error::new_spanned(
198 inputs,
199 "Tool function must have at least one parameter",
200 ));
201 }
202
203 let mut params = Vec::new();
204
205 for input in inputs {
206 if let syn::FnArg::Typed(pat_type) = input {
207 let arg_attr = pat_type
209 .attrs
210 .iter()
211 .find(|attr| attr.path().is_ident("arg"));
212
213 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
214 let param_name = pat_ident.ident.clone();
215 let param_type = pat_type.ty.clone();
216
217 let (is_individual, description) = if let Some(attr) = arg_attr {
218 let desc =
220 extract_arg_description(attr).unwrap_or_else(|| param_name.to_string());
221 (true, Some(desc))
222 } else {
223 (false, None)
225 };
226
227 params.push(ParamInfo {
228 name: param_name,
229 param_type,
230 description,
231 is_individual,
232 });
233 }
234 }
235 }
236
237 Ok(params)
238}
239
240fn extract_arg_description(attr: &syn::Attribute) -> Option<String> {
242 if let syn::Meta::List(list) = &attr.meta {
243 let mut desc = None;
245 let _ = list.parse_nested_meta(|meta| {
246 if meta.path.is_ident("description") {
247 if let Ok(value) = meta.value() {
248 if let Ok(lit) = value.parse::<LitStr>() {
249 desc = Some(lit.value());
250 }
251 }
252 }
253 Ok(())
254 });
255 desc
256 } else {
257 None
258 }
259}
260
261fn expand_tool_function(args: ToolArgs, func: ItemFn) -> TokenStream {
263 let func_name = func.sig.ident.clone();
264 let tool_name = args.name.unwrap_or_else(|| func_name.to_string());
265 let description = args.description.expect("description is required");
266
267 let params = match analyze_flexible_parameters(&func.sig.inputs) {
269 Ok(params) => params,
270 Err(e) => return TokenStream::from(e.to_compile_error()),
271 };
272
273 let individual_params: Vec<_> = params.iter().filter(|p| p.is_individual).cloned().collect();
275 let struct_params: Vec<_> = params
276 .iter()
277 .filter(|p| !p.is_individual)
278 .cloned()
279 .collect();
280
281 expand_flexible_tool(
283 func,
284 &func_name,
285 &tool_name,
286 &description,
287 individual_params,
288 struct_params,
289 )
290}
291
292fn expand_flexible_tool(
295 func: ItemFn,
296 func_name: &syn::Ident,
297 tool_name: &str,
298 description: &str,
299 individual_params: Vec<ParamInfo>,
300 struct_params: Vec<ParamInfo>,
301) -> TokenStream {
302 let pascal_name = func_name.to_string().to_case(Case::Pascal);
303 let struct_name = syn::Ident::new(&format!("{}Tool", pascal_name), func_name.span());
304
305 if individual_params.is_empty() && struct_params.len() == 1 {
307 return expand_single_struct_tool(
309 func,
310 func_name,
311 tool_name,
312 description,
313 struct_params[0].param_type.clone(),
314 );
315 }
316
317 if struct_params.is_empty() && !individual_params.is_empty() {
319 return expand_individual_params_tool(
320 func,
321 func_name,
322 tool_name,
323 description,
324 individual_params,
325 );
326 }
327
328 if !individual_params.is_empty() && !struct_params.is_empty() {
334 return syn::Error::new_spanned(
335 &func.sig.inputs,
336 "Mixed parameters (individual + struct) not yet supported.\n\
337 Use either all individual params with #[arg(...)] OR single struct param.",
338 )
339 .to_compile_error()
340 .into();
341 }
342
343 if struct_params.len() > 1 {
344 return syn::Error::new_spanned(
345 &func.sig.inputs,
346 "Multiple struct parameters not yet supported.\n\
347 Combine all params into a single struct.",
348 )
349 .to_compile_error()
350 .into();
351 }
352
353 syn::Error::new_spanned(&func.sig.inputs, "Unable to determine parameter mode")
355 .to_compile_error()
356 .into()
357}
358
359fn expand_single_struct_tool(
361 func: ItemFn,
362 func_name: &syn::Ident,
363 tool_name: &str,
364 description: &str,
365 param_type: Box<syn::Type>,
366) -> TokenStream {
367 let _return_type = match &func.sig.output {
369 syn::ReturnType::Type(_, ty) => ty,
370 _ => {
371 return syn::Error::new_spanned(&func.sig.output, "Tool function must return a Result")
372 .to_compile_error()
373 .into();
374 }
375 };
376
377 let pascal_name = func_name.to_string().to_case(Case::Pascal);
380 let struct_name = syn::Ident::new(&format!("{}Tool", pascal_name), func_name.span());
381
382 let expanded = quote! {
384 #func
385
386 pub struct #struct_name;
388
389 impl ::rsllm::tools::SchemaBasedTool for #struct_name {
390 type Params = #param_type;
391
392 fn name(&self) -> &str {
393 #tool_name
394 }
395
396 fn description(&self) -> &str {
397 #description
398 }
399
400 fn execute_typed(
401 &self,
402 params: Self::Params,
403 ) -> Result<::serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
404 let result = #func_name(params)?;
406
407 ::serde_json::to_value(&result)
409 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
410 }
411 }
412 };
413
414 TokenStream::from(expanded)
415}
416
417fn expand_individual_params_tool(
419 func: ItemFn,
420 func_name: &syn::Ident,
421 tool_name: &str,
422 description: &str,
423 params: Vec<ParamInfo>,
424) -> TokenStream {
425 let pascal_name = func_name.to_string().to_case(Case::Pascal);
427 let params_struct_name = syn::Ident::new(&format!("{}Params", pascal_name), func_name.span());
428 let struct_name = syn::Ident::new(&format!("{}Tool", pascal_name), func_name.span());
429
430 let param_fields = params.iter().map(|p| {
432 let name = &p.name;
433 let ty = &p.param_type;
434 let doc = p.description.as_deref().unwrap_or("");
435 quote! {
436 #[doc = #doc]
437 pub #name: #ty
438 }
439 });
440
441 let call_args = params.iter().map(|p| {
443 let name = &p.name;
444 quote! { generated_params.#name }
445 });
446
447 let expanded = quote! {
448 #func
449
450 #[derive(::schemars::JsonSchema, ::serde::Serialize, ::serde::Deserialize)]
452 pub struct #params_struct_name {
453 #(#param_fields),*
454 }
455
456 pub struct #struct_name;
458
459 impl ::rsllm::tools::SchemaBasedTool for #struct_name {
460 type Params = #params_struct_name;
461
462 fn name(&self) -> &str {
463 #tool_name
464 }
465
466 fn description(&self) -> &str {
467 #description
468 }
469
470 fn execute_typed(
471 &self,
472 generated_params: Self::Params,
473 ) -> Result<::serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
474 let result = #func_name(#(#call_args),*)?;
476
477 ::serde_json::to_value(&result)
479 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
480 }
481 }
482 };
483
484 TokenStream::from(expanded)
485}
486
487fn expand_tool_struct(args: ToolArgs, struct_item: ItemStruct) -> TokenStream {
489 let struct_name = &struct_item.ident;
490 let tool_name = args
491 .name
492 .unwrap_or_else(|| struct_name.to_string().to_lowercase());
493 let description = args.description.expect("description is required");
494
495 let expanded = quote! {
499 #struct_item
500
501 impl #struct_name {
504 pub fn tool_name() -> &'static str {
505 #tool_name
506 }
507
508 pub fn tool_description() -> &'static str {
509 #description
510 }
511 }
512 };
513
514 TokenStream::from(expanded)
515}