1use proc_macro::TokenStream;
45use proc_macro2::TokenStream as TokenStream2;
46use quote::quote;
47use syn::{Data, DeriveInput, Lit, Meta, parse_macro_input};
48
49#[proc_macro_derive(Tool, attributes(tool))]
54pub fn derive_tool(input: TokenStream) -> TokenStream {
55 let input = parse_macro_input!(input as DeriveInput);
56 match expand_tool(&input) {
57 Ok(ts) => ts.into(),
58 Err(e) => e.to_compile_error().into(),
59 }
60}
61
62fn expand_tool(input: &DeriveInput) -> syn::Result<TokenStream2> {
63 if !matches!(input.data, Data::Struct(_)) {
64 return Err(syn::Error::new_spanned(
65 &input.ident,
66 "#[derive(Tool)] only supports structs",
67 ));
68 }
69
70 let struct_ident = &input.ident;
71 let attrs = ToolAttrs::parse(&input.attrs)?;
72
73 let tool_name = attrs
74 .name
75 .unwrap_or_else(|| pascal_to_snake(&struct_ident.to_string()));
76 let tool_description: Option<String> =
77 attrs.description.or_else(|| first_doc_line(&input.attrs));
78
79 let wrapper_ident = quote::format_ident!("__{}ToolImpl", struct_ident);
80
81 let description_method: TokenStream2 = if let Some(desc) = &tool_description {
82 quote! {
83 fn description(&self) -> ::core::option::Option<&str> {
84 ::core::option::Option::Some(#desc)
85 }
86 }
87 } else {
88 quote! {}
89 };
90
91 Ok(quote! {
92 #[doc(hidden)]
95 #[derive(::core::default::Default)]
96 pub struct #wrapper_ident;
97
98 #[::claude_api::__private::async_trait::async_trait]
99 impl ::claude_api::tool_dispatch::Tool for #wrapper_ident {
100 fn name(&self) -> &str {
101 #tool_name
102 }
103
104 #description_method
105
106 fn schema(&self) -> ::claude_api::__private::serde_json::Value {
107 let schema = ::claude_api::__private::schemars::schema_for!(#struct_ident);
108 ::claude_api::__private::serde_json::to_value(&schema)
109 .unwrap_or_else(|_| ::claude_api::__private::serde_json::Value::Null)
110 }
111
112 async fn invoke(
113 &self,
114 input: ::claude_api::__private::serde_json::Value,
115 ) -> ::core::result::Result<
116 ::claude_api::__private::serde_json::Value,
117 ::claude_api::tool_dispatch::ToolError,
118 > {
119 let parsed: #struct_ident =
120 ::claude_api::__private::serde_json::from_value(input)
121 .map_err(|e| ::claude_api::tool_dispatch::ToolError::invalid_input(
122 ::std::format!("input did not match {}'s schema: {}", #tool_name, e)
123 ))?;
124 <#struct_ident>::run(parsed).await
125 }
126 }
127
128 impl #struct_ident {
129 pub fn tool() -> #wrapper_ident {
135 #wrapper_ident::default()
136 }
137 }
138 })
139}
140
141#[derive(Default)]
142struct ToolAttrs {
143 name: Option<String>,
144 description: Option<String>,
145}
146
147impl ToolAttrs {
148 fn parse(attrs: &[syn::Attribute]) -> syn::Result<Self> {
149 let mut out = ToolAttrs::default();
150 for attr in attrs {
151 if !attr.path().is_ident("tool") {
152 continue;
153 }
154 attr.parse_nested_meta(|meta| {
155 if meta.path.is_ident("name") {
156 let value = meta.value()?;
157 let lit: syn::LitStr = value.parse()?;
158 out.name = Some(lit.value());
159 } else if meta.path.is_ident("description") {
160 let value = meta.value()?;
161 let lit: syn::LitStr = value.parse()?;
162 out.description = Some(lit.value());
163 } else {
164 return Err(meta
165 .error("unsupported #[tool(...)] key; expected `name` or `description`"));
166 }
167 Ok(())
168 })?;
169 }
170 Ok(out)
171 }
172}
173
174fn first_doc_line(attrs: &[syn::Attribute]) -> Option<String> {
175 let mut lines: Vec<String> = Vec::new();
176 for attr in attrs {
177 if !attr.path().is_ident("doc") {
178 continue;
179 }
180 if let Meta::NameValue(nv) = &attr.meta
181 && let syn::Expr::Lit(syn::ExprLit {
182 lit: Lit::Str(s), ..
183 }) = &nv.value
184 {
185 lines.push(s.value().trim().to_string());
186 }
187 }
188 let joined = lines.join(" ");
189 let trimmed = joined.trim();
190 if trimmed.is_empty() {
191 None
192 } else {
193 let mut end = trimmed.len();
196 for (i, ch) in trimmed.char_indices() {
197 if ch == '.' {
198 let after_idx = i + ch.len_utf8();
199 let after = &trimmed[after_idx..];
200 if after.is_empty() || after.starts_with(' ') {
201 end = after_idx;
202 break;
203 }
204 }
205 }
206 Some(trimmed[..end].to_string())
207 }
208}
209
210fn pascal_to_snake(s: &str) -> String {
211 let mut out = String::with_capacity(s.len() + 4);
212 for (i, ch) in s.char_indices() {
213 if ch.is_uppercase() && i > 0 {
214 out.push('_');
215 }
216 out.extend(ch.to_lowercase());
217 }
218 out
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn snake_case_basic() {
227 assert_eq!(pascal_to_snake("GetWeather"), "get_weather");
228 assert_eq!(pascal_to_snake("HTMLParser"), "h_t_m_l_parser");
229 assert_eq!(pascal_to_snake("F"), "f");
230 assert_eq!(pascal_to_snake("Foo"), "foo");
231 }
232
233 #[test]
234 fn first_doc_line_takes_first_sentence() {
235 let attrs: Vec<syn::Attribute> = syn::parse_quote! {
237 };
239 assert_eq!(first_doc_line(&attrs).as_deref(), Some("Hello world."));
240 }
241
242 #[test]
243 fn first_doc_line_handles_no_period() {
244 let attrs: Vec<syn::Attribute> = syn::parse_quote! {
245 };
247 assert_eq!(first_doc_line(&attrs).as_deref(), Some("Hello world"));
248 }
249
250 #[test]
251 fn first_doc_line_returns_none_on_empty() {
252 let attrs: Vec<syn::Attribute> = vec![];
253 assert!(first_doc_line(&attrs).is_none());
254 }
255}