use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Expr, ExprLit, ItemFn, Lit, Meta};
#[proc_macro_attribute]
pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as ItemFn);
let mut description = String::new();
let mut custom_name: Option<String> = None;
let attr_parser = syn::meta::parser(|meta| {
if meta.path.is_ident("description") {
let value: syn::LitStr = meta.value()?.parse()?;
description = value.value();
Ok(())
} else if meta.path.is_ident("name") {
let value: syn::LitStr = meta.value()?.parse()?;
custom_name = Some(value.value());
Ok(())
} else {
Err(meta.error("unsupported attribute"))
}
});
parse_macro_input!(attr with attr_parser);
let fn_name = &input_fn.sig.ident;
let fn_vis = &input_fn.vis;
let fn_block = &input_fn.block;
let fn_inputs = &input_fn.sig.inputs;
let fn_output = &input_fn.sig.output;
let fn_asyncness = &input_fn.sig.asyncness;
let tool_name = custom_name.unwrap_or_else(|| fn_name.to_string());
let description = if description.is_empty() {
let mut doc = String::new();
for attr in &input_fn.attrs {
if attr.path().is_ident("doc") {
if let Meta::NameValue(nv) = &attr.meta {
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) = &nv.value
{
if !doc.is_empty() {
doc.push(' ');
}
doc.push_str(lit.value().trim());
}
}
}
}
if doc.is_empty() {
format!("Tool: {}", tool_name)
} else {
doc
}
} else {
description
};
let mut param_names: Vec<syn::Ident> = Vec::new();
let mut param_types: Vec<syn::Type> = Vec::new();
let mut param_name_strs: Vec<String> = Vec::new();
let mut param_json_types: Vec<String> = Vec::new();
for input in fn_inputs.iter() {
if let syn::FnArg::Typed(pat_type) = input {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
let name = pat_ident.ident.clone();
let name_str = name.to_string();
let ty = (*pat_type.ty).clone();
let json_type = rust_type_to_json_schema(&pat_type.ty);
param_names.push(name);
param_name_strs.push(name_str);
param_types.push(ty);
param_json_types.push(json_type);
}
}
}
let struct_name = format_ident!("{}Tool", to_pascal_case(&tool_name));
let expanded = quote! {
#fn_vis #fn_asyncness fn #fn_name(#fn_inputs) #fn_output #fn_block
#[derive(Debug, Clone)]
#fn_vis struct #struct_name;
impl #struct_name {
pub fn new() -> Self {
Self
}
}
impl Default for #struct_name {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl praisonai::Tool for #struct_name {
fn name(&self) -> &str {
#tool_name
}
fn description(&self) -> &str {
#description
}
fn parameters_schema(&self) -> serde_json::Value {
let mut properties = serde_json::Map::new();
let mut required = Vec::new();
#(
properties.insert(
#param_name_strs.to_string(),
serde_json::json!({ "type": #param_json_types })
);
required.push(serde_json::Value::String(#param_name_strs.to_string()));
)*
serde_json::json!({
"type": "object",
"properties": properties,
"required": required
})
}
async fn execute(&self, args: serde_json::Value) -> praisonai::Result<serde_json::Value> {
#(
let #param_names: #param_types = serde_json::from_value(
args.get(#param_name_strs)
.cloned()
.unwrap_or(serde_json::Value::Null)
).map_err(|e| praisonai::Error::tool(format!("Failed to parse {}: {}", #param_name_strs, e)))?;
)*
let result = #fn_name(#(#param_names),*).await;
serde_json::to_value(result)
.map_err(|e| praisonai::Error::tool(format!("Failed to serialize result: {}", e)))
}
}
};
TokenStream::from(expanded)
}
fn rust_type_to_json_schema(ty: &syn::Type) -> String {
let type_str = quote!(#ty).to_string().replace(" ", "");
match type_str.as_str() {
"String" | "&str" | "str" => "string".to_string(),
"i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
| "usize" => "integer".to_string(),
"f32" | "f64" => "number".to_string(),
"bool" => "boolean".to_string(),
_ if type_str.starts_with("Vec<") => "array".to_string(),
_ if type_str.starts_with("Option<") => {
let inner = &type_str[7..type_str.len() - 1];
rust_type_str_to_json_schema(inner)
}
_ => "object".to_string(),
}
}
fn rust_type_str_to_json_schema(type_str: &str) -> String {
match type_str {
"String" | "&str" | "str" => "string".to_string(),
"i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
| "usize" => "integer".to_string(),
"f32" | "f64" => "number".to_string(),
"bool" => "boolean".to_string(),
_ if type_str.starts_with("Vec<") => "array".to_string(),
_ => "object".to_string(),
}
}
fn to_pascal_case(s: &str) -> String {
s.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().chain(chars).collect(),
}
})
.collect()
}