adk_rust_macros/lib.rs
1//! # adk-macros
2//!
3//! Proc macros for ADK-Rust that eliminate tool registration boilerplate.
4//!
5//! ## `#[tool]`
6//!
7//! Turns an async function into a fully-wired `adk_tool::Tool` implementation:
8//!
9//! ```rust,ignore
10//! use adk_macros::tool;
11//! use schemars::JsonSchema;
12//! use serde::Deserialize;
13//!
14//! #[derive(Deserialize, JsonSchema)]
15//! struct WeatherArgs {
16//! /// The city to look up
17//! city: String,
18//! }
19//!
20//! /// Get the current weather for a city.
21//! #[tool]
22//! async fn get_weather(args: WeatherArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
23//! Ok(serde_json::json!({ "temp": 72, "city": args.city }))
24//! }
25//!
26//! // This generates a struct `GetWeather` that implements `adk_tool::Tool`.
27//! // Use it like: Arc::new(GetWeather)
28//! ```
29//!
30//! The macro:
31//! - Uses the function's doc comment as the tool description
32//! - Derives the JSON schema from the argument type via `schemars::schema_for!`
33//! - Names the tool after the function (snake_case)
34//! - Generates a zero-sized struct (PascalCase) implementing `Tool`
35
36use proc_macro::TokenStream;
37use quote::{format_ident, quote};
38use syn::{FnArg, ItemFn, Meta, Type, parse_macro_input};
39
40/// Attribute macro that generates a `Tool` implementation from an async function.
41///
42/// # Requirements
43///
44/// - The function must be `async`
45/// - It must take exactly one argument (the args struct) that implements
46/// `serde::de::DeserializeOwned` and `schemars::JsonSchema`
47/// - It must return `Result<serde_json::Value, adk_tool::AdkError>`
48/// - Doc comments become the tool description
49///
50/// # Attributes
51///
52/// Optional attributes can be passed to configure tool metadata:
53///
54/// - `read_only` — marks the tool as having no side effects (`is_read_only() → true`)
55/// - `concurrency_safe` — marks the tool as safe for concurrent execution (`is_concurrency_safe() → true`)
56/// - `long_running` — marks the tool as long-running (`is_long_running() → true`)
57///
58/// # Examples
59///
60/// ```rust,ignore
61/// /// Search the knowledge base for documents matching a query.
62/// #[tool]
63/// async fn search_docs(args: SearchArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
64/// // ...
65/// }
66///
67/// /// Look up cached data (read-only, safe for parallel dispatch).
68/// #[tool(read_only, concurrency_safe)]
69/// async fn cache_lookup(args: LookupArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
70/// // ...
71/// }
72///
73/// // Generated: pub struct SearchDocs; implements Tool
74/// // Use: agent_builder.tool(Arc::new(SearchDocs))
75/// ```
76#[proc_macro_attribute]
77pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
78 let input_fn = parse_macro_input!(item as ItemFn);
79
80 // Parse optional attributes: #[tool(read_only, concurrency_safe, long_running)]
81 let mut is_read_only = false;
82 let mut is_concurrency_safe = false;
83 let mut is_long_running = false;
84
85 if !attr.is_empty() {
86 let meta = parse_macro_input!(attr as ToolAttrs);
87 is_read_only = meta.read_only;
88 is_concurrency_safe = meta.concurrency_safe;
89 is_long_running = meta.long_running;
90 }
91
92 let fn_name = &input_fn.sig.ident;
93 let fn_vis = &input_fn.vis;
94
95 // Extract doc comments for description
96 let doc_lines: Vec<String> = input_fn
97 .attrs
98 .iter()
99 .filter(|attr| attr.path().is_ident("doc"))
100 .filter_map(|attr| {
101 if let syn::Meta::NameValue(nv) = &attr.meta {
102 if let syn::Expr::Lit(lit) = &nv.value {
103 if let syn::Lit::Str(s) = &lit.lit {
104 return Some(s.value().trim().to_string());
105 }
106 }
107 }
108 None
109 })
110 .collect();
111
112 let description = if doc_lines.is_empty() {
113 fn_name.to_string().replace('_', " ")
114 } else {
115 doc_lines.join(" ")
116 };
117
118 let tool_name_str = fn_name.to_string();
119
120 // Generate PascalCase struct name: get_weather → GetWeather
121 let struct_name = format_ident!(
122 "{}",
123 tool_name_str
124 .split('_')
125 .map(|seg| {
126 let mut chars = seg.chars();
127 match chars.next() {
128 None => String::new(),
129 Some(c) => c.to_uppercase().to_string() + chars.as_str(),
130 }
131 })
132 .collect::<String>()
133 );
134
135 // Extract the single argument type
136 let args_type = extract_args_type(&input_fn);
137
138 // Check if we have a typed args parameter or no params
139 let (schema_gen, deserialize_call) = if let Some(args_ty) = &args_type {
140 (
141 quote! {
142 {
143 let mut schema = serde_json::to_value(
144 schemars::schema_for!(#args_ty)
145 ).unwrap_or_default();
146 // Strip fields that Gemini/LLM APIs don't accept
147 if let Some(obj) = schema.as_object_mut() {
148 obj.remove("$schema");
149 obj.remove("title");
150 }
151 // Simplify nullable types: {"type": ["string", "null"]} → {"type": "string"}
152 fn simplify_nullable(v: &mut serde_json::Value) {
153 match v {
154 serde_json::Value::Object(map) => {
155 if let Some(serde_json::Value::Array(types)) = map.get("type") {
156 let non_null: Vec<_> = types.iter()
157 .filter(|t| t.as_str() != Some("null"))
158 .cloned()
159 .collect();
160 if non_null.len() == 1 {
161 map.insert("type".to_string(), non_null[0].clone());
162 }
163 }
164 // Remove anyOf wrappers for simple nullable types
165 if let Some(serde_json::Value::Array(any_of)) = map.remove("anyOf") {
166 for variant in &any_of {
167 if let Some(obj) = variant.as_object() {
168 if obj.get("type").and_then(|t| t.as_str()) != Some("null") {
169 for (k, val) in obj {
170 map.insert(k.clone(), val.clone());
171 }
172 break;
173 }
174 }
175 }
176 }
177 for val in map.values_mut() {
178 simplify_nullable(val);
179 }
180 }
181 serde_json::Value::Array(arr) => {
182 for item in arr {
183 simplify_nullable(item);
184 }
185 }
186 _ => {}
187 }
188 }
189 simplify_nullable(&mut schema);
190 Some(schema)
191 }
192 },
193 quote! {
194 let typed_args: #args_ty = serde_json::from_value(args)
195 .map_err(|e| adk_tool::AdkError::tool(
196 format!("invalid arguments for '{}': {e}", #tool_name_str)
197 ))?;
198 #fn_name(typed_args).await
199 },
200 )
201 } else {
202 (
203 quote! { None },
204 quote! {
205 let _ = args;
206 #fn_name().await
207 },
208 )
209 };
210
211 // Check if the function signature includes ctx: Arc<dyn ToolContext>
212 let has_ctx = has_tool_context_param(&input_fn);
213 let execute_body = if has_ctx {
214 if let Some(args_ty) = &args_type {
215 quote! {
216 let typed_args: #args_ty = serde_json::from_value(args)
217 .map_err(|e| adk_tool::AdkError::tool(
218 format!("invalid arguments for '{}': {e}", #tool_name_str)
219 ))?;
220 #fn_name(ctx, typed_args).await
221 }
222 } else {
223 quote! {
224 let _ = args;
225 #fn_name(ctx).await
226 }
227 }
228 } else {
229 deserialize_call
230 };
231
232 // Generate optional trait method overrides
233 let read_only_override = if is_read_only {
234 quote! {
235 fn is_read_only(&self) -> bool { true }
236 }
237 } else {
238 quote! {}
239 };
240
241 let concurrency_safe_override = if is_concurrency_safe {
242 quote! {
243 fn is_concurrency_safe(&self) -> bool { true }
244 }
245 } else {
246 quote! {}
247 };
248
249 let long_running_override = if is_long_running {
250 quote! {
251 fn is_long_running(&self) -> bool { true }
252 }
253 } else {
254 quote! {}
255 };
256
257 let output = quote! {
258 // Keep the original function
259 #input_fn
260
261 /// Auto-generated tool struct for [`#fn_name`].
262 #fn_vis struct #struct_name;
263
264 #[adk_tool::async_trait]
265 impl adk_tool::Tool for #struct_name {
266 fn name(&self) -> &str {
267 #tool_name_str
268 }
269
270 fn description(&self) -> &str {
271 #description
272 }
273
274 fn parameters_schema(&self) -> Option<serde_json::Value> {
275 #schema_gen
276 }
277
278 #read_only_override
279 #concurrency_safe_override
280 #long_running_override
281
282 async fn execute(
283 &self,
284 ctx: std::sync::Arc<dyn adk_tool::ToolContext>,
285 args: serde_json::Value,
286 ) -> adk_tool::Result<serde_json::Value> {
287 #execute_body
288 }
289 }
290 };
291
292 output.into()
293}
294
295/// Extract the args type from the function signature.
296/// Skips any `Arc<dyn ToolContext>` parameter.
297fn extract_args_type(func: &ItemFn) -> Option<Type> {
298 for arg in &func.sig.inputs {
299 if let FnArg::Typed(pat_type) = arg {
300 // Skip context parameters (Arc<dyn ToolContext>)
301 let ty = &pat_type.ty;
302 let ty_str = quote!(#ty).to_string();
303 if ty_str.contains("ToolContext") || ty_str.contains("Arc") {
304 continue;
305 }
306 return Some((*pat_type.ty).clone());
307 }
308 }
309 None
310}
311
312/// Check if the function has an Arc<dyn ToolContext> parameter.
313fn has_tool_context_param(func: &ItemFn) -> bool {
314 func.sig.inputs.iter().any(|arg| {
315 if let FnArg::Typed(pat_type) = arg {
316 let ty = &pat_type.ty;
317 let ty_str = quote!(#ty).to_string();
318 ty_str.contains("ToolContext")
319 } else {
320 false
321 }
322 })
323}
324
325/// Parsed attributes from `#[tool(read_only, concurrency_safe, long_running)]`.
326struct ToolAttrs {
327 read_only: bool,
328 concurrency_safe: bool,
329 long_running: bool,
330}
331
332impl syn::parse::Parse for ToolAttrs {
333 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
334 let mut attrs =
335 ToolAttrs { read_only: false, concurrency_safe: false, long_running: false };
336
337 let punctuated =
338 syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated(input)?;
339
340 for meta in punctuated {
341 if let Meta::Path(path) = &meta {
342 if path.is_ident("read_only") {
343 attrs.read_only = true;
344 } else if path.is_ident("concurrency_safe") {
345 attrs.concurrency_safe = true;
346 } else if path.is_ident("long_running") {
347 attrs.long_running = true;
348 } else {
349 return Err(syn::Error::new_spanned(
350 path,
351 "unknown tool attribute; expected `read_only`, `concurrency_safe`, or `long_running`",
352 ));
353 }
354 } else {
355 return Err(syn::Error::new_spanned(
356 meta,
357 "expected identifier (e.g., `read_only`), not key-value",
358 ));
359 }
360 }
361
362 Ok(attrs)
363 }
364}