adk_rust_macros/lib.rs
1//! # adk-macros
2//!
3//! Proc macros for ADK-Rust that eliminate tool registration boilerplate.
4//!
5//! ## `#[tool]`
6//!
7//! Turns an async function into a fully-wired [`adk_tool::Tool`] implementation:
8//!
9//! ```rust,ignore
10//! use adk_macros::tool;
11//! use schemars::JsonSchema;
12//! use serde::Deserialize;
13//!
14//! #[derive(Deserialize, JsonSchema)]
15//! struct WeatherArgs {
16//! /// The city to look up
17//! city: String,
18//! }
19//!
20//! /// Get the current weather for a city.
21//! #[tool]
22//! async fn get_weather(args: WeatherArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
23//! Ok(serde_json::json!({ "temp": 72, "city": args.city }))
24//! }
25//!
26//! // This generates a struct `GetWeather` that implements `adk_tool::Tool`.
27//! // Use it like: Arc::new(GetWeather)
28//! ```
29//!
30//! The macro:
31//! - Uses the function's doc comment as the tool description
32//! - Derives the JSON schema from the argument type via `schemars::schema_for!`
33//! - Names the tool after the function (snake_case)
34//! - Generates a zero-sized struct (PascalCase) implementing `Tool`
35
36use proc_macro::TokenStream;
37use quote::{format_ident, quote};
38use syn::{FnArg, ItemFn, Type, parse_macro_input};
39
40/// Attribute macro that generates a `Tool` implementation from an async function.
41///
42/// # Requirements
43///
44/// - The function must be `async`
45/// - It must take exactly one argument (the args struct) that implements
46/// `serde::de::DeserializeOwned` and `schemars::JsonSchema`
47/// - It must return `Result<serde_json::Value, adk_tool::AdkError>`
48/// - Doc comments become the tool description
49///
50/// # Example
51///
52/// ```rust,ignore
53/// /// Search the knowledge base for documents matching a query.
54/// #[tool]
55/// async fn search_docs(args: SearchArgs) -> Result<serde_json::Value, adk_tool::AdkError> {
56/// // ...
57/// }
58///
59/// // Generated: pub struct SearchDocs; implements Tool
60/// // Use: agent_builder.tool(Arc::new(SearchDocs))
61/// ```
62#[proc_macro_attribute]
63pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
64 let input_fn = parse_macro_input!(item as ItemFn);
65
66 let fn_name = &input_fn.sig.ident;
67 let fn_vis = &input_fn.vis;
68
69 // Extract doc comments for description
70 let doc_lines: Vec<String> = input_fn
71 .attrs
72 .iter()
73 .filter(|attr| attr.path().is_ident("doc"))
74 .filter_map(|attr| {
75 if let syn::Meta::NameValue(nv) = &attr.meta {
76 if let syn::Expr::Lit(lit) = &nv.value {
77 if let syn::Lit::Str(s) = &lit.lit {
78 return Some(s.value().trim().to_string());
79 }
80 }
81 }
82 None
83 })
84 .collect();
85
86 let description = if doc_lines.is_empty() {
87 fn_name.to_string().replace('_', " ")
88 } else {
89 doc_lines.join(" ")
90 };
91
92 let tool_name_str = fn_name.to_string();
93
94 // Generate PascalCase struct name: get_weather → GetWeather
95 let struct_name = format_ident!(
96 "{}",
97 tool_name_str
98 .split('_')
99 .map(|seg| {
100 let mut chars = seg.chars();
101 match chars.next() {
102 None => String::new(),
103 Some(c) => c.to_uppercase().to_string() + chars.as_str(),
104 }
105 })
106 .collect::<String>()
107 );
108
109 // Extract the single argument type
110 let args_type = extract_args_type(&input_fn);
111
112 // Check if we have a typed args parameter or no params
113 let (schema_gen, deserialize_call) = if let Some(args_ty) = &args_type {
114 (
115 quote! {
116 {
117 let mut schema = serde_json::to_value(
118 schemars::schema_for!(#args_ty)
119 ).unwrap_or_default();
120 // Strip fields that Gemini/LLM APIs don't accept
121 if let Some(obj) = schema.as_object_mut() {
122 obj.remove("$schema");
123 obj.remove("title");
124 }
125 // Simplify nullable types: {"type": ["string", "null"]} → {"type": "string"}
126 fn simplify_nullable(v: &mut serde_json::Value) {
127 match v {
128 serde_json::Value::Object(map) => {
129 if let Some(serde_json::Value::Array(types)) = map.get("type") {
130 let non_null: Vec<_> = types.iter()
131 .filter(|t| t.as_str() != Some("null"))
132 .cloned()
133 .collect();
134 if non_null.len() == 1 {
135 map.insert("type".to_string(), non_null[0].clone());
136 }
137 }
138 // Remove anyOf wrappers for simple nullable types
139 if let Some(serde_json::Value::Array(any_of)) = map.remove("anyOf") {
140 for variant in &any_of {
141 if let Some(obj) = variant.as_object() {
142 if obj.get("type").and_then(|t| t.as_str()) != Some("null") {
143 for (k, val) in obj {
144 map.insert(k.clone(), val.clone());
145 }
146 break;
147 }
148 }
149 }
150 }
151 for val in map.values_mut() {
152 simplify_nullable(val);
153 }
154 }
155 serde_json::Value::Array(arr) => {
156 for item in arr {
157 simplify_nullable(item);
158 }
159 }
160 _ => {}
161 }
162 }
163 simplify_nullable(&mut schema);
164 Some(schema)
165 }
166 },
167 quote! {
168 let typed_args: #args_ty = serde_json::from_value(args)
169 .map_err(|e| adk_tool::AdkError::Tool(
170 format!("invalid arguments for '{}': {e}", #tool_name_str)
171 ))?;
172 #fn_name(typed_args).await
173 },
174 )
175 } else {
176 (
177 quote! { None },
178 quote! {
179 let _ = args;
180 #fn_name().await
181 },
182 )
183 };
184
185 // Check if the function signature includes ctx: Arc<dyn ToolContext>
186 let has_ctx = has_tool_context_param(&input_fn);
187 let execute_body = if has_ctx {
188 if let Some(args_ty) = &args_type {
189 quote! {
190 let typed_args: #args_ty = serde_json::from_value(args)
191 .map_err(|e| adk_tool::AdkError::Tool(
192 format!("invalid arguments for '{}': {e}", #tool_name_str)
193 ))?;
194 #fn_name(ctx, typed_args).await
195 }
196 } else {
197 quote! {
198 let _ = args;
199 #fn_name(ctx).await
200 }
201 }
202 } else {
203 deserialize_call
204 };
205
206 let output = quote! {
207 // Keep the original function
208 #input_fn
209
210 /// Auto-generated tool struct for [`#fn_name`].
211 #fn_vis struct #struct_name;
212
213 #[async_trait::async_trait]
214 impl adk_tool::Tool for #struct_name {
215 fn name(&self) -> &str {
216 #tool_name_str
217 }
218
219 fn description(&self) -> &str {
220 #description
221 }
222
223 fn parameters_schema(&self) -> Option<serde_json::Value> {
224 #schema_gen
225 }
226
227 async fn execute(
228 &self,
229 ctx: std::sync::Arc<dyn adk_tool::ToolContext>,
230 args: serde_json::Value,
231 ) -> adk_tool::Result<serde_json::Value> {
232 #execute_body
233 }
234 }
235 };
236
237 output.into()
238}
239
240/// Extract the args type from the function signature.
241/// Skips any `Arc<dyn ToolContext>` parameter.
242fn extract_args_type(func: &ItemFn) -> Option<Type> {
243 for arg in &func.sig.inputs {
244 if let FnArg::Typed(pat_type) = arg {
245 // Skip context parameters (Arc<dyn ToolContext>)
246 let ty_str = quote!(#pat_type.ty).to_string();
247 if ty_str.contains("ToolContext") || ty_str.contains("Arc") {
248 continue;
249 }
250 return Some((*pat_type.ty).clone());
251 }
252 }
253 None
254}
255
256/// Check if the function has an Arc<dyn ToolContext> parameter.
257fn has_tool_context_param(func: &ItemFn) -> bool {
258 func.sig.inputs.iter().any(|arg| {
259 if let FnArg::Typed(pat_type) = arg {
260 let ty = &pat_type.ty;
261 let ty_str = quote!(#ty).to_string();
262 ty_str.contains("ToolContext")
263 } else {
264 false
265 }
266 })
267}