use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{Data, DeriveInput, Lit, Meta, parse_macro_input};
#[proc_macro_derive(Tool, attributes(tool))]
pub fn derive_tool(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match expand_tool(&input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn expand_tool(input: &DeriveInput) -> syn::Result<TokenStream2> {
if !matches!(input.data, Data::Struct(_)) {
return Err(syn::Error::new_spanned(
&input.ident,
"#[derive(Tool)] only supports structs",
));
}
let struct_ident = &input.ident;
let attrs = ToolAttrs::parse(&input.attrs)?;
let tool_name = attrs
.name
.unwrap_or_else(|| pascal_to_snake(&struct_ident.to_string()));
let tool_description: Option<String> =
attrs.description.or_else(|| first_doc_line(&input.attrs));
let wrapper_ident = quote::format_ident!("__{}ToolImpl", struct_ident);
let description_method: TokenStream2 = if let Some(desc) = &tool_description {
quote! {
fn description(&self) -> ::core::option::Option<&str> {
::core::option::Option::Some(#desc)
}
}
} else {
quote! {}
};
Ok(quote! {
#[doc(hidden)]
#[derive(::core::default::Default)]
pub struct #wrapper_ident;
#[::claude_api::__private::async_trait::async_trait]
impl ::claude_api::tool_dispatch::Tool for #wrapper_ident {
fn name(&self) -> &str {
#tool_name
}
#description_method
fn schema(&self) -> ::claude_api::__private::serde_json::Value {
let schema = ::claude_api::__private::schemars::schema_for!(#struct_ident);
::claude_api::__private::serde_json::to_value(&schema)
.unwrap_or_else(|_| ::claude_api::__private::serde_json::Value::Null)
}
async fn invoke(
&self,
input: ::claude_api::__private::serde_json::Value,
) -> ::core::result::Result<
::claude_api::__private::serde_json::Value,
::claude_api::tool_dispatch::ToolError,
> {
let parsed: #struct_ident =
::claude_api::__private::serde_json::from_value(input)
.map_err(|e| ::claude_api::tool_dispatch::ToolError::invalid_input(
::std::format!("input did not match {}'s schema: {}", #tool_name, e)
))?;
<#struct_ident>::run(parsed).await
}
}
impl #struct_ident {
pub fn tool() -> #wrapper_ident {
#wrapper_ident::default()
}
}
})
}
#[derive(Default)]
struct ToolAttrs {
name: Option<String>,
description: Option<String>,
}
impl ToolAttrs {
fn parse(attrs: &[syn::Attribute]) -> syn::Result<Self> {
let mut out = ToolAttrs::default();
for attr in attrs {
if !attr.path().is_ident("tool") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("name") {
let value = meta.value()?;
let lit: syn::LitStr = value.parse()?;
out.name = Some(lit.value());
} else if meta.path.is_ident("description") {
let value = meta.value()?;
let lit: syn::LitStr = value.parse()?;
out.description = Some(lit.value());
} else {
return Err(meta
.error("unsupported #[tool(...)] key; expected `name` or `description`"));
}
Ok(())
})?;
}
Ok(out)
}
}
fn first_doc_line(attrs: &[syn::Attribute]) -> Option<String> {
let mut lines: Vec<String> = Vec::new();
for attr in attrs {
if !attr.path().is_ident("doc") {
continue;
}
if let Meta::NameValue(nv) = &attr.meta
&& let syn::Expr::Lit(syn::ExprLit {
lit: Lit::Str(s), ..
}) = &nv.value
{
lines.push(s.value().trim().to_string());
}
}
let joined = lines.join(" ");
let trimmed = joined.trim();
if trimmed.is_empty() {
None
} else {
let mut end = trimmed.len();
for (i, ch) in trimmed.char_indices() {
if ch == '.' {
let after_idx = i + ch.len_utf8();
let after = &trimmed[after_idx..];
if after.is_empty() || after.starts_with(' ') {
end = after_idx;
break;
}
}
}
Some(trimmed[..end].to_string())
}
}
fn pascal_to_snake(s: &str) -> String {
let mut out = String::with_capacity(s.len() + 4);
for (i, ch) in s.char_indices() {
if ch.is_uppercase() && i > 0 {
out.push('_');
}
out.extend(ch.to_lowercase());
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn snake_case_basic() {
assert_eq!(pascal_to_snake("GetWeather"), "get_weather");
assert_eq!(pascal_to_snake("HTMLParser"), "h_t_m_l_parser");
assert_eq!(pascal_to_snake("F"), "f");
assert_eq!(pascal_to_snake("Foo"), "foo");
}
#[test]
fn first_doc_line_takes_first_sentence() {
let attrs: Vec<syn::Attribute> = syn::parse_quote! {
};
assert_eq!(first_doc_line(&attrs).as_deref(), Some("Hello world."));
}
#[test]
fn first_doc_line_handles_no_period() {
let attrs: Vec<syn::Attribute> = syn::parse_quote! {
};
assert_eq!(first_doc_line(&attrs).as_deref(), Some("Hello world"));
}
#[test]
fn first_doc_line_returns_none_on_empty() {
let attrs: Vec<syn::Attribute> = vec![];
assert!(first_doc_line(&attrs).is_none());
}
}