use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput};
#[proc_macro_derive(Tool, attributes(tool))]
pub fn derive_tool(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match expand_tool(input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn expand_tool(input: DeriveInput) -> syn::Result<TokenStream2> {
match &input.data {
Data::Struct(_) => {}
_ => {
return Err(syn::Error::new_spanned(
&input.ident,
"#[derive(Tool)] can only be applied to structs",
));
}
}
let struct_name = &input.ident;
let mut description: Option<String> = None;
let mut name_override: Option<String> = None;
for attr in &input.attrs {
if !attr.path().is_ident("tool") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("description") {
let value: syn::LitStr = meta.value()?.parse()?;
description = Some(value.value());
} else if meta.path.is_ident("name") {
let value: syn::LitStr = meta.value()?.parse()?;
name_override = Some(value.value());
}
Ok(())
})?;
}
let description = description.unwrap_or_else(|| to_title_case(&struct_name.to_string()));
let tool_name = name_override.unwrap_or_else(|| to_snake_case(&struct_name.to_string()));
let expanded = quote! {
impl #struct_name {
pub fn tool_name() -> &'static str {
#tool_name
}
pub fn tool_description() -> &'static str {
#description
}
pub fn tool_schema() -> traitclaw_core::ToolSchema {
let schema = schemars::schema_for!(#struct_name);
traitclaw_core::ToolSchema {
name: #tool_name.to_string(),
description: #description.to_string(),
parameters: serde_json::to_value(schema)
.unwrap_or_else(|_| serde_json::Value::Object(Default::default())),
}
}
}
#[async_trait::async_trait]
impl traitclaw_core::ErasedTool for #struct_name {
fn name(&self) -> &str {
#tool_name
}
fn description(&self) -> &str {
#description
}
fn schema(&self) -> traitclaw_core::ToolSchema {
#struct_name::tool_schema()
}
async fn execute_json(
&self,
input: serde_json::Value,
) -> traitclaw_core::Result<serde_json::Value> {
let typed: #struct_name = serde_json::from_value(input)
.map_err(|e| traitclaw_core::Error::tool_execution(
#tool_name,
format!("Invalid input: {e}"),
))?;
typed.execute().await
}
}
};
Ok(expanded)
}
fn to_snake_case(s: &str) -> String {
let mut result = String::new();
for (i, c) in s.chars().enumerate() {
if c.is_uppercase() && i > 0 {
result.push('_');
}
result.extend(c.to_lowercase());
}
result
}
fn to_title_case(s: &str) -> String {
let mut result = String::new();
for (i, c) in s.chars().enumerate() {
if c.is_uppercase() && i > 0 {
result.push(' ');
}
result.push(c);
}
result
}