Skip to main content

adk_rust_macros/
lib.rs

1//! # adk-macros
2//!
3//! Proc macros for ADK-Rust that eliminate tool registration boilerplate.
4//!
5//! ## `#[tool]`
6//!
7//! Turns an async function into a fully-wired `adk_tool::Tool` implementation:
8//!
9//! ```rust,ignore
10//! use adk_macros::tool;
11//! use schemars::JsonSchema;
12//! use serde::Deserialize;
13//!
14//! #[derive(Deserialize, JsonSchema)]
15//! struct WeatherArgs {
16//!     /// The city to look up
17//!     city: String,
18//! }
19//!
20//! /// Get the current weather for a city.
21//! #[tool]
22//! async fn get_weather(args: WeatherArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
23//!     Ok(serde_json::json!({ "temp": 72, "city": args.city }))
24//! }
25//!
26//! // This generates a struct `GetWeather` that implements `adk_tool::Tool`.
27//! // Use it like: Arc::new(GetWeather)
28//! ```
29//!
30//! The macro:
31//! - Uses the function's doc comment as the tool description
32//! - Derives the JSON schema from the argument type via `schemars::schema_for!`
33//! - Names the tool after the function (snake_case)
34//! - Generates a zero-sized struct (PascalCase) implementing `Tool`
35
36use proc_macro::TokenStream;
37use quote::{format_ident, quote};
38use syn::{FnArg, ItemFn, Meta, Type, parse_macro_input};
39
40/// Attribute macro that generates a `Tool` implementation from an async function.
41///
42/// # Requirements
43///
44/// - The function must be `async`
45/// - It must take exactly one argument (the args struct) that implements
46///   `serde::de::DeserializeOwned` and `schemars::JsonSchema`
47/// - It must return `Result<serde_json::Value, adk_tool::AdkError>`
48/// - Doc comments become the tool description
49///
50/// # Attributes
51///
52/// Optional attributes can be passed to configure tool metadata:
53///
54/// - `read_only` — marks the tool as having no side effects (`is_read_only() → true`)
55/// - `concurrency_safe` — marks the tool as safe for concurrent execution (`is_concurrency_safe() → true`)
56/// - `long_running` — marks the tool as long-running (`is_long_running() → true`)
57///
58/// # Examples
59///
60/// ```rust,ignore
61/// /// Search the knowledge base for documents matching a query.
62/// #[tool]
63/// async fn search_docs(args: SearchArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
64///     // ...
65/// }
66///
67/// /// Look up cached data (read-only, safe for parallel dispatch).
68/// #[tool(read_only, concurrency_safe)]
69/// async fn cache_lookup(args: LookupArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
70///     // ...
71/// }
72///
73/// // Generated: pub struct SearchDocs; implements Tool
74/// // Use: agent_builder.tool(Arc::new(SearchDocs))
75/// ```
76#[proc_macro_attribute]
77pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
78    let input_fn = parse_macro_input!(item as ItemFn);
79
80    // Parse optional attributes: #[tool(read_only, concurrency_safe, long_running)]
81    let mut is_read_only = false;
82    let mut is_concurrency_safe = false;
83    let mut is_long_running = false;
84
85    if !attr.is_empty() {
86        let meta = parse_macro_input!(attr as ToolAttrs);
87        is_read_only = meta.read_only;
88        is_concurrency_safe = meta.concurrency_safe;
89        is_long_running = meta.long_running;
90    }
91
92    let fn_name = &input_fn.sig.ident;
93    let fn_vis = &input_fn.vis;
94
95    // Extract doc comments for description
96    let doc_lines: Vec<String> = input_fn
97        .attrs
98        .iter()
99        .filter(|attr| attr.path().is_ident("doc"))
100        .filter_map(|attr| {
101            if let syn::Meta::NameValue(nv) = &attr.meta {
102                if let syn::Expr::Lit(lit) = &nv.value {
103                    if let syn::Lit::Str(s) = &lit.lit {
104                        return Some(s.value().trim().to_string());
105                    }
106                }
107            }
108            None
109        })
110        .collect();
111
112    let description = if doc_lines.is_empty() {
113        fn_name.to_string().replace('_', " ")
114    } else {
115        doc_lines.join(" ")
116    };
117
118    let tool_name_str = fn_name.to_string();
119
120    // Generate PascalCase struct name: get_weather → GetWeather
121    let struct_name = format_ident!(
122        "{}",
123        tool_name_str
124            .split('_')
125            .map(|seg| {
126                let mut chars = seg.chars();
127                match chars.next() {
128                    None => String::new(),
129                    Some(c) => c.to_uppercase().to_string() + chars.as_str(),
130                }
131            })
132            .collect::<String>()
133    );
134
135    // Extract the single argument type
136    let args_type = extract_args_type(&input_fn);
137
138    // Check if we have a typed args parameter or no params
139    let (schema_gen, deserialize_call) = if let Some(args_ty) = &args_type {
140        (
141            quote! {
142                {
143                    let mut schema = serde_json::to_value(
144                        schemars::schema_for!(#args_ty)
145                    ).unwrap_or_default();
146                    // Strip fields that Gemini/LLM APIs don't accept
147                    if let Some(obj) = schema.as_object_mut() {
148                        obj.remove("$schema");
149                        obj.remove("title");
150                    }
151                    // Simplify nullable types: {"type": ["string", "null"]} → {"type": "string"}
152                    fn simplify_nullable(v: &mut serde_json::Value) {
153                        match v {
154                            serde_json::Value::Object(map) => {
155                                if let Some(serde_json::Value::Array(types)) = map.get("type") {
156                                    let non_null: Vec<_> = types.iter()
157                                        .filter(|t| t.as_str() != Some("null"))
158                                        .cloned()
159                                        .collect();
160                                    if non_null.len() == 1 {
161                                        map.insert("type".to_string(), non_null[0].clone());
162                                    }
163                                }
164                                // Remove anyOf wrappers for simple nullable types
165                                if let Some(serde_json::Value::Array(any_of)) = map.remove("anyOf") {
166                                    for variant in &any_of {
167                                        if let Some(obj) = variant.as_object() {
168                                            if obj.get("type").and_then(|t| t.as_str()) != Some("null") {
169                                                for (k, val) in obj {
170                                                    map.insert(k.clone(), val.clone());
171                                                }
172                                                break;
173                                            }
174                                        }
175                                    }
176                                }
177                                for val in map.values_mut() {
178                                    simplify_nullable(val);
179                                }
180                            }
181                            serde_json::Value::Array(arr) => {
182                                for item in arr {
183                                    simplify_nullable(item);
184                                }
185                            }
186                            _ => {}
187                        }
188                    }
189                    simplify_nullable(&mut schema);
190                    Some(schema)
191                }
192            },
193            quote! {
194                let typed_args: #args_ty = serde_json::from_value(args)
195                    .map_err(|e| adk_tool::AdkError::tool(
196                        format!("invalid arguments for '{}': {e}", #tool_name_str)
197                    ))?;
198                #fn_name(typed_args).await
199            },
200        )
201    } else {
202        (
203            quote! { None },
204            quote! {
205                let _ = args;
206                #fn_name().await
207            },
208        )
209    };
210
211    // Check if the function signature includes ctx: Arc<dyn ToolContext>
212    let has_ctx = has_tool_context_param(&input_fn);
213    let execute_body = if has_ctx {
214        if let Some(args_ty) = &args_type {
215            quote! {
216                let typed_args: #args_ty = serde_json::from_value(args)
217                    .map_err(|e| adk_tool::AdkError::tool(
218                        format!("invalid arguments for '{}': {e}", #tool_name_str)
219                    ))?;
220                #fn_name(ctx, typed_args).await
221            }
222        } else {
223            quote! {
224                let _ = args;
225                #fn_name(ctx).await
226            }
227        }
228    } else {
229        deserialize_call
230    };
231
232    // Generate optional trait method overrides
233    let read_only_override = if is_read_only {
234        quote! {
235            fn is_read_only(&self) -> bool { true }
236        }
237    } else {
238        quote! {}
239    };
240
241    let concurrency_safe_override = if is_concurrency_safe {
242        quote! {
243            fn is_concurrency_safe(&self) -> bool { true }
244        }
245    } else {
246        quote! {}
247    };
248
249    let long_running_override = if is_long_running {
250        quote! {
251            fn is_long_running(&self) -> bool { true }
252        }
253    } else {
254        quote! {}
255    };
256
257    let output = quote! {
258        // Keep the original function
259        #input_fn
260
261        /// Auto-generated tool struct for [`#fn_name`].
262        #fn_vis struct #struct_name;
263
264        #[adk_tool::async_trait]
265        impl adk_tool::Tool for #struct_name {
266            fn name(&self) -> &str {
267                #tool_name_str
268            }
269
270            fn description(&self) -> &str {
271                #description
272            }
273
274            fn parameters_schema(&self) -> Option<serde_json::Value> {
275                #schema_gen
276            }
277
278            #read_only_override
279            #concurrency_safe_override
280            #long_running_override
281
282            async fn execute(
283                &self,
284                ctx: std::sync::Arc<dyn adk_tool::ToolContext>,
285                args: serde_json::Value,
286            ) -> adk_tool::Result<serde_json::Value> {
287                #execute_body
288            }
289        }
290    };
291
292    output.into()
293}
294
295/// Extract the args type from the function signature.
296/// Skips any `Arc<dyn ToolContext>` parameter.
297fn extract_args_type(func: &ItemFn) -> Option<Type> {
298    for arg in &func.sig.inputs {
299        if let FnArg::Typed(pat_type) = arg {
300            // Skip context parameters (Arc<dyn ToolContext>)
301            let ty = &pat_type.ty;
302            let ty_str = quote!(#ty).to_string();
303            if ty_str.contains("ToolContext") || ty_str.contains("Arc") {
304                continue;
305            }
306            return Some((*pat_type.ty).clone());
307        }
308    }
309    None
310}
311
312/// Check if the function has an Arc<dyn ToolContext> parameter.
313fn has_tool_context_param(func: &ItemFn) -> bool {
314    func.sig.inputs.iter().any(|arg| {
315        if let FnArg::Typed(pat_type) = arg {
316            let ty = &pat_type.ty;
317            let ty_str = quote!(#ty).to_string();
318            ty_str.contains("ToolContext")
319        } else {
320            false
321        }
322    })
323}
324
325/// Parsed attributes from `#[tool(read_only, concurrency_safe, long_running)]`.
326struct ToolAttrs {
327    read_only: bool,
328    concurrency_safe: bool,
329    long_running: bool,
330}
331
332impl syn::parse::Parse for ToolAttrs {
333    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
334        let mut attrs =
335            ToolAttrs { read_only: false, concurrency_safe: false, long_running: false };
336
337        let punctuated =
338            syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated(input)?;
339
340        for meta in punctuated {
341            if let Meta::Path(path) = &meta {
342                if path.is_ident("read_only") {
343                    attrs.read_only = true;
344                } else if path.is_ident("concurrency_safe") {
345                    attrs.concurrency_safe = true;
346                } else if path.is_ident("long_running") {
347                    attrs.long_running = true;
348                } else {
349                    return Err(syn::Error::new_spanned(
350                        path,
351                        "unknown tool attribute; expected `read_only`, `concurrency_safe`, or `long_running`",
352                    ));
353                }
354            } else {
355                return Err(syn::Error::new_spanned(
356                    meta,
357                    "expected identifier (e.g., `read_only`), not key-value",
358                ));
359            }
360        }
361
362        Ok(attrs)
363    }
364}