use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{FnArg, ItemFn, Meta, Type, parse_macro_input};
#[proc_macro_attribute]
pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as ItemFn);
let mut is_read_only = false;
let mut is_concurrency_safe = false;
let mut is_long_running = false;
if !attr.is_empty() {
let meta = parse_macro_input!(attr as ToolAttrs);
is_read_only = meta.read_only;
is_concurrency_safe = meta.concurrency_safe;
is_long_running = meta.long_running;
}
let fn_name = &input_fn.sig.ident;
let fn_vis = &input_fn.vis;
let doc_lines: Vec<String> = input_fn
.attrs
.iter()
.filter(|attr| attr.path().is_ident("doc"))
.filter_map(|attr| {
if let syn::Meta::NameValue(nv) = &attr.meta {
if let syn::Expr::Lit(lit) = &nv.value {
if let syn::Lit::Str(s) = &lit.lit {
return Some(s.value().trim().to_string());
}
}
}
None
})
.collect();
let description = if doc_lines.is_empty() {
fn_name.to_string().replace('_', " ")
} else {
doc_lines.join(" ")
};
let tool_name_str = fn_name.to_string();
let struct_name = format_ident!(
"{}",
tool_name_str
.split('_')
.map(|seg| {
let mut chars = seg.chars();
match chars.next() {
None => String::new(),
Some(c) => c.to_uppercase().to_string() + chars.as_str(),
}
})
.collect::<String>()
);
let args_type = extract_args_type(&input_fn);
let (schema_gen, deserialize_call) = if let Some(args_ty) = &args_type {
(
quote! {
{
let mut schema = serde_json::to_value(
schemars::schema_for!(#args_ty)
).unwrap_or_default();
if let Some(obj) = schema.as_object_mut() {
obj.remove("$schema");
obj.remove("title");
}
fn simplify_nullable(v: &mut serde_json::Value) {
match v {
serde_json::Value::Object(map) => {
if let Some(serde_json::Value::Array(types)) = map.get("type") {
let non_null: Vec<_> = types.iter()
.filter(|t| t.as_str() != Some("null"))
.cloned()
.collect();
if non_null.len() == 1 {
map.insert("type".to_string(), non_null[0].clone());
}
}
if let Some(serde_json::Value::Array(any_of)) = map.remove("anyOf") {
for variant in &any_of {
if let Some(obj) = variant.as_object() {
if obj.get("type").and_then(|t| t.as_str()) != Some("null") {
for (k, val) in obj {
map.insert(k.clone(), val.clone());
}
break;
}
}
}
}
for val in map.values_mut() {
simplify_nullable(val);
}
}
serde_json::Value::Array(arr) => {
for item in arr {
simplify_nullable(item);
}
}
_ => {}
}
}
simplify_nullable(&mut schema);
Some(schema)
}
},
quote! {
let typed_args: #args_ty = serde_json::from_value(args)
.map_err(|e| adk_tool::AdkError::tool(
format!("invalid arguments for '{}': {e}", #tool_name_str)
))?;
#fn_name(typed_args).await
},
)
} else {
(
quote! { None },
quote! {
let _ = args;
#fn_name().await
},
)
};
let has_ctx = has_tool_context_param(&input_fn);
let execute_body = if has_ctx {
if let Some(args_ty) = &args_type {
quote! {
let typed_args: #args_ty = serde_json::from_value(args)
.map_err(|e| adk_tool::AdkError::tool(
format!("invalid arguments for '{}': {e}", #tool_name_str)
))?;
#fn_name(ctx, typed_args).await
}
} else {
quote! {
let _ = args;
#fn_name(ctx).await
}
}
} else {
deserialize_call
};
let read_only_override = if is_read_only {
quote! {
fn is_read_only(&self) -> bool { true }
}
} else {
quote! {}
};
let concurrency_safe_override = if is_concurrency_safe {
quote! {
fn is_concurrency_safe(&self) -> bool { true }
}
} else {
quote! {}
};
let long_running_override = if is_long_running {
quote! {
fn is_long_running(&self) -> bool { true }
}
} else {
quote! {}
};
let output = quote! {
#input_fn
#fn_vis struct #struct_name;
#[adk_tool::async_trait]
impl adk_tool::Tool for #struct_name {
fn name(&self) -> &str {
#tool_name_str
}
fn description(&self) -> &str {
#description
}
fn parameters_schema(&self) -> Option<serde_json::Value> {
#schema_gen
}
#read_only_override
#concurrency_safe_override
#long_running_override
async fn execute(
&self,
ctx: std::sync::Arc<dyn adk_tool::ToolContext>,
args: serde_json::Value,
) -> adk_tool::Result<serde_json::Value> {
#execute_body
}
}
};
output.into()
}
fn extract_args_type(func: &ItemFn) -> Option<Type> {
for arg in &func.sig.inputs {
if let FnArg::Typed(pat_type) = arg {
let ty = &pat_type.ty;
let ty_str = quote!(#ty).to_string();
if ty_str.contains("ToolContext") || ty_str.contains("Arc") {
continue;
}
return Some((*pat_type.ty).clone());
}
}
None
}
fn has_tool_context_param(func: &ItemFn) -> bool {
func.sig.inputs.iter().any(|arg| {
if let FnArg::Typed(pat_type) = arg {
let ty = &pat_type.ty;
let ty_str = quote!(#ty).to_string();
ty_str.contains("ToolContext")
} else {
false
}
})
}
struct ToolAttrs {
read_only: bool,
concurrency_safe: bool,
long_running: bool,
}
impl syn::parse::Parse for ToolAttrs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut attrs =
ToolAttrs { read_only: false, concurrency_safe: false, long_running: false };
let punctuated =
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated(input)?;
for meta in punctuated {
if let Meta::Path(path) = &meta {
if path.is_ident("read_only") {
attrs.read_only = true;
} else if path.is_ident("concurrency_safe") {
attrs.concurrency_safe = true;
} else if path.is_ident("long_running") {
attrs.long_running = true;
} else {
return Err(syn::Error::new_spanned(
path,
"unknown tool attribute; expected `read_only`, `concurrency_safe`, or `long_running`",
));
}
} else {
return Err(syn::Error::new_spanned(
meta,
"expected identifier (e.g., `read_only`), not key-value",
));
}
}
Ok(attrs)
}
}