agent_chain_macros/
lib.rs

1//! Procedural macros for agent-chain.
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{FnArg, ItemFn, Pat, ReturnType, Type, parse_macro_input};
6
7/// Marks a function as a tool that can be used by an LLM.
8///
9/// This macro generates a struct that implements the `Tool` trait,
10/// allowing the function to be invoked by an AI model.
11///
12/// # Example
13///
14/// ```ignore
15/// use agent_chain::tools::tool;
16///
17/// #[tool]
18/// fn multiply(a: i64, b: i64) -> i64 {
19///     a * b
20/// }
21///
22/// // Creates a tool instance
23/// let tool = multiply::tool();
24/// ```
25#[proc_macro_attribute]
26pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
27    let input = parse_macro_input!(item as ItemFn);
28
29    let fn_name = &input.sig.ident;
30    let fn_name_str = fn_name.to_string();
31    let mod_name = format_ident!("{}", fn_name);
32    let struct_name = format_ident!("{}Tool", to_pascal_case(&fn_name_str));
33
34    let fn_body = &input.block;
35    let fn_vis = &input.vis;
36    let _fn_asyncness = &input.sig.asyncness;
37    let fn_return_type = &input.sig.output;
38
39    // Extract parameters
40    let params: Vec<_> = input
41        .sig
42        .inputs
43        .iter()
44        .filter_map(|arg| {
45            if let FnArg::Typed(pat_type) = arg
46                && let Pat::Ident(pat_ident) = pat_type.pat.as_ref()
47            {
48                let param_name = &pat_ident.ident;
49                let param_type = &pat_type.ty;
50                return Some((param_name.clone(), param_type.clone()));
51            }
52            None
53        })
54        .collect();
55
56    let param_names: Vec<_> = params.iter().map(|(name, _)| name.clone()).collect();
57    let param_types: Vec<_> = params.iter().map(|(_, ty)| ty.clone()).collect();
58    let param_names_str: Vec<_> = params.iter().map(|(name, _)| name.to_string()).collect();
59
60    // Generate JSON schema properties for parameters
61    let schema_properties: Vec<_> = params
62        .iter()
63        .map(|(name, ty)| {
64            let name_str = name.to_string();
65            let type_str = get_json_type(ty);
66            quote! {
67                (#name_str.to_string(), serde_json::json!({ "type": #type_str }))
68            }
69        })
70        .collect();
71
72    // Get the return type without the `-> `
73    let actual_return_type = match fn_return_type {
74        ReturnType::Default => quote! { () },
75        ReturnType::Type(_, ty) => quote! { #ty },
76    };
77
78    let expanded = quote! {
79        #fn_vis mod #mod_name {
80            use super::*;
81            use std::collections::HashMap;
82            use serde_json;
83
84            /// The tool implementation struct
85            pub struct #struct_name;
86
87            impl #struct_name {
88                /// Create a new instance of this tool
89                pub fn new() -> Self {
90                    Self
91                }
92            }
93
94            impl Default for #struct_name {
95                fn default() -> Self {
96                    Self::new()
97                }
98            }
99
100            #[agent_chain::async_trait]
101            impl agent_chain::tools::Tool for #struct_name {
102                fn name(&self) -> &str {
103                    #fn_name_str
104                }
105
106                fn description(&self) -> &str {
107                    concat!("Tool: ", #fn_name_str)
108                }
109
110                fn parameters_schema(&self) -> serde_json::Value {
111                    let properties: HashMap<String, serde_json::Value> = [
112                        #(#schema_properties),*
113                    ].into_iter().collect();
114
115                    let required: Vec<String> = vec![
116                        #(#param_names_str.to_string()),*
117                    ];
118
119                    serde_json::json!({
120                        "type": "object",
121                        "properties": properties,
122                        "required": required
123                    })
124                }
125
126                async fn invoke(&self, tool_call: agent_chain::messages::ToolCall) -> agent_chain::messages::BaseMessage {
127                    let args = tool_call.args();
128
129                    #(
130                        let #param_names: #param_types = serde_json::from_value(
131                            args.get(#param_names_str).cloned().unwrap_or(serde_json::Value::Null)
132                        ).expect(&format!("Failed to parse parameter '{}'", #param_names_str));
133                    )*
134
135                    let result: #actual_return_type = { #fn_body };
136
137                    let result_str = serde_json::to_string(&result).unwrap_or_else(|_| format!("{:?}", result));
138
139                    agent_chain::messages::ToolMessage::new(result_str, tool_call.id()).into()
140                }
141            }
142
143            /// Create a new instance of this tool
144            pub fn tool() -> #struct_name {
145                #struct_name::new()
146            }
147        }
148    };
149
150    TokenStream::from(expanded)
151}
152
153/// Convert snake_case to PascalCase
154fn to_pascal_case(s: &str) -> String {
155    s.split('_')
156        .map(|word| {
157            let mut chars = word.chars();
158            match chars.next() {
159                None => String::new(),
160                Some(first) => first.to_uppercase().chain(chars).collect(),
161            }
162        })
163        .collect()
164}
165
166/// Get the JSON schema type for a Rust type
167fn get_json_type(ty: &Type) -> &'static str {
168    let type_str = quote!(#ty).to_string();
169    match type_str.as_str() {
170        "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
171        | "usize" => "integer",
172        "f32" | "f64" => "number",
173        "bool" => "boolean",
174        "String" | "& str" | "& 'static str" => "string",
175        _ => "string", // Default to string for unknown types
176    }
177}