#![forbid(unsafe_code)]
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use syn::{FnArg, ItemFn, Pat, PatType, Type, parse_macro_input, parse_quote};
fn upper_camel(s: &str) -> String {
let mut out = String::new();
let mut up = true;
for c in s.chars() {
if c == '_' {
up = true;
} else if up {
out.push(c.to_ascii_uppercase());
up = false;
} else {
out.push(c);
}
}
out
}
fn doc_comment(attrs: &[syn::Attribute]) -> String {
let mut out = String::new();
for a in attrs {
if a.path().is_ident("doc") {
if let syn::Meta::NameValue(nv) = a.meta.clone() {
if let syn::Expr::Lit(lit) = nv.value {
if let syn::Lit::Str(s) = lit.lit {
if !out.is_empty() {
out.push('\n');
}
out.push_str(s.value().trim());
}
}
}
}
}
out
}
#[proc_macro_attribute]
#[allow(clippy::match_on_vec_items)] pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
let f = parse_macro_input!(item as ItemFn);
let vis = &f.vis;
let name_ident = &f.sig.ident;
let name_str = name_ident.to_string();
let pascal = upper_camel(&name_str);
let struct_ident = syn::Ident::new(&pascal, Span::call_site());
let description = doc_comment(&f.attrs);
if f.sig.asyncness.is_none() {
return TokenStream::from(quote! {
compile_error!("#[adk_rs::tool] requires an async fn");
});
}
let inputs: Vec<&FnArg> = f.sig.inputs.iter().collect();
if inputs.len() != 2 {
return TokenStream::from(quote! {
compile_error!("#[adk_rs::tool] requires exactly two args: (args: T, ctx: &mut ToolContext)");
});
}
let PatType {
pat: arg_pat,
ty: arg_ty,
..
} = match inputs[0] {
FnArg::Typed(p) => p.clone(),
FnArg::Receiver(_) => {
return TokenStream::from(quote! {
compile_error!("#[adk_rs::tool] doesn't support receivers");
});
}
};
let arg_ident = match *arg_pat {
Pat::Ident(ref id) => id.ident.clone(),
_ => {
return TokenStream::from(
quote! { compile_error!("first arg must be a simple identifier"); },
);
}
};
let PatType {
pat: ctx_pat,
ty: _ctx_ty,
..
} = match inputs[1] {
FnArg::Typed(p) => p.clone(),
FnArg::Receiver(_) => {
return TokenStream::from(quote! {
compile_error!("#[adk_rs::tool] doesn't support receivers");
});
}
};
let ctx_ident = match *ctx_pat {
Pat::Ident(ref id) => id.ident.clone(),
_ => {
return TokenStream::from(
quote! { compile_error!("second arg must be a simple identifier"); },
);
}
};
let arg_ty_owned: Type = parse_quote!(#arg_ty);
let body = &f.block;
let ret_ty = &f.sig.output;
let constructor_name = name_ident.clone();
let expanded = quote! {
#[doc = #description]
#[derive(Debug, Default, Clone, Copy)]
#vis struct #struct_ident;
#[::async_trait::async_trait]
impl ::adk_rs::__private::DynTool for #struct_ident {
fn name(&self) -> &str { #name_str }
fn description(&self) -> &str { #description }
fn declaration(&self) -> ::std::option::Option<::adk_rs::__private::FunctionDeclaration> {
let root = ::schemars::schema_for!(#arg_ty_owned);
let schema = ::adk_rs::__private::Schema::from_schemars(&root)
.unwrap_or_else(|_| ::adk_rs::__private::Schema::object());
::std::option::Option::Some(
::adk_rs::__private::FunctionDeclaration::new(#name_str, #description)
.with_parameters(schema),
)
}
async fn run(
&self,
args: ::serde_json::Value,
#ctx_ident: &mut ::adk_rs::__private::ToolContext,
) -> ::adk_rs::__private::Result<::serde_json::Value> {
async fn __inner(#arg_ident: #arg_ty_owned, #ctx_ident: &mut ::adk_rs::__private::ToolContext) #ret_ty #body
let typed: #arg_ty_owned = ::serde_json::from_value(args).map_err(|e| {
::adk_rs::__private::Error::Tool(::adk_rs::__private::ToolError::InvalidArgs {
tool: #name_str.to_string(),
message: e.to_string(),
})
})?;
let r = __inner(typed, #ctx_ident).await?;
::serde_json::to_value(r).map_err(|e| {
::adk_rs::__private::Error::Tool(::adk_rs::__private::ToolError::Execution {
tool: #name_str.to_string(),
message: e.to_string(),
})
})
}
}
#vis fn #constructor_name() -> ::std::sync::Arc<dyn ::adk_rs::__private::DynTool> {
::std::sync::Arc::new(#struct_ident)
}
};
TokenStream::from(expanded)
}