adk_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::Lit;
4use syn::{parse_macro_input, Expr, ExprLit, FnArg, ItemFn, Pat, PatType, Type};
5
6/// A procedural macro that generates a tool with parameter schema from a function signature
7///
8/// Usage:
9/// ```
10/// #[tool_fn(name = "calculator", description = "A simple calculator")]
11/// fn calculator(context: &mut RunContext, a: i32, b: i32, operation: String) -> String {
12///     // Function implementation
13/// }
14/// ```
15#[proc_macro_attribute]
16pub fn tool_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
17    // Parse the function definition
18    let input_fn = parse_macro_input!(item as ItemFn);
19    let fn_name = &input_fn.sig.ident;
20    let fn_name_str = fn_name.to_string();
21
22    // Parse attributes as a punctuated sequence of Meta items
23    let attrs = parse_macro_input!(attr with syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated);
24
25    // Extract name and description from attributes
26    let mut tool_name = fn_name_str.clone();
27    let mut tool_description = format!("Tool function {}", fn_name_str);
28
29    for attr in attrs.iter() {
30        if let syn::Meta::NameValue(name_value) = attr {
31            if name_value.path.is_ident("name") {
32                if let Expr::Lit(ExprLit {
33                    lit: Lit::Str(lit_str),
34                    ..
35                }) = &name_value.value
36                {
37                    tool_name = lit_str.value();
38                }
39            } else if name_value.path.is_ident("description") {
40                if let Expr::Lit(ExprLit {
41                    lit: Lit::Str(lit_str),
42                    ..
43                }) = &name_value.value
44                {
45                    tool_description = lit_str.value();
46                }
47            }
48        }
49    }
50
51    // Extract parameter information from function signature
52    let params = extract_params(&input_fn);
53
54    // Generate the tool function name (append _tool to the original function name)
55    let tool_fn_name = format_ident!("{}_tool", fn_name);
56
57    // Generate parameter extraction and conversion code
58    let param_extractions = params.iter().map(|(name, type_name)| {
59        let param_name = format_ident!("{}", name);
60        match type_name.as_str() {
61            "i32" => quote! {
62                let #param_name = params[#name].as_i64()
63                    .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?
64                    as i32;
65            },
66            "i64" => quote! {
67                let #param_name = params[#name].as_i64()
68                    .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?;
69            },
70            "u32" | "u64" => quote! {
71                let #param_name = params[#name].as_u64()
72                    .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?
73                    as u32;
74            },
75            "f32" | "f64" => quote! {
76                let #param_name = params[#name].as_f64()
77                    .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?
78                    as f64;
79            },
80            "String" => quote! {
81                let #param_name = params[#name].as_str()
82                    .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?
83                    .to_string();
84            },
85            "&str" => quote! {
86                let #param_name = params[#name].as_str()
87                    .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?;
88            },
89            "bool" => quote! {
90                let #param_name = params[#name].as_bool()
91                    .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?;
92            },
93            _ => quote! {
94                let #param_name = serde_json::from_value::<#param_name>(params[#name].clone())
95                    .map_err(|e| AgentError::InvalidInput(format!("Invalid parameter {}: {}", #name, e)))?;
96            },
97        }
98    });
99
100    // Collect parameter names for the function call
101    let param_names = params.iter().map(|(name, _)| format_ident!("{}", name));
102
103    // Generate the schema properties
104    let schema_properties = params.iter().map(|(name, type_name)| {
105        let type_str = match type_name.as_str() {
106            "i32" | "i64" | "u32" | "u64" | "f32" | "f64" => "number",
107            "String" | "&str" => "string",
108            "bool" => "boolean",
109            _ => "object",
110        };
111
112        quote! {
113            let mut property = serde_json::Map::new();
114            property.insert("type".to_string(), serde_json::Value::String(#type_str.to_string()));
115            properties.insert(#name.to_string(), serde_json::Value::Object(property));
116            required.push(serde_json::Value::String(#name.to_string()));
117        }
118    });
119
120    // Generate the expanded code
121    let expanded = quote! {
122        // Keep the original function
123        #input_fn
124
125        // Create a tool function that returns a FunctionTool
126        pub fn #tool_fn_name() -> ::adk::tool::FunctionTool {
127            use adk::error::AgentError;
128            use adk::tool::ToolResult;
129
130            ::adk::tool::FunctionTool::new(
131                #tool_name,
132                #tool_description,
133                // Generate schema based on function parameters
134                generate_parameter_schema(),
135                Box::new(|context, params_str| {
136                    // Parse parameters from JSON
137                    let params: serde_json::Value = serde_json::from_str(params_str)
138                        .map_err(|e| AgentError::InvalidInput(e.to_string()))?;
139
140                    // Extract and convert parameters
141                    #(#param_extractions)*
142
143                    // Call the function with parsed parameters
144                    let result = #fn_name(context, #(#param_names),*);
145
146                    Ok(ToolResult {
147                        tool_name: #tool_name.to_string(),
148                        output: result,
149                    })
150                })
151            )
152        }
153
154        // Generate the parameter schema as a function
155        fn generate_parameter_schema() -> serde_json::Value {
156            let mut properties = serde_json::Map::new();
157            let mut required = Vec::new();
158
159            #(#schema_properties)*
160
161            let mut schema = serde_json::Map::new();
162            schema.insert("type".to_string(), serde_json::Value::String("object".to_string()));
163            schema.insert("properties".to_string(), serde_json::Value::Object(properties));
164            schema.insert("required".to_string(), serde_json::Value::Array(required));
165
166            serde_json::Value::Object(schema)
167        }
168    };
169
170    expanded.into()
171}
172
173// Helper function to extract parameter info from a function
174fn extract_params(input_fn: &ItemFn) -> Vec<(String, String)> {
175    let mut params = Vec::new();
176
177    for arg in &input_fn.sig.inputs {
178        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
179            if let Pat::Ident(pat_ident) = &**pat {
180                let param_name = pat_ident.ident.to_string();
181                let param_type = get_type_name(ty);
182
183                // Skip the context parameter
184                if param_name != "context" && !param_type.contains("RunContext") {
185                    params.push((param_name, param_type));
186                }
187            }
188        }
189    }
190
191    params
192}
193
194// Helper function to get the name of a type
195fn get_type_name(ty: &Box<Type>) -> String {
196    match ty.as_ref() {
197        Type::Path(type_path) => {
198            if let Some(segment) = type_path.path.segments.last() {
199                segment.ident.to_string()
200            } else {
201                "unknown".to_string()
202            }
203        }
204        Type::Reference(type_ref) => {
205            if let Type::Path(type_path) = type_ref.elem.as_ref() {
206                if let Some(segment) = type_path.path.segments.last() {
207                    segment.ident.to_string()
208                } else {
209                    "unknown".to_string()
210                }
211            } else {
212                "unknown".to_string()
213            }
214        }
215        _ => "unknown".to_string(),
216    }
217}