Skip to main content

cortexai_macros/
lib.rs

1//! # Procedural Macros for cortex
2//!
3//! This crate provides derive macros for easily defining tools and other
4//! agent components with minimal boilerplate.
5//!
6//! ## Tool Derive Macro
7//!
8//! ```rust,ignore
9//! use cortexai_macros::Tool;
10//!
11//! #[derive(Tool)]
12//! #[tool(name = "calculator", description = "Perform mathematical calculations")]
13//! struct CalculatorTool {
14//!     #[tool(param, required, description = "The mathematical expression to evaluate")]
15//!     expression: String,
16//!
17//!     #[tool(param, description = "Number of decimal places for result")]
18//!     precision: Option<u32>,
19//! }
20//!
21//! impl CalculatorTool {
22//!     async fn run(&self, expression: String, precision: Option<u32>) -> Result<serde_json::Value, String> {
23//!         // Implementation here
24//!         Ok(serde_json::json!(42.0))
25//!     }
26//! }
27//! ```
28
29use darling::{ast, FromDeriveInput, FromField};
30use proc_macro::TokenStream;
31use quote::{format_ident, quote};
32use syn::{parse::Parser, parse_macro_input, DeriveInput, FnArg, Ident, ItemFn, Meta, Type};
33
34/// Attributes for the Tool derive macro at struct level
35#[derive(Debug, FromDeriveInput)]
36#[darling(attributes(tool), supports(struct_named))]
37struct ToolArgs {
38    ident: Ident,
39    data: ast::Data<(), ToolField>,
40
41    /// Tool name (defaults to snake_case of struct name)
42    #[darling(default)]
43    name: Option<String>,
44
45    /// Tool description
46    #[darling(default)]
47    description: Option<String>,
48
49    /// Whether the tool is dangerous and requires confirmation
50    #[darling(default)]
51    dangerous: bool,
52}
53
54/// Attributes for tool parameters (struct fields)
55#[derive(Debug, FromField)]
56#[darling(attributes(tool))]
57struct ToolField {
58    ident: Option<Ident>,
59    ty: Type,
60
61    /// Mark this field as a tool parameter
62    #[darling(default)]
63    param: bool,
64
65    /// Whether this parameter is required
66    #[darling(default)]
67    required: bool,
68
69    /// Parameter description
70    #[darling(default)]
71    description: Option<String>,
72
73    /// Skip this field (not a parameter)
74    #[darling(default)]
75    skip: bool,
76}
77
78/// Derive macro for creating Tool implementations
79///
80/// # Example
81///
82/// ```rust,ignore
83/// #[derive(Tool)]
84/// #[tool(name = "search", description = "Search the web")]
85/// struct SearchTool {
86///     #[tool(param, required, description = "Search query")]
87///     query: String,
88///
89///     #[tool(param, description = "Maximum results to return")]
90///     max_results: Option<u32>,
91/// }
92///
93/// impl SearchTool {
94///     // You must implement this method
95///     async fn run(&self, query: String, max_results: Option<u32>) -> Result<serde_json::Value, String> {
96///         Ok(serde_json::json!({"results": []}))
97///     }
98/// }
99/// ```
100#[proc_macro_derive(Tool, attributes(tool))]
101pub fn derive_tool(input: TokenStream) -> TokenStream {
102    let input = parse_macro_input!(input as DeriveInput);
103
104    let args = match ToolArgs::from_derive_input(&input) {
105        Ok(args) => args,
106        Err(e) => return e.write_errors().into(),
107    };
108
109    let expanded = generate_tool_impl(&args);
110
111    TokenStream::from(expanded)
112}
113
114fn generate_tool_impl(args: &ToolArgs) -> proc_macro2::TokenStream {
115    let struct_name = &args.ident;
116
117    // Generate tool name (default to snake_case of struct name without "Tool" suffix)
118    let tool_name = args.name.clone().unwrap_or_else(|| {
119        let name = struct_name.to_string();
120        let name = name.strip_suffix("Tool").unwrap_or(&name);
121        to_snake_case(name)
122    });
123
124    let description = args
125        .description
126        .clone()
127        .unwrap_or_else(|| format!("{} tool", tool_name));
128
129    let dangerous = args.dangerous;
130
131    // Extract parameter fields
132    let fields = match &args.data {
133        ast::Data::Struct(fields) => fields,
134        _ => panic!("Tool derive only supports structs"),
135    };
136
137    let params: Vec<_> = fields
138        .fields
139        .iter()
140        .filter(|f| f.param && !f.skip)
141        .collect();
142
143    // Generate JSON schema for parameters
144    let param_properties = generate_param_properties(&params);
145    let required_params = generate_required_params(&params);
146
147    // Generate argument extraction code
148    let arg_extractions = generate_arg_extractions(&params);
149
150    // Generate parameter names for calling run()
151    let param_names: Vec<_> = params.iter().map(|f| f.ident.as_ref().unwrap()).collect();
152
153    quote! {
154        #[async_trait::async_trait]
155        impl cortexai_core::tool::Tool for #struct_name {
156            fn schema(&self) -> cortexai_core::tool::ToolSchema {
157                let mut __properties = serde_json::Map::new();
158                #(#param_properties)*
159
160                cortexai_core::tool::ToolSchema {
161                    name: #tool_name.to_string(),
162                    description: #description.to_string(),
163                    parameters: serde_json::json!({
164                        "type": "object",
165                        "properties": serde_json::Value::Object(__properties),
166                        "required": [#(#required_params),*]
167                    }),
168                    dangerous: #dangerous,
169                    metadata: std::collections::HashMap::new(),
170                    required_scopes: vec![],
171                }
172            }
173
174            async fn execute(
175                &self,
176                _context: &cortexai_core::tool::ExecutionContext,
177                arguments: serde_json::Value,
178            ) -> Result<serde_json::Value, cortexai_core::errors::ToolError> {
179                #(#arg_extractions)*
180
181                self.run(#(#param_names),*).await
182                    .map_err(|e| cortexai_core::errors::ToolError::ExecutionFailed(e.to_string()))
183            }
184        }
185    }
186}
187
188fn generate_param_properties(params: &[&ToolField]) -> Vec<proc_macro2::TokenStream> {
189    params
190        .iter()
191        .map(|field| {
192            let name = field.ident.as_ref().unwrap().to_string();
193            let description = field
194                .description
195                .clone()
196                .unwrap_or_else(|| format!("Parameter: {}", name));
197            let effective_ty = unwrap_option_type(&field.ty).unwrap_or(&field.ty);
198            let schema_tokens = type_to_json_schema(effective_ty);
199
200            quote! {
201                {
202                    let mut __prop_schema = #schema_tokens;
203                    if let serde_json::Value::Object(ref mut m) = __prop_schema {
204                        m.insert("description".to_string(), serde_json::Value::String(#description.to_string()));
205                    }
206                    __properties.insert(#name.to_string(), __prop_schema);
207                }
208            }
209        })
210        .collect()
211}
212
213fn generate_required_params(params: &[&ToolField]) -> Vec<proc_macro2::TokenStream> {
214    params
215        .iter()
216        .filter(|f| f.required)
217        .map(|field| {
218            let name = field.ident.as_ref().unwrap().to_string();
219            quote! { #name }
220        })
221        .collect()
222}
223
224fn generate_arg_extractions(params: &[&ToolField]) -> Vec<proc_macro2::TokenStream> {
225    params
226        .iter()
227        .map(|field| {
228            let ident = field.ident.as_ref().unwrap();
229            let name = ident.to_string();
230            let ty = &field.ty;
231
232            if is_option_type(ty) {
233                quote! {
234                    let #ident: #ty = arguments.get(#name)
235                        .and_then(|v| serde_json::from_value(v.clone()).ok());
236                }
237            } else if field.required {
238                quote! {
239                    let #ident: #ty = {
240                        let val = arguments.get(#name)
241                            .ok_or_else(|| cortexai_core::errors::ToolError::InvalidArguments(
242                                format!("Missing required parameter: {}", #name)
243                            ))?
244                            .clone();
245                        serde_json::from_value(val)
246                            .map_err(|e| cortexai_core::errors::ToolError::InvalidArguments(
247                                format!("Invalid type for {}: {}", #name, e)
248                            ))?
249                    };
250                }
251            } else {
252                // Non-required, non-option field - use default
253                quote! {
254                    let #ident: #ty = arguments.get(#name)
255                        .and_then(|v| serde_json::from_value(v.clone()).ok())
256                        .unwrap_or_default();
257                }
258            }
259        })
260        .collect()
261}
262
263/// Generate a `proc_macro2::TokenStream` that evaluates to a `serde_json::Value`
264/// representing the JSON Schema for the given Rust type.
265fn type_to_json_schema(ty: &Type) -> proc_macro2::TokenStream {
266    let type_name = extract_type_name(ty);
267
268    match type_name.as_str() {
269        "String" | "&str" | "str" => quote! { serde_json::json!({"type": "string"}) },
270        "bool" => quote! { serde_json::json!({"type": "boolean"}) },
271        "f32" | "f64" => quote! { serde_json::json!({"type": "number"}) },
272        "i8" | "i16" | "i32" | "i64" | "i128" | "isize"
273        | "u8" | "u16" | "u32" | "u64" | "u128" | "usize" => {
274            quote! { serde_json::json!({"type": "integer"}) }
275        }
276        "Vec" | "Array" => {
277            let items_schema = extract_first_generic_arg(ty)
278                .map(|inner| type_to_json_schema(inner))
279                .unwrap_or_else(|| quote! { serde_json::json!({}) });
280            quote! {
281                {
282                    let __items = #items_schema;
283                    serde_json::json!({"type": "array", "items": __items})
284                }
285            }
286        }
287        "HashMap" | "BTreeMap" => {
288            let value_schema = extract_second_generic_arg(ty)
289                .map(|inner| type_to_json_schema(inner))
290                .unwrap_or_else(|| quote! { serde_json::json!({}) });
291            quote! {
292                {
293                    let __additional = #value_schema;
294                    serde_json::json!({"type": "object", "additionalProperties": __additional})
295                }
296            }
297        }
298        "Option" => {
299            // If Option wasn't already unwrapped, handle it here
300            extract_first_generic_arg(ty)
301                .map(|inner| type_to_json_schema(inner))
302                .unwrap_or_else(|| quote! { serde_json::json!({"type": "object"}) })
303        }
304        _ => quote! { serde_json::json!({"type": "object"}) },
305    }
306}
307
308/// Extract the outermost type name (e.g. "Vec" from `Vec<String>`, "String" from `String`)
309fn extract_type_name(ty: &Type) -> String {
310    if let Type::Path(type_path) = ty {
311        if let Some(segment) = type_path.path.segments.last() {
312            return segment.ident.to_string();
313        }
314    }
315    String::new()
316}
317
318/// Extract the first generic type argument (e.g. `String` from `Vec<String>`)
319fn extract_first_generic_arg(ty: &Type) -> Option<&Type> {
320    if let Type::Path(type_path) = ty {
321        if let Some(segment) = type_path.path.segments.last() {
322            if let syn::PathArguments::AngleBracketed(ref args) = segment.arguments {
323                if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
324                    return Some(inner);
325                }
326            }
327        }
328    }
329    None
330}
331
332/// Extract the second generic type argument (e.g. `V` from `HashMap<K, V>`)
333fn extract_second_generic_arg(ty: &Type) -> Option<&Type> {
334    if let Type::Path(type_path) = ty {
335        if let Some(segment) = type_path.path.segments.last() {
336            if let syn::PathArguments::AngleBracketed(ref args) = segment.arguments {
337                if let Some(syn::GenericArgument::Type(inner)) = args.args.iter().nth(1) {
338                    return Some(inner);
339                }
340            }
341        }
342    }
343    None
344}
345
346/// If the type is `Option<T>`, return `Some(&T)`. Otherwise `None`.
347fn unwrap_option_type(ty: &Type) -> Option<&Type> {
348    if extract_type_name(ty) == "Option" {
349        extract_first_generic_arg(ty)
350    } else {
351        None
352    }
353}
354
355fn is_option_type(ty: &Type) -> bool {
356    extract_type_name(ty) == "Option"
357}
358
359fn to_snake_case(s: &str) -> String {
360    let mut result = String::new();
361    for (i, c) in s.chars().enumerate() {
362        if c.is_uppercase() {
363            if i > 0 {
364                result.push('_');
365            }
366            result.push(c.to_ascii_lowercase());
367        } else {
368            result.push(c);
369        }
370    }
371    result
372}
373
374fn to_pascal_case(s: &str) -> String {
375    s.split('_')
376        .map(|part| {
377            let mut chars = part.chars();
378            match chars.next() {
379                None => String::new(),
380                Some(first) => {
381                    let upper: String = first.to_uppercase().collect();
382                    upper + &chars.collect::<String>()
383                }
384            }
385        })
386        .collect()
387}
388
389/// Parsed parameter info from a function argument with `#[param(...)]` attributes.
390struct FnParam {
391    ident: Ident,
392    ty: Type,
393    description: Option<String>,
394    required: bool,
395    name_override: Option<String>,
396    is_option: bool,
397}
398
399fn parse_param_attrs(attrs: &[syn::Attribute]) -> (Option<String>, bool, Option<String>) {
400    let mut description = None;
401    let mut required = false;
402    let mut name_override = None;
403
404    for attr in attrs {
405        if !attr.path().is_ident("param") {
406            continue;
407        }
408        if let Meta::List(meta_list) = &attr.meta {
409            let tokens = meta_list.tokens.clone();
410            let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
411            if let Ok(nested) = parser.parse2(tokens) {
412                for meta in &nested {
413                    match meta {
414                        Meta::Path(path) if path.is_ident("required") => {
415                            required = true;
416                        }
417                        Meta::NameValue(nv) if nv.path.is_ident("description") => {
418                            if let syn::Expr::Lit(syn::ExprLit {
419                                lit: syn::Lit::Str(s),
420                                ..
421                            }) = &nv.value
422                            {
423                                description = Some(s.value());
424                            }
425                        }
426                        Meta::NameValue(nv) if nv.path.is_ident("name") => {
427                            if let syn::Expr::Lit(syn::ExprLit {
428                                lit: syn::Lit::Str(s),
429                                ..
430                            }) = &nv.value
431                            {
432                                name_override = Some(s.value());
433                            }
434                        }
435                        _ => {}
436                    }
437                }
438            }
439        }
440    }
441    (description, required, name_override)
442}
443
444fn extract_fn_params(sig: &syn::Signature) -> Vec<FnParam> {
445    sig.inputs
446        .iter()
447        .filter_map(|arg| {
448            if let FnArg::Typed(pat_type) = arg {
449                if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
450                    let ident = pat_ident.ident.clone();
451                    let ty = *pat_type.ty.clone();
452                    let is_option = is_option_type(&ty);
453                    let (description, explicit_required, name_override) =
454                        parse_param_attrs(&pat_type.attrs);
455                    // Default: required unless Option<T>
456                    let required = explicit_required || !is_option;
457                    return Some(FnParam {
458                        ident,
459                        ty,
460                        description,
461                        required,
462                        name_override,
463                        is_option,
464                    });
465                }
466            }
467            None
468        })
469        .collect()
470}
471
472/// Attribute macro for defining tools from async functions.
473///
474/// Generates a `<PascalCaseFnName>Tool` struct implementing the `Tool` trait.
475/// Use `#[param(...)]` on function parameters for schema metadata.
476#[proc_macro_attribute]
477pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
478    let attr_args = proc_macro2::TokenStream::from(attr);
479    let input_fn = parse_macro_input!(item as ItemFn);
480
481    match generate_fn_tool(attr_args, &input_fn) {
482        Ok(tokens) => tokens.into(),
483        Err(err) => err.to_compile_error().into(),
484    }
485}
486
487fn parse_tool_attr_args(
488    tokens: proc_macro2::TokenStream,
489) -> syn::Result<(Option<String>, Option<String>)> {
490    let mut description = None;
491    let mut name_override = None;
492
493    if tokens.is_empty() {
494        return Ok((description, name_override));
495    }
496
497    let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
498    let metas = parser.parse2(tokens)?;
499
500    for meta in &metas {
501        if let Meta::NameValue(nv) = meta {
502            if nv.path.is_ident("description") {
503                if let syn::Expr::Lit(syn::ExprLit {
504                    lit: syn::Lit::Str(s),
505                    ..
506                }) = &nv.value
507                {
508                    description = Some(s.value());
509                }
510            } else if nv.path.is_ident("name") {
511                if let syn::Expr::Lit(syn::ExprLit {
512                    lit: syn::Lit::Str(s),
513                    ..
514                }) = &nv.value
515                {
516                    name_override = Some(s.value());
517                }
518            }
519        }
520    }
521
522    Ok((description, name_override))
523}
524
525fn generate_fn_tool(
526    attr_args: proc_macro2::TokenStream,
527    input_fn: &ItemFn,
528) -> syn::Result<proc_macro2::TokenStream> {
529    let (attr_description, attr_name) = parse_tool_attr_args(attr_args)?;
530
531    let fn_name = &input_fn.sig.ident;
532    let fn_name_str = fn_name.to_string();
533
534    let struct_name = format_ident!("{}Tool", to_pascal_case(&fn_name_str));
535    let tool_name = attr_name.unwrap_or_else(|| fn_name_str.clone());
536    let description = attr_description.unwrap_or_else(|| format!("{} tool", tool_name));
537
538    let params = extract_fn_params(&input_fn.sig);
539
540    // Generate properties for JSON schema (using T3's rich schema generation)
541    let param_properties: Vec<proc_macro2::TokenStream> = params
542        .iter()
543        .map(|p| {
544            let name = p
545                .name_override
546                .clone()
547                .unwrap_or_else(|| p.ident.to_string());
548            let desc = p
549                .description
550                .clone()
551                .unwrap_or_else(|| format!("Parameter: {}", name));
552            let effective_ty = unwrap_option_type(&p.ty).unwrap_or(&p.ty);
553            let schema_tokens = type_to_json_schema(effective_ty);
554
555            quote! {
556                {
557                    let mut __prop_schema = #schema_tokens;
558                    if let serde_json::Value::Object(ref mut m) = __prop_schema {
559                        m.insert("description".to_string(), serde_json::Value::String(#desc.to_string()));
560                    }
561                    __properties.insert(#name.to_string(), __prop_schema);
562                }
563            }
564        })
565        .collect();
566
567    let required_params: Vec<proc_macro2::TokenStream> = params
568        .iter()
569        .filter(|p| p.required)
570        .map(|p| {
571            let name = p
572                .name_override
573                .clone()
574                .unwrap_or_else(|| p.ident.to_string());
575            quote! { #name }
576        })
577        .collect();
578
579    // Generate arg extraction code
580    let arg_extractions: Vec<proc_macro2::TokenStream> = params
581        .iter()
582        .map(|p| {
583            let ident = &p.ident;
584            let name = p
585                .name_override
586                .clone()
587                .unwrap_or_else(|| p.ident.to_string());
588            let ty = &p.ty;
589
590            if p.is_option {
591                quote! {
592                    let #ident: #ty = arguments.get(#name)
593                        .and_then(|v| serde_json::from_value(v.clone()).ok());
594                }
595            } else if p.required {
596                quote! {
597                    let #ident: #ty = {
598                        let val = arguments.get(#name)
599                            .ok_or_else(|| cortexai_core::errors::ToolError::InvalidArguments(
600                                format!("Missing required parameter: {}", #name)
601                            ))?
602                            .clone();
603                        serde_json::from_value(val)
604                            .map_err(|e| cortexai_core::errors::ToolError::InvalidArguments(
605                                format!("Invalid type for {}: {}", #name, e)
606                            ))?
607                    };
608                }
609            } else {
610                quote! {
611                    let #ident: #ty = arguments.get(#name)
612                        .and_then(|v| serde_json::from_value(v.clone()).ok())
613                        .unwrap_or_default();
614                }
615            }
616        })
617        .collect();
618
619    let param_idents: Vec<&Ident> = params.iter().map(|p| &p.ident).collect();
620
621    // Strip #[param(...)] attributes from the original function signature
622    let mut clean_fn = input_fn.clone();
623    for arg in &mut clean_fn.sig.inputs {
624        if let FnArg::Typed(pat_type) = arg {
625            pat_type.attrs.retain(|a| !a.path().is_ident("param"));
626        }
627    }
628
629    Ok(quote! {
630        #clean_fn
631
632        #[derive(Default)]
633        pub struct #struct_name;
634
635        #[async_trait::async_trait]
636        impl cortexai_core::tool::Tool for #struct_name {
637            fn schema(&self) -> cortexai_core::tool::ToolSchema {
638                let mut __properties = serde_json::Map::new();
639                #(#param_properties)*
640
641                cortexai_core::tool::ToolSchema {
642                    name: #tool_name.to_string(),
643                    description: #description.to_string(),
644                    parameters: serde_json::json!({
645                        "type": "object",
646                        "properties": serde_json::Value::Object(__properties),
647                        "required": [#(#required_params),*]
648                    }),
649                    dangerous: false,
650                    metadata: std::collections::HashMap::new(),
651                    required_scopes: Vec::new(),
652                }
653            }
654
655            async fn execute(
656                &self,
657                _context: &cortexai_core::tool::ExecutionContext,
658                arguments: serde_json::Value,
659            ) -> Result<serde_json::Value, cortexai_core::errors::ToolError> {
660                #(#arg_extractions)*
661
662                #fn_name(#(#param_idents),*).await
663                    .map_err(|e| cortexai_core::errors::ToolError::ExecutionFailed(e.to_string()))
664            }
665        }
666    })
667}