Skip to main content

adk_rust_macros/
lib.rs

1//! # adk-macros
2//!
3//! Proc macros for ADK-Rust that eliminate tool registration boilerplate.
4//!
5//! ## `#[tool]`
6//!
7//! Turns an async function into a fully-wired [`adk_tool::Tool`] implementation:
8//!
9//! ```rust,ignore
10//! use adk_macros::tool;
11//! use schemars::JsonSchema;
12//! use serde::Deserialize;
13//!
14//! #[derive(Deserialize, JsonSchema)]
15//! struct WeatherArgs {
16//!     /// The city to look up
17//!     city: String,
18//! }
19//!
20//! /// Get the current weather for a city.
21//! #[tool]
22//! async fn get_weather(args: WeatherArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
23//!     Ok(serde_json::json!({ "temp": 72, "city": args.city }))
24//! }
25//!
26//! // This generates a struct `GetWeather` that implements `adk_tool::Tool`.
27//! // Use it like: Arc::new(GetWeather)
28//! ```
29//!
30//! The macro:
31//! - Uses the function's doc comment as the tool description
32//! - Derives the JSON schema from the argument type via `schemars::schema_for!`
33//! - Names the tool after the function (snake_case)
34//! - Generates a zero-sized struct (PascalCase) implementing `Tool`
35
36use proc_macro::TokenStream;
37use quote::{format_ident, quote};
38use syn::{FnArg, ItemFn, Type, parse_macro_input};
39
40/// Attribute macro that generates a `Tool` implementation from an async function.
41///
42/// # Requirements
43///
44/// - The function must be `async`
45/// - It must take exactly one argument (the args struct) that implements
46///   `serde::de::DeserializeOwned` and `schemars::JsonSchema`
47/// - It must return `Result<serde_json::Value, adk_tool::AdkError>`
48/// - Doc comments become the tool description
49///
50/// # Example
51///
52/// ```rust,ignore
53/// /// Search the knowledge base for documents matching a query.
54/// #[tool]
55/// async fn search_docs(args: SearchArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
56///     // ...
57/// }
58///
59/// // Generated: pub struct SearchDocs; implements Tool
60/// // Use: agent_builder.tool(Arc::new(SearchDocs))
61/// ```
62#[proc_macro_attribute]
63pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
64    let input_fn = parse_macro_input!(item as ItemFn);
65
66    let fn_name = &input_fn.sig.ident;
67    let fn_vis = &input_fn.vis;
68
69    // Extract doc comments for description
70    let doc_lines: Vec<String> = input_fn
71        .attrs
72        .iter()
73        .filter(|attr| attr.path().is_ident("doc"))
74        .filter_map(|attr| {
75            if let syn::Meta::NameValue(nv) = &attr.meta {
76                if let syn::Expr::Lit(lit) = &nv.value {
77                    if let syn::Lit::Str(s) = &lit.lit {
78                        return Some(s.value().trim().to_string());
79                    }
80                }
81            }
82            None
83        })
84        .collect();
85
86    let description = if doc_lines.is_empty() {
87        fn_name.to_string().replace('_', " ")
88    } else {
89        doc_lines.join(" ")
90    };
91
92    let tool_name_str = fn_name.to_string();
93
94    // Generate PascalCase struct name: get_weather → GetWeather
95    let struct_name = format_ident!(
96        "{}",
97        tool_name_str
98            .split('_')
99            .map(|seg| {
100                let mut chars = seg.chars();
101                match chars.next() {
102                    None => String::new(),
103                    Some(c) => c.to_uppercase().to_string() + chars.as_str(),
104                }
105            })
106            .collect::<String>()
107    );
108
109    // Extract the single argument type
110    let args_type = extract_args_type(&input_fn);
111
112    // Check if we have a typed args parameter or no params
113    let (schema_gen, deserialize_call) = if let Some(args_ty) = &args_type {
114        (
115            quote! {
116                {
117                    let mut schema = serde_json::to_value(
118                        schemars::schema_for!(#args_ty)
119                    ).unwrap_or_default();
120                    // Strip fields that Gemini/LLM APIs don't accept
121                    if let Some(obj) = schema.as_object_mut() {
122                        obj.remove("$schema");
123                        obj.remove("title");
124                    }
125                    // Simplify nullable types: {"type": ["string", "null"]} → {"type": "string"}
126                    fn simplify_nullable(v: &mut serde_json::Value) {
127                        match v {
128                            serde_json::Value::Object(map) => {
129                                if let Some(serde_json::Value::Array(types)) = map.get("type") {
130                                    let non_null: Vec<_> = types.iter()
131                                        .filter(|t| t.as_str() != Some("null"))
132                                        .cloned()
133                                        .collect();
134                                    if non_null.len() == 1 {
135                                        map.insert("type".to_string(), non_null[0].clone());
136                                    }
137                                }
138                                // Remove anyOf wrappers for simple nullable types
139                                if let Some(serde_json::Value::Array(any_of)) = map.remove("anyOf") {
140                                    for variant in &any_of {
141                                        if let Some(obj) = variant.as_object() {
142                                            if obj.get("type").and_then(|t| t.as_str()) != Some("null") {
143                                                for (k, val) in obj {
144                                                    map.insert(k.clone(), val.clone());
145                                                }
146                                                break;
147                                            }
148                                        }
149                                    }
150                                }
151                                for val in map.values_mut() {
152                                    simplify_nullable(val);
153                                }
154                            }
155                            serde_json::Value::Array(arr) => {
156                                for item in arr {
157                                    simplify_nullable(item);
158                                }
159                            }
160                            _ => {}
161                        }
162                    }
163                    simplify_nullable(&mut schema);
164                    Some(schema)
165                }
166            },
167            quote! {
168                let typed_args: #args_ty = serde_json::from_value(args)
169                    .map_err(|e| adk_tool::AdkError::Tool(
170                        format!("invalid arguments for '{}': {e}", #tool_name_str)
171                    ))?;
172                #fn_name(typed_args).await
173            },
174        )
175    } else {
176        (
177            quote! { None },
178            quote! {
179                let _ = args;
180                #fn_name().await
181            },
182        )
183    };
184
185    // Check if the function signature includes ctx: Arc<dyn ToolContext>
186    let has_ctx = has_tool_context_param(&input_fn);
187    let execute_body = if has_ctx {
188        if let Some(args_ty) = &args_type {
189            quote! {
190                let typed_args: #args_ty = serde_json::from_value(args)
191                    .map_err(|e| adk_tool::AdkError::Tool(
192                        format!("invalid arguments for '{}': {e}", #tool_name_str)
193                    ))?;
194                #fn_name(ctx, typed_args).await
195            }
196        } else {
197            quote! {
198                let _ = args;
199                #fn_name(ctx).await
200            }
201        }
202    } else {
203        deserialize_call
204    };
205
206    let output = quote! {
207        // Keep the original function
208        #input_fn
209
210        /// Auto-generated tool struct for [`#fn_name`].
211        #fn_vis struct #struct_name;
212
213        #[async_trait::async_trait]
214        impl adk_tool::Tool for #struct_name {
215            fn name(&self) -> &str {
216                #tool_name_str
217            }
218
219            fn description(&self) -> &str {
220                #description
221            }
222
223            fn parameters_schema(&self) -> Option<serde_json::Value> {
224                #schema_gen
225            }
226
227            async fn execute(
228                &self,
229                ctx: std::sync::Arc<dyn adk_tool::ToolContext>,
230                args: serde_json::Value,
231            ) -> adk_tool::Result<serde_json::Value> {
232                #execute_body
233            }
234        }
235    };
236
237    output.into()
238}
239
240/// Extract the args type from the function signature.
241/// Skips any `Arc<dyn ToolContext>` parameter.
242fn extract_args_type(func: &ItemFn) -> Option<Type> {
243    for arg in &func.sig.inputs {
244        if let FnArg::Typed(pat_type) = arg {
245            // Skip context parameters (Arc<dyn ToolContext>)
246            let ty_str = quote!(#pat_type.ty).to_string();
247            if ty_str.contains("ToolContext") || ty_str.contains("Arc") {
248                continue;
249            }
250            return Some((*pat_type.ty).clone());
251        }
252    }
253    None
254}
255
256/// Check if the function has an Arc<dyn ToolContext> parameter.
257fn has_tool_context_param(func: &ItemFn) -> bool {
258    func.sig.inputs.iter().any(|arg| {
259        if let FnArg::Typed(pat_type) = arg {
260            let ty = &pat_type.ty;
261            let ty_str = quote!(#ty).to_string();
262            ty_str.contains("ToolContext")
263        } else {
264            false
265        }
266    })
267}