use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{FnArg, GenericArgument, ItemFn, Pat, PatType, PathArguments, Type, parse_macro_input};
#[proc_macro_attribute]
pub fn llm_tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
let func = parse_macro_input!(item as ItemFn);
match tool_impl(&func) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
struct ParamInfo {
name: syn::Ident,
ty: Box<syn::Type>,
doc_attrs: Vec<syn::Attribute>,
is_context: bool,
}
enum ReturnInfo {
ResultType {
ok_type: Box<syn::Type>,
err_type: Box<syn::Type>,
},
BareType,
}
fn tool_impl(func: &ItemFn) -> syn::Result<proc_macro2::TokenStream> {
let crate_path = quote! { ::llm_tool };
let fn_name = &func.sig.ident;
let tool_name_str = fn_name.to_string();
let struct_name = format_ident!("{}", tool_name_str.to_case(Case::Pascal));
let params_name = format_ident!("{}Params", struct_name);
let description = extract_doc_string(&func.attrs);
if description.is_empty() {
return Err(syn::Error::new_spanned(
fn_name,
"#[llm_tool] functions must have a doc comment (used as the tool description)",
));
}
let all_params = extract_params(func)?;
let ctx_param = all_params.iter().find(|p| p.is_context);
let params: Vec<&ParamInfo> = all_params.iter().filter(|p| !p.is_context).collect();
for param in ¶ms {
if param.doc_attrs.is_empty() {
return Err(syn::Error::new_spanned(
¶m.name,
format!(
"#[llm_tool] parameter `{}` must have a doc comment \
(used as the parameter description in the JSON schema)",
param.name
),
));
}
}
let return_info = parse_return_type(func)?;
let param_names: Vec<_> = params.iter().map(|p| &p.name).collect();
let param_descriptions: Vec<String> = params
.iter()
.map(|p| extract_doc_string(&p.doc_attrs))
.collect();
let (param_struct_types, borrow_bindings) = build_param_types_and_borrows(¶ms);
let serde_defaults = build_serde_defaults(¶ms);
let body_tokens = build_body_tokens(func, &return_info, &crate_path);
let vis = &func.vis;
let params_doc = format!("Auto-generated parameters for the [`{struct_name}`] tool.");
let struct_doc = format!(
"Auto-generated tool struct. See the `#[llm_tool]`-annotated function `{fn_name}` for the implementation."
);
let ctx_binding = if let Some(cp) = ctx_param {
let ctx_name = &cp.name;
quote! { let #ctx_name = _ctx; }
} else {
quote! {}
};
Ok(quote! {
#[doc = #params_doc]
#[derive(::serde::Deserialize, ::schemars::JsonSchema)]
#vis struct #params_name {
#(
#[schemars(description = #param_descriptions)]
#serde_defaults
pub #param_names: #param_struct_types,
)*
}
#[doc = #struct_doc]
#vis struct #struct_name;
impl #crate_path::RustTool for #struct_name {
type Params = #params_name;
const NAME: &'static str = #tool_name_str;
const DESCRIPTION: &'static str = #description;
async fn call(&self, params: Self::Params, _ctx: &#crate_path::ToolContext) -> ::std::result::Result<#crate_path::ToolOutput, #crate_path::ToolError> {
use #crate_path::__private::SerializeFallback as _;
let #params_name { #( #param_names, )* } = params;
#( #borrow_bindings )*
#ctx_binding
#body_tokens
}
}
})
}
fn build_param_types_and_borrows(
params: &[&ParamInfo],
) -> (Vec<proc_macro2::TokenStream>, Vec<proc_macro2::TokenStream>) {
params
.iter()
.map(|p| {
if is_str_ref(&p.ty) {
let name = &p.name;
(quote! { String }, quote! { let #name: &str = &#name; })
} else {
let ty = &p.ty;
(quote! { #ty }, quote! {})
}
})
.unzip()
}
fn build_serde_defaults(params: &[&ParamInfo]) -> Vec<proc_macro2::TokenStream> {
params
.iter()
.map(|p| {
if is_option_type(&p.ty) {
quote! { #[serde(default)] }
} else {
quote! {}
}
})
.collect()
}
fn build_body_tokens(
func: &ItemFn,
return_info: &ReturnInfo,
crate_path: &proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let is_async = func.sig.asyncness.is_some();
let body_stmts = &func.block.stmts;
match return_info {
ReturnInfo::ResultType { ok_type, err_type } => {
let inner = if is_async {
quote! {
let __r: ::std::result::Result<#ok_type, #err_type> = async move {
#( #body_stmts )*
}.await;
}
} else {
quote! {
let __r: ::std::result::Result<#ok_type, #err_type> = (|| { #( #body_stmts )* })();
}
};
quote! {
#inner
match __r {
::std::result::Result::Ok(__v) => #crate_path::__private::Wrap(__v).__convert(),
::std::result::Result::Err(__e) => ::std::result::Result::Err(::std::convert::Into::into(__e)),
}
}
}
ReturnInfo::BareType => {
let inner = if is_async {
quote! {
let __v = async move { #( #body_stmts )* }.await;
}
} else {
quote! {
let __v = (|| { #( #body_stmts )* })();
}
};
quote! {
#inner
#crate_path::__private::Wrap(__v).__convert()
}
}
}
}
fn is_option_type(ty: &syn::Type) -> bool {
let Type::Path(type_path) = ty else {
return false;
};
let Some(last_seg) = type_path.path.segments.last() else {
return false;
};
if last_seg.ident != "Option" {
return false;
}
matches!(&last_seg.arguments, PathArguments::AngleBracketed(args)
if args.args.len() == 1
&& matches!(args.args.first(), Some(GenericArgument::Type(_))))
}
fn is_tool_context_type(ty: &syn::Type) -> bool {
let inner = match ty {
Type::Reference(r) => r.elem.as_ref(),
other => other,
};
let Type::Path(type_path) = inner else {
return false;
};
type_path
.path
.segments
.last()
.is_some_and(|seg| seg.ident == "ToolContext")
}
fn is_str_ref(ty: &syn::Type) -> bool {
let Type::Reference(ref_type) = ty else {
return false;
};
if ref_type.mutability.is_some() {
return false;
}
let Type::Path(type_path) = ref_type.elem.as_ref() else {
return false;
};
type_path
.path
.segments
.last()
.is_some_and(|seg| seg.ident == "str" && seg.arguments.is_none())
}
fn is_explicit_context_attr(attr: &syn::Attribute) -> syn::Result<bool> {
if !attr.path().is_ident("llm_tool") {
return Ok(false);
}
let mut is_context = false;
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("context") {
is_context = true;
Ok(())
} else {
Err(meta.error("unsupported llm_tool attribute"))
}
})?;
Ok(is_context)
}
fn extract_params(func: &ItemFn) -> syn::Result<Vec<ParamInfo>> {
let mut params = Vec::new();
for arg in &func.sig.inputs {
match arg {
FnArg::Receiver(r) => {
return Err(syn::Error::new_spanned(
r,
"#[llm_tool] functions must be free functions (no `self`)",
));
}
FnArg::Typed(PatType { pat, ty, attrs, .. }) => {
let name = match pat.as_ref() {
Pat::Ident(ident) => ident.ident.clone(),
other => {
return Err(syn::Error::new_spanned(
other,
"#[llm_tool] parameters must be simple identifiers",
));
}
};
let mut has_context_attr = false;
for a in attrs {
has_context_attr |= is_explicit_context_attr(a)?;
}
let is_tool_context = is_tool_context_type(ty);
let is_context = has_context_attr || is_tool_context;
if is_tool_context && !matches!(ty.as_ref(), syn::Type::Reference(_)) {
return Err(syn::Error::new_spanned(
ty,
"ToolContext parameter must be a reference type (e.g., `&ToolContext` or `&'a ToolContext`)",
));
}
let doc_attrs: Vec<syn::Attribute> = attrs
.iter()
.filter(|a| a.path().is_ident("doc"))
.cloned()
.collect();
params.push(ParamInfo {
name,
ty: ty.clone(),
doc_attrs,
is_context,
});
}
}
}
Ok(params)
}
fn extract_doc_string(attrs: &[syn::Attribute]) -> String {
let lines: Vec<String> = attrs
.iter()
.filter_map(|attr| {
if !attr.path().is_ident("doc") {
return None;
}
if let syn::Meta::NameValue(nv) = &attr.meta
&& let syn::Expr::Lit(lit) = &nv.value
&& let syn::Lit::Str(s) = &lit.lit
{
return Some(s.value());
}
None
})
.collect();
lines
.iter()
.map(|l| l.trim())
.collect::<Vec<_>>()
.join("\n")
.trim()
.to_string()
}
fn parse_return_type(func: &ItemFn) -> syn::Result<ReturnInfo> {
let syn::ReturnType::Type(_, ty) = &func.sig.output else {
return Err(syn::Error::new_spanned(
&func.sig,
"#[llm_tool] functions must have an explicit return type",
));
};
if let Some(result_types) = try_extract_result_types(ty) {
return Ok(result_types);
}
Ok(ReturnInfo::BareType)
}
fn try_extract_result_types(ty: &syn::Type) -> Option<ReturnInfo> {
let Type::Path(type_path) = ty else {
return None;
};
let last_seg = type_path.path.segments.last()?;
if last_seg.ident != "Result" {
return None;
}
let PathArguments::AngleBracketed(args) = &last_seg.arguments else {
return None;
};
if args.args.len() != 2 {
return None;
}
let GenericArgument::Type(ok_type) = &args.args[0] else {
return None;
};
let GenericArgument::Type(err_type) = &args.args[1] else {
return None;
};
Some(ReturnInfo::ResultType {
ok_type: Box::new(ok_type.clone()),
err_type: Box::new(err_type.clone()),
})
}