rexis_macros/
lib.rs

1//! Procedural macros for RSLLM tool calling
2//!
3//! This crate provides the `#[tool]` and `#[arg]` attribute macros for easy tool definition.
4
5use convert_case::{Case, Casing};
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{
9    parse::{Parse, ParseStream},
10    parse_macro_input, ItemFn, ItemStruct, LitStr, Token,
11};
12
13/// The `#[arg]` attribute for marking individual tool parameters
14///
15/// Usage: `#[arg(description = "Parameter description")]`
16///
17/// This is a marker attribute that the `#[tool]` macro looks for.
18/// It doesn't do anything on its own.
19#[proc_macro_attribute]
20pub fn arg(_args: TokenStream, input: TokenStream) -> TokenStream {
21    // This is a pass-through attribute
22    // The #[tool] macro will process it
23    input
24}
25
26/// The `#[context]` attribute for marking context parameters
27///
28/// Usage: `#[context]`
29///
30/// Context parameters are not included in the LLM schema.
31/// They must be provided at runtime (for dependency injection).
32#[proc_macro_attribute]
33pub fn context(_args: TokenStream, input: TokenStream) -> TokenStream {
34    // This is a pass-through attribute
35    // The #[tool] macro will process it
36    input
37}
38
39/// Arguments for the #[tool] attribute
40#[derive(Debug, Default)]
41struct ToolArgs {
42    /// Tool name (defaults to function/struct name)
43    name: Option<String>,
44
45    /// Tool description
46    description: Option<String>,
47}
48
49/// Custom parser for tool arguments
50impl Parse for ToolArgs {
51    fn parse(input: ParseStream) -> syn::Result<Self> {
52        let mut args = ToolArgs::default();
53
54        // Handle empty attribute like #[tool]
55        if input.is_empty() {
56            return Err(syn::Error::new(
57                proc_macro2::Span::call_site(),
58                "description is required: use #[tool(description = \"your description here\")]",
59            ));
60        }
61
62        while !input.is_empty() {
63            // Parse the identifier (name or description)
64            let ident: syn::Ident = input.parse().map_err(|e| {
65                syn::Error::new(
66                    e.span(),
67                    format!("Expected 'name' or 'description', got parse error: {}", e),
68                )
69            })?;
70
71            let ident_str = ident.to_string();
72
73            // Parse the equals sign
74            let _: Token![=] = input.parse().map_err(|e| {
75                syn::Error::new(
76                    e.span(),
77                    format!(
78                        "Expected '=' after '{}', use syntax: {} = \"...\"",
79                        ident_str, ident_str
80                    ),
81                )
82            })?;
83
84            // Parse the string value
85            let value: LitStr = input.parse().map_err(|e| {
86                syn::Error::new(
87                    e.span(),
88                    format!(
89                        "Expected string literal after '{} =', got parse error: {}",
90                        ident_str, e
91                    ),
92                )
93            })?;
94
95            match ident_str.as_str() {
96                "name" => args.name = Some(value.value()),
97                "description" => args.description = Some(value.value()),
98                _ => {
99                    return Err(syn::Error::new_spanned(
100                        ident,
101                        format!(
102                            "Unknown attribute '{}'. Expected 'name' or 'description'",
103                            ident_str
104                        ),
105                    ))
106                }
107            }
108
109            // Handle optional comma
110            if input.peek(Token![,]) {
111                let _: Token![,] = input.parse()?;
112            }
113        }
114
115        if args.description.is_none() {
116            return Err(syn::Error::new(
117                proc_macro2::Span::call_site(),
118                "description is required: use #[tool(description = \"your description here\")]",
119            ));
120        }
121
122        Ok(args)
123    }
124}
125
126/// The `#[tool]` attribute macro for easy tool definition
127///
128/// # Usage on Functions
129///
130/// ```rust,ignore
131/// #[tool(description = "Adds two numbers")]
132/// fn add_numbers(params: AddParams) -> Result<AddResult, Error> {
133///     Ok(AddResult { sum: params.a + params.b })
134/// }
135/// ```
136///
137/// # Usage on Structs
138///
139/// ```rust,ignore
140/// #[tool(
141///     name = "calculator",
142///     description = "Performs arithmetic operations"
143/// )]
144/// struct Calculator;
145///
146/// impl Calculator {
147///     fn execute(&self, params: CalcParams) -> Result<Value, Error> {
148///         // implementation
149///     }
150/// }
151/// ```
152///
153/// This automatically:
154/// - Generates Tool trait implementation
155/// - Uses SchemaBasedTool for automatic schema generation
156/// - Handles type conversions
157/// - Provides error handling
158#[proc_macro_attribute]
159pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
160    // Parse the tool arguments
161    let tool_args = parse_macro_input!(args as ToolArgs);
162
163    // Try to parse as function first
164    if let Ok(func) = syn::parse::<ItemFn>(input.clone()) {
165        return expand_tool_function(tool_args, func);
166    }
167
168    // Try to parse as struct
169    if let Ok(struct_item) = syn::parse::<ItemStruct>(input) {
170        return expand_tool_struct(tool_args, struct_item);
171    }
172
173    // If neither, return error
174    syn::Error::new(
175        proc_macro2::Span::call_site(),
176        "#[tool] can only be applied to functions or structs",
177    )
178    .to_compile_error()
179    .into()
180}
181
182/// Parameter information extracted from function signature
183#[derive(Debug, Clone)]
184struct ParamInfo {
185    name: syn::Ident,
186    param_type: Box<syn::Type>,
187    description: Option<String>,
188    is_individual: bool, // Has #[arg] attribute
189}
190
191/// Analyze function parameters flexibly
192/// Supports any combination of individual args and struct params in any order
193fn analyze_flexible_parameters(
194    inputs: &syn::punctuated::Punctuated<syn::FnArg, Token![,]>,
195) -> syn::Result<Vec<ParamInfo>> {
196    if inputs.is_empty() {
197        return Err(syn::Error::new_spanned(
198            inputs,
199            "Tool function must have at least one parameter",
200        ));
201    }
202
203    let mut params = Vec::new();
204
205    for input in inputs {
206        if let syn::FnArg::Typed(pat_type) = input {
207            // Check for #[arg(...)] attribute
208            let arg_attr = pat_type
209                .attrs
210                .iter()
211                .find(|attr| attr.path().is_ident("arg"));
212
213            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
214                let param_name = pat_ident.ident.clone();
215                let param_type = pat_type.ty.clone();
216
217                let (is_individual, description) = if let Some(attr) = arg_attr {
218                    // Extract description from #[arg(description = "...")]
219                    let desc =
220                        extract_arg_description(attr).unwrap_or_else(|| param_name.to_string());
221                    (true, Some(desc))
222                } else {
223                    // No #[arg] attribute - treat as struct parameter
224                    (false, None)
225                };
226
227                params.push(ParamInfo {
228                    name: param_name,
229                    param_type,
230                    description,
231                    is_individual,
232                });
233            }
234        }
235    }
236
237    Ok(params)
238}
239
240/// Extract description from #[arg(description = "...")] attribute
241fn extract_arg_description(attr: &syn::Attribute) -> Option<String> {
242    if let syn::Meta::List(list) = &attr.meta {
243        // Try to parse nested meta
244        let mut desc = None;
245        let _ = list.parse_nested_meta(|meta| {
246            if meta.path.is_ident("description") {
247                if let Ok(value) = meta.value() {
248                    if let Ok(lit) = value.parse::<LitStr>() {
249                        desc = Some(lit.value());
250                    }
251                }
252            }
253            Ok(())
254        });
255        desc
256    } else {
257        None
258    }
259}
260
261/// Expand #[tool] on a function
262fn expand_tool_function(args: ToolArgs, func: ItemFn) -> TokenStream {
263    let func_name = func.sig.ident.clone();
264    let tool_name = args.name.unwrap_or_else(|| func_name.to_string());
265    let description = args.description.expect("description is required");
266
267    // Analyze all parameters flexibly
268    let params = match analyze_flexible_parameters(&func.sig.inputs) {
269        Ok(params) => params,
270        Err(e) => return TokenStream::from(e.to_compile_error()),
271    };
272
273    // Separate individual args from struct params
274    let individual_params: Vec<_> = params.iter().filter(|p| p.is_individual).cloned().collect();
275    let struct_params: Vec<_> = params
276        .iter()
277        .filter(|p| !p.is_individual)
278        .cloned()
279        .collect();
280
281    // Generate the appropriate expansion
282    expand_flexible_tool(
283        func,
284        &func_name,
285        &tool_name,
286        &description,
287        individual_params,
288        struct_params,
289    )
290}
291
292/// Expand tool with flexible parameter combination
293/// Supports 0-n individual args + 0-n struct params in any order
294fn expand_flexible_tool(
295    func: ItemFn,
296    func_name: &syn::Ident,
297    tool_name: &str,
298    description: &str,
299    individual_params: Vec<ParamInfo>,
300    struct_params: Vec<ParamInfo>,
301) -> TokenStream {
302    let pascal_name = func_name.to_string().to_case(Case::Pascal);
303    let struct_name = syn::Ident::new(&format!("{}Tool", pascal_name), func_name.span());
304
305    // Case 1: Only struct params (1 or more)
306    if individual_params.is_empty() && struct_params.len() == 1 {
307        // Simple case: single struct parameter (original mode)
308        return expand_single_struct_tool(
309            func,
310            func_name,
311            tool_name,
312            description,
313            struct_params[0].param_type.clone(),
314        );
315    }
316
317    // Case 2: Only individual params (1 or more)
318    if struct_params.is_empty() && !individual_params.is_empty() {
319        return expand_individual_params_tool(
320            func,
321            func_name,
322            tool_name,
323            description,
324            individual_params,
325        );
326    }
327
328    // Case 3: Mixed (both individual and struct params)
329    // Case 4: Multiple struct params
330    // For now, these advanced cases are not fully implemented
331    // Show a helpful error message
332
333    if !individual_params.is_empty() && !struct_params.is_empty() {
334        return syn::Error::new_spanned(
335            &func.sig.inputs,
336            "Mixed parameters (individual + struct) not yet supported.\n\
337             Use either all individual params with #[arg(...)] OR single struct param.",
338        )
339        .to_compile_error()
340        .into();
341    }
342
343    if struct_params.len() > 1 {
344        return syn::Error::new_spanned(
345            &func.sig.inputs,
346            "Multiple struct parameters not yet supported.\n\
347             Combine all params into a single struct.",
348        )
349        .to_compile_error()
350        .into();
351    }
352
353    // Fallback error
354    syn::Error::new_spanned(&func.sig.inputs, "Unable to determine parameter mode")
355        .to_compile_error()
356        .into()
357}
358
359/// Expand tool with single struct parameter (original mode)
360fn expand_single_struct_tool(
361    func: ItemFn,
362    func_name: &syn::Ident,
363    tool_name: &str,
364    description: &str,
365    param_type: Box<syn::Type>,
366) -> TokenStream {
367    // Extract return type (for future validation)
368    let _return_type = match &func.sig.output {
369        syn::ReturnType::Type(_, ty) => ty,
370        _ => {
371            return syn::Error::new_spanned(&func.sig.output, "Tool function must return a Result")
372                .to_compile_error()
373                .into();
374        }
375    };
376
377    // Generate struct name for the tool
378    // Convert snake_case function name to PascalCase + Tool suffix
379    let pascal_name = func_name.to_string().to_case(Case::Pascal);
380    let struct_name = syn::Ident::new(&format!("{}Tool", pascal_name), func_name.span());
381
382    // Generate the implementation
383    let expanded = quote! {
384        #func
385
386        // Generate a struct to implement the tool trait
387        pub struct #struct_name;
388
389        impl ::rsllm::tools::SchemaBasedTool for #struct_name {
390            type Params = #param_type;
391
392            fn name(&self) -> &str {
393                #tool_name
394            }
395
396            fn description(&self) -> &str {
397                #description
398            }
399
400            fn execute_typed(
401                &self,
402                params: Self::Params,
403            ) -> Result<::serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
404                // Call the original function
405                let result = #func_name(params)?;
406
407                // Convert result to JSON
408                ::serde_json::to_value(&result)
409                    .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
410            }
411        }
412    };
413
414    TokenStream::from(expanded)
415}
416
417/// Expand tool with individual parameters only
418fn expand_individual_params_tool(
419    func: ItemFn,
420    func_name: &syn::Ident,
421    tool_name: &str,
422    description: &str,
423    params: Vec<ParamInfo>,
424) -> TokenStream {
425    // Generate a params struct from individual parameters
426    let pascal_name = func_name.to_string().to_case(Case::Pascal);
427    let params_struct_name = syn::Ident::new(&format!("{}Params", pascal_name), func_name.span());
428    let struct_name = syn::Ident::new(&format!("{}Tool", pascal_name), func_name.span());
429
430    // Build struct fields with doc comments
431    let param_fields = params.iter().map(|p| {
432        let name = &p.name;
433        let ty = &p.param_type;
434        let doc = p.description.as_deref().unwrap_or("");
435        quote! {
436            #[doc = #doc]
437            pub #name: #ty
438        }
439    });
440
441    // Build function call arguments in the original order
442    let call_args = params.iter().map(|p| {
443        let name = &p.name;
444        quote! { generated_params.#name }
445    });
446
447    let expanded = quote! {
448        #func
449
450        // Auto-generated params struct
451        #[derive(::schemars::JsonSchema, ::serde::Serialize, ::serde::Deserialize)]
452        pub struct #params_struct_name {
453            #(#param_fields),*
454        }
455
456        // Generate tool struct
457        pub struct #struct_name;
458
459        impl ::rsllm::tools::SchemaBasedTool for #struct_name {
460            type Params = #params_struct_name;
461
462            fn name(&self) -> &str {
463                #tool_name
464            }
465
466            fn description(&self) -> &str {
467                #description
468            }
469
470            fn execute_typed(
471                &self,
472                generated_params: Self::Params,
473            ) -> Result<::serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
474                // Call the original function with unpacked params
475                let result = #func_name(#(#call_args),*)?;
476
477                // Convert result to JSON
478                ::serde_json::to_value(&result)
479                    .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
480            }
481        }
482    };
483
484    TokenStream::from(expanded)
485}
486
487/// Expand #[tool] on a struct
488fn expand_tool_struct(args: ToolArgs, struct_item: ItemStruct) -> TokenStream {
489    let struct_name = &struct_item.ident;
490    let tool_name = args
491        .name
492        .unwrap_or_else(|| struct_name.to_string().to_lowercase());
493    let description = args.description.expect("description is required");
494
495    // For struct, we expect the user to implement an execute method manually
496    // This macro just generates the Tool trait boilerplate
497
498    let expanded = quote! {
499        #struct_item
500
501        // Note: User must implement execute_typed manually
502        // This just provides the name and description
503        impl #struct_name {
504            pub fn tool_name() -> &'static str {
505                #tool_name
506            }
507
508            pub fn tool_description() -> &'static str {
509                #description
510            }
511        }
512    };
513
514    TokenStream::from(expanded)
515}