Skip to main content

resonai_macros/
lib.rs

1//! Reson procedural macros
2//!
3//! Provides ergonomic decorators for agentic functions and tool definitions.
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::parse::{Parse, ParseStream};
8use syn::{parse_macro_input, FnArg, ItemFn, Lit, Pat, PatType, Token};
9
10// Helper struct to parse macro attributes
11struct AgenticArgs {
12    model: Option<String>,
13    api_key: Option<String>,
14    autobind: bool,
15    native_tools: bool,
16}
17
18impl Parse for AgenticArgs {
19    fn parse(input: ParseStream) -> syn::Result<Self> {
20        let mut model = None;
21        let mut api_key = None;
22        let mut autobind = true;
23        let mut native_tools = true; // Default true like Python
24
25        while !input.is_empty() {
26            let ident: syn::Ident = input.parse()?;
27            input.parse::<Token![=]>()?;
28
29            match ident.to_string().as_str() {
30                "model" => {
31                    let lit: Lit = input.parse()?;
32                    if let Lit::Str(s) = lit {
33                        model = Some(s.value());
34                    }
35                }
36                "api_key" => {
37                    let lit: Lit = input.parse()?;
38                    if let Lit::Str(s) = lit {
39                        api_key = Some(s.value());
40                    }
41                }
42                "autobind" => {
43                    let lit: Lit = input.parse()?;
44                    if let Lit::Bool(b) = lit {
45                        autobind = b.value();
46                    }
47                }
48                "native_tools" => {
49                    let lit: Lit = input.parse()?;
50                    if let Lit::Bool(b) = lit {
51                        native_tools = b.value();
52                    }
53                }
54                _ => {
55                    return Err(syn::Error::new(ident.span(), "Unknown attribute"));
56                }
57            }
58
59            if !input.is_empty() {
60                input.parse::<Token![,]>()?;
61            }
62        }
63
64        Ok(Self {
65            model,
66            api_key,
67            autobind,
68            native_tools,
69        })
70    }
71}
72
73/// `#[agentic]` attribute macro for async functions
74///
75/// Automatically creates a Runtime, injects it into the function, binds tools, and validates usage.
76///
77/// The decorated function must have a `runtime: Runtime` parameter which will be automatically
78/// injected by the macro - callers should NOT pass it.
79///
80/// # Attributes
81/// - `model`: Model string in format "provider:model" (e.g., "anthropic:claude-3-5-sonnet-20241022")
82/// - `api_key`: Optional API key (defaults to environment variable based on provider)
83/// - `autobind`: Whether to auto-bind callable parameters as tools (default: true)
84/// - `native_tools`: Whether to use native tool calling (default: true)
85///
86/// # Example
87/// ```ignore
88/// use reson_agentic::prelude::*;
89/// use reson_agentic::agentic;
90///
91/// #[agentic(model = "anthropic:claude-3-5-sonnet-20241022")]
92/// async fn extract_people(text: String, runtime: Runtime) -> Result<Vec<Person>> {
93///     runtime.run(
94///         Some(&format!("Extract people from: {}", text)),
95///         None, None, None, None, None, None, None, None
96///     ).await
97/// }
98///
99/// // Usage - note: runtime is NOT passed by caller
100/// let result = extract_people("Alice is 30 years old".to_string()).await?;
101/// ```
102#[proc_macro_attribute]
103pub fn agentic(attr: TokenStream, item: TokenStream) -> TokenStream {
104    let args = parse_macro_input!(attr as AgenticArgs);
105    let input_fn = parse_macro_input!(item as ItemFn);
106
107    let model = args.model;
108    let api_key = args.api_key;
109    let _autobind = args.autobind;
110    let _native_tools = args.native_tools;
111
112    // Extract function components
113    let fn_name = &input_fn.sig.ident;
114    let fn_vis = &input_fn.vis;
115    let fn_generics = &input_fn.sig.generics;
116    let fn_output = &input_fn.sig.output;
117    let fn_block = &input_fn.block;
118    let fn_attrs = &input_fn.attrs;
119    let fn_asyncness = &input_fn.sig.asyncness;
120
121    // Separate runtime parameter from other parameters
122    let mut other_params = Vec::new();
123    let mut has_runtime = false;
124
125    for param in input_fn.sig.inputs.iter() {
126        match param {
127            FnArg::Typed(PatType { pat, .. }) => {
128                if let Pat::Ident(pat_ident) = pat.as_ref() {
129                    if pat_ident.ident == "runtime" {
130                        has_runtime = true;
131                        continue; // Skip runtime param - we'll inject it
132                    }
133                }
134                other_params.push(param.clone());
135            }
136            FnArg::Receiver(_) => {
137                other_params.push(param.clone());
138            }
139        }
140    }
141
142    if !has_runtime {
143        return syn::Error::new_spanned(
144            &input_fn.sig.ident,
145            "#[agentic] function must have a `runtime: Runtime` parameter",
146        )
147        .to_compile_error()
148        .into();
149    }
150
151    // Generate model setup
152    let model_setup = if let Some(m) = model {
153        quote! { Some(#m.to_string()) }
154    } else {
155        quote! { None }
156    };
157
158    let api_key_setup = if let Some(k) = api_key {
159        quote! { Some(#k.to_string()) }
160    } else {
161        quote! { None }
162    };
163
164    // Generate wrapper function that:
165    // 1. Creates a Runtime
166    // 2. Calls the original function with runtime injected
167    // 3. Validates runtime.used after completion
168    let expanded = quote! {
169        #(#fn_attrs)*
170        #fn_vis #fn_asyncness fn #fn_name #fn_generics(#(#other_params),*) #fn_output {
171            // Create Runtime with native tools enabled
172            let mut runtime = ::reson_agentic::runtime::Runtime::with_config(
173                #model_setup,
174                #api_key_setup,
175            );
176
177            // Execute the original function body with runtime in scope
178            let result = {
179                // The original function body has access to `runtime`
180                #fn_block
181            };
182
183            // Validate runtime was used
184            if !runtime.used {
185                panic!(
186                    "agentic function '{}' completed without calling runtime.run() or runtime.run_stream()",
187                    stringify!(#fn_name)
188                );
189            }
190
191            result
192        }
193    };
194
195    TokenStream::from(expanded)
196}
197
198/// `#[agentic_generator]` attribute macro for async generator functions
199///
200/// Similar to `#[agentic]` but for functions that return `impl Stream`.
201/// Generator functions can yield intermediate results while processing.
202///
203/// # Example
204/// ```ignore
205/// #[agentic_generator(model = "anthropic:claude-3-5-sonnet-20241022")]
206/// async fn process_items(items: Vec<String>, runtime: Runtime) -> impl Stream<Item = Result<String>> {
207///     async_stream::stream! {
208///         for item in items {
209///             let result = runtime.run(Some(&item), None, None, None, None, None, None, None, None).await?;
210///             yield Ok(result.to_string());
211///         }
212///     }
213/// }
214/// ```
215#[proc_macro_attribute]
216pub fn agentic_generator(attr: TokenStream, item: TokenStream) -> TokenStream {
217    // For now, generators use the same logic as regular agentic functions
218    // Full generator support with yield tracking would need async_stream integration
219    agentic(attr, item)
220}
221
222/// `#[derive(Tool)]` for tool structs
223///
224/// Automatically implements schema generation for tool types.
225/// Use with Serialize/Deserialize for full tool support.
226///
227/// # Example
228/// ```ignore
229/// #[derive(Tool, Serialize, Deserialize)]
230/// struct CalculateTool {
231///     /// The operation to perform (add, subtract, multiply, divide)
232///     operation: String,
233///     /// First operand
234///     a: f64,
235///     /// Second operand
236///     b: f64,
237/// }
238/// ```
239#[proc_macro_derive(Tool, attributes(tool))]
240pub fn derive_tool(input: TokenStream) -> TokenStream {
241    let input = parse_macro_input!(input as syn::DeriveInput);
242    let name = &input.ident;
243    let name_str = name.to_string();
244
245    // Convert PascalCase to snake_case for tool name
246    let tool_name = convert_to_snake_case(&name_str);
247
248    // Extract struct-level doc comment for description
249    let struct_description = extract_doc_comments(&input.attrs);
250
251    // Extract fields and their documentation
252    let fields = match &input.data {
253        syn::Data::Struct(data_struct) => match &data_struct.fields {
254            syn::Fields::Named(fields) => &fields.named,
255            _ => {
256                return syn::Error::new_spanned(
257                    name,
258                    "Tool derive only supports structs with named fields",
259                )
260                .to_compile_error()
261                .into();
262            }
263        },
264        _ => {
265            return syn::Error::new_spanned(name, "Tool derive only supports structs")
266                .to_compile_error()
267                .into();
268        }
269    };
270
271    // Build schema properties from fields
272    let mut schema_properties = Vec::new();
273    let mut required_fields = Vec::new();
274
275    for field in fields {
276        let field_name = field.ident.as_ref().unwrap();
277        let field_name_str = field_name.to_string();
278
279        // Extract field documentation
280        let field_desc = extract_doc_comments(&field.attrs);
281
282        // Get the JSON type for this field
283        let json_type = get_json_type(&field.ty);
284
285        // Determine if field is Option<T> (optional)
286        let is_optional = is_option_type(&field.ty);
287
288        if !is_optional {
289            required_fields.push(field_name_str.clone());
290        }
291
292        // Check if this is an array type and get the item type
293        let array_item_info = get_array_item_type(&field.ty);
294
295        match array_item_info {
296            Some(ArrayItemType::Primitive(item_type)) => {
297                // Array of primitives - static schema
298                schema_properties.push(quote! {
299                    properties.insert(
300                        #field_name_str.to_string(),
301                        serde_json::json!({
302                            "type": #json_type,
303                            "description": #field_desc,
304                            "items": {
305                                "type": #item_type
306                            }
307                        })
308                    );
309                });
310            }
311            Some(ArrayItemType::Complex(inner_ty)) => {
312                // Array of complex types - call inner type's schema() at runtime
313                schema_properties.push(quote! {
314                    {
315                        let mut arr_schema = serde_json::json!({
316                            "type": #json_type,
317                            "description": #field_desc
318                        });
319                        arr_schema["items"] = #inner_ty::schema();
320                        properties.insert(#field_name_str.to_string(), arr_schema);
321                    }
322                });
323            }
324            None => {
325                // Non-array type
326                schema_properties.push(quote! {
327                    properties.insert(
328                        #field_name_str.to_string(),
329                        serde_json::json!({
330                            "type": #json_type,
331                            "description": #field_desc
332                        })
333                    );
334                });
335            }
336        }
337    }
338
339    let required_array = if required_fields.is_empty() {
340        quote! { serde_json::json!([]) }
341    } else {
342        let req_fields = required_fields.iter();
343        quote! { serde_json::json!([#(#req_fields),*]) }
344    };
345
346    let expanded = quote! {
347        impl #name {
348            /// Get the tool name (snake_case version of struct name)
349            pub fn tool_name() -> &'static str {
350                #tool_name
351            }
352
353            /// Get the tool description from doc comments
354            pub fn description() -> &'static str {
355                #struct_description
356            }
357
358            /// Generate JSON schema for this tool
359            pub fn schema() -> serde_json::Value {
360                let mut properties = serde_json::Map::new();
361                #(#schema_properties)*
362
363                serde_json::json!({
364                    "type": "object",
365                    "properties": serde_json::Value::Object(properties),
366                    "required": #required_array
367                })
368            }
369
370            /// Generate provider-specific tool schema using a SchemaGenerator
371            pub fn tool_schema(generator: &dyn ::reson_agentic::schema::SchemaGenerator) -> serde_json::Value {
372                generator.generate_schema(
373                    #tool_name,
374                    #struct_description,
375                    Self::schema()
376                )
377            }
378        }
379    };
380
381    TokenStream::from(expanded)
382}
383
384/// `#[derive(Deserializable)]` for streaming-parseable types
385///
386/// Implements the Deserializable trait for progressive parsing during streaming.
387/// Types with this derive can be constructed from partial JSON as it arrives.
388///
389/// # Example
390/// ```ignore
391/// #[derive(Deserializable, Serialize, Deserialize)]
392/// struct Person {
393///     /// The person's name
394///     name: String,
395///     /// The person's age
396///     age: u32,
397///     /// Optional email address
398///     email: Option<String>,
399/// }
400/// ```
401#[proc_macro_derive(Deserializable)]
402pub fn derive_deserializable(input: TokenStream) -> TokenStream {
403    let input = parse_macro_input!(input as syn::DeriveInput);
404    let name = &input.ident;
405
406    // Extract fields
407    let fields = match &input.data {
408        syn::Data::Struct(data_struct) => match &data_struct.fields {
409            syn::Fields::Named(fields) => &fields.named,
410            _ => {
411                return syn::Error::new_spanned(
412                    name,
413                    "Deserializable derive only supports structs with named fields",
414                )
415                .to_compile_error()
416                .into();
417            }
418        },
419        _ => {
420            return syn::Error::new_spanned(name, "Deserializable derive only supports structs")
421                .to_compile_error()
422                .into();
423        }
424    };
425
426    // Build field descriptions
427    let mut field_desc_tokens = Vec::new();
428    let mut validation_checks = Vec::new();
429
430    for field in fields {
431        let field_name = field.ident.as_ref().unwrap();
432        let field_name_str = field_name.to_string();
433        let field_desc = extract_doc_comments(&field.attrs);
434        let field_type = &field.ty;
435        let is_optional = is_option_type(&field.ty);
436        let is_required = !is_optional;
437
438        field_desc_tokens.push(quote! {
439            ::reson_agentic::parsers::FieldDescription {
440                name: #field_name_str.to_string(),
441                field_type: ::std::any::type_name::<#field_type>().to_string(),
442                description: #field_desc.to_string(),
443                required: #is_required,
444            }
445        });
446
447        // Add validation for required fields
448        if is_required {
449            validation_checks.push(quote! {
450                if let serde_json::Value::Null = serde_json::to_value(&self.#field_name)
451                    .map_err(|e| ::reson_agentic::error::Error::NonRetryable(e.to_string()))? {
452                    return Err(::reson_agentic::error::Error::NonRetryable(
453                        format!("Required field '{}' is missing or null", #field_name_str)
454                    ));
455                }
456            });
457        }
458    }
459
460    let validation_logic = if validation_checks.is_empty() {
461        quote! { Ok(()) }
462    } else {
463        quote! {
464            #(#validation_checks)*
465            Ok(())
466        }
467    };
468
469    let expanded = quote! {
470        impl ::reson_agentic::parsers::Deserializable for #name {
471            fn from_partial(partial: serde_json::Value) -> ::reson_agentic::error::Result<Self> {
472                serde_json::from_value(partial).map_err(|e| {
473                    ::reson_agentic::error::Error::NonRetryable(format!("Failed to parse {}: {}", stringify!(#name), e))
474                })
475            }
476
477            fn validate_complete(&self) -> ::reson_agentic::error::Result<()> {
478                #validation_logic
479            }
480
481            fn field_descriptions() -> Vec<::reson_agentic::parsers::FieldDescription> {
482                vec![
483                    #(#field_desc_tokens),*
484                ]
485            }
486        }
487    };
488
489    TokenStream::from(expanded)
490}
491
492// Helper function to convert PascalCase to snake_case
493fn convert_to_snake_case(s: &str) -> String {
494    let mut result = String::new();
495    for (i, ch) in s.chars().enumerate() {
496        if ch.is_uppercase() {
497            if i > 0 {
498                result.push('_');
499            }
500            result.push(ch.to_lowercase().next().unwrap());
501        } else {
502            result.push(ch);
503        }
504    }
505    result
506}
507
508// Helper to check if a type is Option<T>
509fn is_option_type(ty: &syn::Type) -> bool {
510    if let syn::Type::Path(type_path) = ty {
511        if let Some(segment) = type_path.path.segments.last() {
512            return segment.ident == "Option";
513        }
514    }
515    false
516}
517
518// Helper to extract doc comments from attributes
519fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
520    let mut docs = Vec::new();
521    for attr in attrs {
522        if attr.path().is_ident("doc") {
523            if let syn::Meta::NameValue(meta) = &attr.meta {
524                if let syn::Expr::Lit(expr_lit) = &meta.value {
525                    if let syn::Lit::Str(lit_str) = &expr_lit.lit {
526                        docs.push(lit_str.value().trim().to_string());
527                    }
528                }
529            }
530        }
531    }
532    docs.join(" ")
533}
534
535// Helper to get JSON schema type from Rust type
536fn get_json_type(ty: &syn::Type) -> String {
537    if let syn::Type::Path(type_path) = ty {
538        if let Some(segment) = type_path.path.segments.last() {
539            let ident = segment.ident.to_string();
540
541            // Handle Option<T> - extract inner type
542            if ident == "Option" {
543                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
544                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
545                        return get_json_type(inner_ty);
546                    }
547                }
548            }
549
550            // Map Rust types to JSON schema types
551            return match ident.as_str() {
552                "String" | "str" => "string",
553                "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64"
554                | "u128" | "usize" => "integer",
555                "f32" | "f64" => "number",
556                "bool" => "boolean",
557                "Vec" => "array",
558                "HashMap" | "BTreeMap" => "object",
559                _ => "object", // Default for custom types
560            }
561            .to_string();
562        }
563    }
564    "object".to_string()
565}
566
567/// Represents the inner type of an array
568enum ArrayItemType {
569    /// Primitive type like string, integer, number, boolean
570    Primitive(String),
571    /// Complex type (struct) that has its own schema() method
572    Complex(syn::Type),
573}
574
575/// Get the inner type of Vec<T> or Option<Vec<T>>
576fn get_array_item_type(ty: &syn::Type) -> Option<ArrayItemType> {
577    if let syn::Type::Path(type_path) = ty {
578        if let Some(segment) = type_path.path.segments.last() {
579            let ident = segment.ident.to_string();
580
581            // Handle Option<Vec<T>> - extract Vec<T> first
582            if ident == "Option" {
583                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
584                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
585                        return get_array_item_type(inner_ty);
586                    }
587                }
588            }
589
590            // Handle Vec<T> - extract T
591            if ident == "Vec" {
592                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
593                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
594                        let json_type = get_json_type(inner_ty);
595                        // Primitive types just need "type": "string" etc
596                        if matches!(json_type.as_str(), "string" | "integer" | "number" | "boolean") {
597                            return Some(ArrayItemType::Primitive(json_type));
598                        }
599                        // Complex types need full schema from T::schema()
600                        return Some(ArrayItemType::Complex(inner_ty.clone()));
601                    }
602                }
603                // Default to string if we can't determine the inner type
604                return Some(ArrayItemType::Primitive("string".to_string()));
605            }
606        }
607    }
608    None
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614
615    #[test]
616    fn test_snake_case_conversion() {
617        assert_eq!(convert_to_snake_case("CalculatorTool"), "calculator_tool");
618        assert_eq!(convert_to_snake_case("GetWeather"), "get_weather");
619        assert_eq!(convert_to_snake_case("HTTPClient"), "h_t_t_p_client");
620        assert_eq!(convert_to_snake_case("Simple"), "simple");
621    }
622}