Skip to main content

claude_api_derive/
lib.rs

1//! Procedural macros for [`claude-api`](https://docs.rs/claude-api).
2//!
3//! Re-exported from the parent crate behind the `derive` feature; users
4//! should write `claude_api::derive::Tool`, not depend on this crate
5//! directly.
6//!
7//! # `#[derive(Tool)]`
8//!
9//! Derive an implementation of `claude_api::tool_dispatch::Tool` for a
10//! struct that already implements `serde::Deserialize` and
11//! `schemars::JsonSchema`. The struct's fields define the tool input;
12//! the user supplies the behavior via an inherent `async fn run(self)`.
13//!
14//! ```ignore
15//! use claude_api::derive::Tool;
16//! use claude_api::tool_dispatch::ToolError;
17//! use serde::Deserialize;
18//! use schemars::JsonSchema;
19//!
20//! /// Get the current weather for a city.
21//! #[derive(Deserialize, JsonSchema, Tool)]
22//! struct GetWeather {
23//!     /// City to look up.
24//!     city: String,
25//! }
26//!
27//! impl GetWeather {
28//!     async fn run(self) -> Result<serde_json::Value, ToolError> {
29//!         Ok(serde_json::json!({"temp": 72, "city": self.city}))
30//!     }
31//! }
32//!
33//! // Use:
34//! let tool = GetWeather::tool();
35//! ```
36//!
37//! ## Attribute syntax
38//!
39//! - `#[tool(name = "...")]` -- override the tool name.
40//!   Default: `snake_case` of the type name.
41//! - `#[tool(description = "...")]` -- override the description.
42//!   Default: the first line of the struct's doc comment.
43
44use proc_macro::TokenStream;
45use proc_macro2::TokenStream as TokenStream2;
46use quote::quote;
47use syn::{Data, DeriveInput, Lit, Meta, parse_macro_input};
48
49/// Derive `claude_api::tool_dispatch::Tool` for a struct.
50///
51/// See the [crate-level docs](crate) for the supported attribute syntax and
52/// the requirements on the underlying struct.
53#[proc_macro_derive(Tool, attributes(tool))]
54pub fn derive_tool(input: TokenStream) -> TokenStream {
55    let input = parse_macro_input!(input as DeriveInput);
56    match expand_tool(&input) {
57        Ok(ts) => ts.into(),
58        Err(e) => e.to_compile_error().into(),
59    }
60}
61
62fn expand_tool(input: &DeriveInput) -> syn::Result<TokenStream2> {
63    if !matches!(input.data, Data::Struct(_)) {
64        return Err(syn::Error::new_spanned(
65            &input.ident,
66            "#[derive(Tool)] only supports structs",
67        ));
68    }
69
70    let struct_ident = &input.ident;
71    let attrs = ToolAttrs::parse(&input.attrs)?;
72
73    let tool_name = attrs
74        .name
75        .unwrap_or_else(|| pascal_to_snake(&struct_ident.to_string()));
76    let tool_description: Option<String> =
77        attrs.description.or_else(|| first_doc_line(&input.attrs));
78
79    let wrapper_ident = quote::format_ident!("__{}ToolImpl", struct_ident);
80
81    let description_method: TokenStream2 = if let Some(desc) = &tool_description {
82        quote! {
83            fn description(&self) -> ::core::option::Option<&str> {
84                ::core::option::Option::Some(#desc)
85            }
86        }
87    } else {
88        quote! {}
89    };
90
91    Ok(quote! {
92        // Hidden wrapper struct that carries the Tool impl. Unit struct so
93        // it's trivially Send + Sync + 'static.
94        #[doc(hidden)]
95        #[derive(::core::default::Default)]
96        pub struct #wrapper_ident;
97
98        #[::claude_api::__private::async_trait::async_trait]
99        impl ::claude_api::tool_dispatch::Tool for #wrapper_ident {
100            fn name(&self) -> &str {
101                #tool_name
102            }
103
104            #description_method
105
106            fn schema(&self) -> ::claude_api::__private::serde_json::Value {
107                let schema = ::claude_api::__private::schemars::schema_for!(#struct_ident);
108                ::claude_api::__private::serde_json::to_value(&schema)
109                    .unwrap_or_else(|_| ::claude_api::__private::serde_json::Value::Null)
110            }
111
112            async fn invoke(
113                &self,
114                input: ::claude_api::__private::serde_json::Value,
115            ) -> ::core::result::Result<
116                ::claude_api::__private::serde_json::Value,
117                ::claude_api::tool_dispatch::ToolError,
118            > {
119                let parsed: #struct_ident =
120                    ::claude_api::__private::serde_json::from_value(input)
121                        .map_err(|e| ::claude_api::tool_dispatch::ToolError::invalid_input(
122                            ::std::format!("input did not match {}'s schema: {}", #tool_name, e)
123                        ))?;
124                <#struct_ident>::run(parsed).await
125            }
126        }
127
128        impl #struct_ident {
129            /// Build the [`Tool`](::claude_api::tool_dispatch::Tool) impl
130            /// derived for this type. The returned value is a unit struct
131            /// implementing `Tool`; pass it to
132            /// [`ToolRegistry::register_tool`](::claude_api::tool_dispatch::ToolRegistry::register_tool)
133            /// or wrap in `Arc` for trait-object dispatch.
134            pub fn tool() -> #wrapper_ident {
135                #wrapper_ident::default()
136            }
137        }
138    })
139}
140
141#[derive(Default)]
142struct ToolAttrs {
143    name: Option<String>,
144    description: Option<String>,
145}
146
147impl ToolAttrs {
148    fn parse(attrs: &[syn::Attribute]) -> syn::Result<Self> {
149        let mut out = ToolAttrs::default();
150        for attr in attrs {
151            if !attr.path().is_ident("tool") {
152                continue;
153            }
154            attr.parse_nested_meta(|meta| {
155                if meta.path.is_ident("name") {
156                    let value = meta.value()?;
157                    let lit: syn::LitStr = value.parse()?;
158                    out.name = Some(lit.value());
159                } else if meta.path.is_ident("description") {
160                    let value = meta.value()?;
161                    let lit: syn::LitStr = value.parse()?;
162                    out.description = Some(lit.value());
163                } else {
164                    return Err(meta
165                        .error("unsupported #[tool(...)] key; expected `name` or `description`"));
166                }
167                Ok(())
168            })?;
169        }
170        Ok(out)
171    }
172}
173
174fn first_doc_line(attrs: &[syn::Attribute]) -> Option<String> {
175    let mut lines: Vec<String> = Vec::new();
176    for attr in attrs {
177        if !attr.path().is_ident("doc") {
178            continue;
179        }
180        if let Meta::NameValue(nv) = &attr.meta
181            && let syn::Expr::Lit(syn::ExprLit {
182                lit: Lit::Str(s), ..
183            }) = &nv.value
184        {
185            lines.push(s.value().trim().to_string());
186        }
187    }
188    let joined = lines.join(" ");
189    let trimmed = joined.trim();
190    if trimmed.is_empty() {
191        None
192    } else {
193        // First sentence: take up to and including the first period that
194        // ends a sentence (followed by space, end-of-string, or end-of-line).
195        let mut end = trimmed.len();
196        for (i, ch) in trimmed.char_indices() {
197            if ch == '.' {
198                let after_idx = i + ch.len_utf8();
199                let after = &trimmed[after_idx..];
200                if after.is_empty() || after.starts_with(' ') {
201                    end = after_idx;
202                    break;
203                }
204            }
205        }
206        Some(trimmed[..end].to_string())
207    }
208}
209
210fn pascal_to_snake(s: &str) -> String {
211    let mut out = String::with_capacity(s.len() + 4);
212    for (i, ch) in s.char_indices() {
213        if ch.is_uppercase() && i > 0 {
214            out.push('_');
215        }
216        out.extend(ch.to_lowercase());
217    }
218    out
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn snake_case_basic() {
227        assert_eq!(pascal_to_snake("GetWeather"), "get_weather");
228        assert_eq!(pascal_to_snake("HTMLParser"), "h_t_m_l_parser");
229        assert_eq!(pascal_to_snake("F"), "f");
230        assert_eq!(pascal_to_snake("Foo"), "foo");
231    }
232
233    #[test]
234    fn first_doc_line_takes_first_sentence() {
235        // Build attrs simulating: /// Hello world. More text here.
236        let attrs: Vec<syn::Attribute> = syn::parse_quote! {
237            /// Hello world. More text here.
238        };
239        assert_eq!(first_doc_line(&attrs).as_deref(), Some("Hello world."));
240    }
241
242    #[test]
243    fn first_doc_line_handles_no_period() {
244        let attrs: Vec<syn::Attribute> = syn::parse_quote! {
245            /// Hello world
246        };
247        assert_eq!(first_doc_line(&attrs).as_deref(), Some("Hello world"));
248    }
249
250    #[test]
251    fn first_doc_line_returns_none_on_empty() {
252        let attrs: Vec<syn::Attribute> = vec![];
253        assert!(first_doc_line(&attrs).is_none());
254    }
255}