use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::ItemFn;
use crate::codegen::{generate_compat_methods, generate_safe_methods, generate_schema_impl};
use crate::helpers::{
extract_doc_from_attrs, extract_fn_params, parse_tool_meta_tokens, snake_to_pascal,
};
pub(crate) fn expand_tool_for_fn(
args: TokenStream2,
func: ItemFn,
) -> Result<TokenStream2, syn::Error> {
let meta = parse_tool_meta_tokens(&args);
let name = if meta.name.is_empty() {
crate::helpers::ident_to_snake_case(&func.sig.ident.to_string())
} else {
meta.name.clone()
};
let description = if !meta.description.is_empty() {
meta.description.clone()
} else {
extract_doc_from_attrs(&func.attrs).unwrap_or_default()
};
let params = extract_fn_params(&func.sig.inputs)?;
let is_async = func.sig.asyncness.is_some();
let pascal_name = snake_to_pascal(&func.sig.ident.to_string());
let struct_name = format_ident!("{}Args", pascal_name);
let reg_fn_name = format_ident!("{}_tool", func.sig.ident);
let reg_fn_name_with = format_ident!("{}_tool_with", func.sig.ident);
let fn_name = &func.sig.ident;
let fields: Vec<syn::Field> = params
.iter()
.map(|p| {
let ident = &p.ident;
let ty = &p.ty;
let doc_attrs = &p.doc_attrs;
syn::parse_quote! {
#(#doc_attrs)*
pub #ident: #ty
}
})
.collect();
let arg_refs: Vec<proc_macro2::TokenStream> = params
.iter()
.map(|p| {
let ident = &p.ident;
quote! { args.#ident }
})
.collect();
let await_suffix = if is_async {
quote! { .await }
} else {
quote! {}
};
let mut cleaned_func = func.clone();
cleaned_func
.attrs
.retain(|attr| !attr.path().is_ident("tool"));
let visibility = &func.vis;
let schema_fn = generate_schema_impl(&struct_name);
let compat_methods = generate_compat_methods(&struct_name);
let safe_methods = generate_safe_methods(&struct_name);
Ok(quote! {
#cleaned_func
#[derive(
::lellm_agent::serde::Deserialize,
::lellm_agent::schemars::JsonSchema
)]
#visibility struct #struct_name {
#(#fields),*
}
impl ::lellm_agent::ToolArgs for #struct_name {
const NAME: &'static str = #name;
const DESCRIPTION: &'static str = #description;
#schema_fn
}
#compat_methods
#safe_methods
#visibility fn #reg_fn_name() -> ::lellm_agent::ToolRegistration {
#reg_fn_name_with(|args| async move {
#fn_name(#(#arg_refs),*) #await_suffix
})
}
#visibility fn #reg_fn_name_with<F, Fut>(f: F) -> ::lellm_agent::ToolRegistration
where
F: Fn(#struct_name) -> Fut + Send + Sync + 'static,
Fut: ::core::future::Future<Output = ::lellm_agent::ToolResult> + Send + 'static,
{
#struct_name::safe(f)
}
})
}