Skip to main content

llm_tool_macros/
lib.rs

1//! Proc-macro crate for `llm-tool`.
2//!
3//! Provides the `#[llm_tool]` attribute macro that transforms a plain function
4//! into a strongly-typed [`RustTool`](https://docs.rs/llm-tool/latest/llm_tool/trait.RustTool.html)
5//! implementation.
6
7use convert_case::{Case, Casing};
8use proc_macro::TokenStream;
9use quote::{format_ident, quote};
10use syn::{FnArg, GenericArgument, ItemFn, Pat, PatType, PathArguments, Type, parse_macro_input};
11
12/// Transforms a function into a `RustTool` implementation.
13///
14/// The macro generates:
15/// - A `{FnName}Params` struct deriving `Deserialize` and `JsonSchema`
16/// - A `{FnName}` unit struct (`PascalCase`) implementing `RustTool`
17///
18/// The tool **name** is the function name (`snake_case`).
19/// The tool **description** comes from the function's doc comment.
20/// Parameter names and types come from the function signature.
21/// Doc comments on parameters become schema descriptions.
22///
23/// # Typed parameters
24///
25/// Parameters may use `&str` — the generated params struct stores an owned
26/// `String` and the macro auto-borrows it before passing to your function body.
27///
28/// # Return types
29///
30/// The return type can be `Result<T, E>` or just `T` (infallible):
31///
32/// - **`T`**: `String` (wrapped as-is), `ToolOutput` (passed through), any
33///   `T: Serialize` (auto-serialized to JSON), or any `T: Into<ToolOutput>`
34/// - **`E`**: any `E: Into<ToolError>` — built-in for `String`, `ToolError`,
35///   `std::io::Error`, `serde_json::Error`
36///
37/// # Usage
38///
39/// ```ignore
40/// use llm_tool::{RustTool, ToolContext, ToolRegistry};
41///
42/// /// Adds two numbers together (with a twist).
43/// #[llm_tool::llm_tool]
44/// fn wonky_add(
45///     /// First number.
46///     a: i64,
47///     /// Second number.
48///     b: i64,
49/// ) -> Result<String, String> {
50///     Ok(format!("{}", a + b + 1))
51/// }
52///
53/// let mut registry = ToolRegistry::new();
54/// registry.register(WonkyAdd);
55/// assert_eq!(registry.definitions().len(), 1);
56/// ```
57#[proc_macro_attribute]
58pub fn llm_tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
59    let func = parse_macro_input!(item as ItemFn);
60    match tool_impl(&func) {
61        Ok(tokens) => tokens.into(),
62        Err(err) => err.to_compile_error().into(),
63    }
64}
65
66// ── Implementation ──────────────────────────────────────────────────────────
67
68/// Parsed information about a single function parameter.
69struct ParamInfo {
70    name: syn::Ident,
71    ty: Box<syn::Type>,
72    doc_attrs: Vec<syn::Attribute>,
73    is_context: bool,
74}
75
76/// Information about the function's return type.
77enum ReturnInfo {
78    /// `Result<T, E>` — fallible tool.
79    ResultType {
80        ok_type: Box<syn::Type>,
81        err_type: Box<syn::Type>,
82    },
83    /// Bare `T` — infallible tool.
84    BareType,
85}
86
87fn tool_impl(func: &ItemFn) -> syn::Result<proc_macro2::TokenStream> {
88    let crate_path = quote! { ::llm_tool };
89    let fn_name = &func.sig.ident;
90    let tool_name_str = fn_name.to_string();
91    let struct_name = format_ident!("{}", tool_name_str.to_case(Case::Pascal));
92    let params_name = format_ident!("{}Params", struct_name);
93
94    // Extract doc comment from function attributes → tool description.
95    let description = extract_doc_string(&func.attrs);
96    if description.is_empty() {
97        return Err(syn::Error::new_spanned(
98            fn_name,
99            "#[llm_tool] functions must have a doc comment (used as the tool description)",
100        ));
101    }
102
103    // Extract parameters, separating ToolContext from regular params.
104    let all_params = extract_params(func)?;
105    let ctx_param = all_params.iter().find(|p| p.is_context);
106    let params: Vec<&ParamInfo> = all_params.iter().filter(|p| !p.is_context).collect();
107
108    // Enforce doc comments on every non-ToolContext parameter.
109    for param in &params {
110        if param.doc_attrs.is_empty() {
111            return Err(syn::Error::new_spanned(
112                &param.name,
113                format!(
114                    "#[llm_tool] parameter `{}` must have a doc comment \
115                     (used as the parameter description in the JSON schema)",
116                    param.name
117                ),
118            ));
119        }
120    }
121
122    // Parse return type: either Result<T, E> or bare T.
123    let return_info = parse_return_type(func)?;
124
125    let param_names: Vec<_> = params.iter().map(|p| &p.name).collect();
126    let param_descriptions: Vec<String> = params
127        .iter()
128        .map(|p| extract_doc_string(&p.doc_attrs))
129        .collect();
130
131    let (param_struct_types, borrow_bindings) = build_param_types_and_borrows(&params);
132    let serde_defaults = build_serde_defaults(&params);
133    let body_tokens = build_body_tokens(func, &return_info, &crate_path);
134
135    let vis = &func.vis;
136
137    let params_doc = format!("Auto-generated parameters for the [`{struct_name}`] tool.");
138    let struct_doc = format!(
139        "Auto-generated tool struct. See the `#[llm_tool]`-annotated function `{fn_name}` for the implementation."
140    );
141
142    // If the user's function takes a ToolContext parameter, bind it from the
143    // `_ctx` reference provided by the RustTool::call signature.
144    let ctx_binding = if let Some(cp) = ctx_param {
145        let ctx_name = &cp.name;
146        quote! { let #ctx_name = _ctx; }
147    } else {
148        quote! {}
149    };
150
151    Ok(quote! {
152        #[doc = #params_doc]
153        #[derive(::serde::Deserialize, ::schemars::JsonSchema)]
154        #vis struct #params_name {
155            #(
156                #[schemars(description = #param_descriptions)]
157                #serde_defaults
158                pub #param_names: #param_struct_types,
159            )*
160        }
161
162        #[doc = #struct_doc]
163        #vis struct #struct_name;
164
165        impl #crate_path::RustTool for #struct_name {
166            type Params = #params_name;
167            const NAME: &'static str = #tool_name_str;
168            const DESCRIPTION: &'static str = #description;
169
170            async fn call(&self, params: Self::Params, _ctx: &#crate_path::ToolContext) -> ::std::result::Result<#crate_path::ToolOutput, #crate_path::ToolError> {
171                // Import the fallback trait so `Wrap<T>::__convert()` resolves
172                // for `T: Serialize` types that lack an inherent `__convert`.
173                use #crate_path::__private::SerializeFallback as _;
174                // Destructure params into local bindings matching the original
175                // function signature.
176                let #params_name { #( #param_names, )* } = params;
177                // Auto-borrow &str params from their owned String fields.
178                #( #borrow_bindings )*
179                #ctx_binding
180                #body_tokens
181            }
182        }
183    })
184}
185
186/// Build the struct field types and any auto-borrow bindings for `&str` params.
187fn build_param_types_and_borrows(
188    params: &[&ParamInfo],
189) -> (Vec<proc_macro2::TokenStream>, Vec<proc_macro2::TokenStream>) {
190    params
191        .iter()
192        .map(|p| {
193            if is_str_ref(&p.ty) {
194                // &str → String in struct, auto-borrow in body
195                let name = &p.name;
196                (quote! { String }, quote! { let #name: &str = &#name; })
197            } else {
198                let ty = &p.ty;
199                (quote! { #ty }, quote! {})
200            }
201        })
202        .unzip()
203}
204
205/// Build `#[serde(default)]` annotations for `Option<T>` params.
206fn build_serde_defaults(params: &[&ParamInfo]) -> Vec<proc_macro2::TokenStream> {
207    params
208        .iter()
209        .map(|p| {
210            if is_option_type(&p.ty) {
211                quote! { #[serde(default)] }
212            } else {
213                quote! {}
214            }
215        })
216        .collect()
217}
218
219/// Build the body tokens that wrap the user's function body.
220///
221/// Uses compile-time dispatch via `__private::Wrap(v).__convert()` —
222/// the compiler resolves the correct conversion (inherent method for
223/// `String`/`ToolOutput`/`Json<T>`, or `SerializeFallback` trait for
224/// `T: Serialize`) without any proc-macro type-name matching.
225fn build_body_tokens(
226    func: &ItemFn,
227    return_info: &ReturnInfo,
228    crate_path: &proc_macro2::TokenStream,
229) -> proc_macro2::TokenStream {
230    let is_async = func.sig.asyncness.is_some();
231    let body_stmts = &func.block.stmts;
232
233    match return_info {
234        ReturnInfo::ResultType { ok_type, err_type } => {
235            let inner = if is_async {
236                quote! {
237                    let __r: ::std::result::Result<#ok_type, #err_type> = async move {
238                        #( #body_stmts )*
239                    }.await;
240                }
241            } else {
242                quote! {
243                    let __r: ::std::result::Result<#ok_type, #err_type> = (|| { #( #body_stmts )* })();
244                }
245            };
246            quote! {
247                #inner
248                match __r {
249                    ::std::result::Result::Ok(__v) => #crate_path::__private::Wrap(__v).__convert(),
250                    ::std::result::Result::Err(__e) => ::std::result::Result::Err(::std::convert::Into::into(__e)),
251                }
252            }
253        }
254        ReturnInfo::BareType => {
255            let inner = if is_async {
256                quote! {
257                    let __v = async move { #( #body_stmts )* }.await;
258                }
259            } else {
260                quote! {
261                    let __v = (|| { #( #body_stmts )* })();
262                }
263            };
264            quote! {
265                #inner
266                #crate_path::__private::Wrap(__v).__convert()
267            }
268        }
269    }
270}
271
272/// Check whether `ty` is `Option<T>` (or `std::option::Option<T>`).
273fn is_option_type(ty: &syn::Type) -> bool {
274    let Type::Path(type_path) = ty else {
275        return false;
276    };
277    let Some(last_seg) = type_path.path.segments.last() else {
278        return false;
279    };
280    if last_seg.ident != "Option" {
281        return false;
282    }
283    matches!(&last_seg.arguments, PathArguments::AngleBracketed(args)
284        if args.args.len() == 1
285            && matches!(args.args.first(), Some(GenericArgument::Type(_))))
286}
287
288/// Check whether `ty` is `ToolContext`, `&ToolContext`, or a qualified path
289/// ending in `ToolContext`.
290fn is_tool_context_type(ty: &syn::Type) -> bool {
291    let inner = match ty {
292        Type::Reference(r) => r.elem.as_ref(),
293        other => other,
294    };
295    let Type::Path(type_path) = inner else {
296        return false;
297    };
298    type_path
299        .path
300        .segments
301        .last()
302        .is_some_and(|seg| seg.ident == "ToolContext")
303}
304
305/// Check whether `ty` is `&str`.
306fn is_str_ref(ty: &syn::Type) -> bool {
307    let Type::Reference(ref_type) = ty else {
308        return false;
309    };
310    if ref_type.mutability.is_some() {
311        return false;
312    }
313    let Type::Path(type_path) = ref_type.elem.as_ref() else {
314        return false;
315    };
316    type_path
317        .path
318        .segments
319        .last()
320        .is_some_and(|seg| seg.ident == "str" && seg.arguments.is_none())
321}
322
323fn is_explicit_context_attr(attr: &syn::Attribute) -> syn::Result<bool> {
324    if !attr.path().is_ident("llm_tool") {
325        return Ok(false);
326    }
327    let mut is_context = false;
328    attr.parse_nested_meta(|meta| {
329        if meta.path.is_ident("context") {
330            is_context = true;
331            Ok(())
332        } else {
333            Err(meta.error("unsupported llm_tool attribute"))
334        }
335    })?;
336    Ok(is_context)
337}
338
339fn extract_params(func: &ItemFn) -> syn::Result<Vec<ParamInfo>> {
340    let mut params = Vec::new();
341    for arg in &func.sig.inputs {
342        match arg {
343            FnArg::Receiver(r) => {
344                return Err(syn::Error::new_spanned(
345                    r,
346                    "#[llm_tool] functions must be free functions (no `self`)",
347                ));
348            }
349            FnArg::Typed(PatType { pat, ty, attrs, .. }) => {
350                let name = match pat.as_ref() {
351                    Pat::Ident(ident) => ident.ident.clone(),
352                    other => {
353                        return Err(syn::Error::new_spanned(
354                            other,
355                            "#[llm_tool] parameters must be simple identifiers",
356                        ));
357                    }
358                };
359
360                let mut has_context_attr = false;
361                for a in attrs {
362                    has_context_attr |= is_explicit_context_attr(a)?;
363                }
364                let is_tool_context = is_tool_context_type(ty);
365                let is_context = has_context_attr || is_tool_context;
366
367                if is_tool_context && !matches!(ty.as_ref(), syn::Type::Reference(_)) {
368                    return Err(syn::Error::new_spanned(
369                        ty,
370                        "ToolContext parameter must be a reference type (e.g., `&ToolContext` or `&'a ToolContext`)",
371                    ));
372                }
373
374                let doc_attrs: Vec<syn::Attribute> = attrs
375                    .iter()
376                    .filter(|a| a.path().is_ident("doc"))
377                    .cloned()
378                    .collect();
379                params.push(ParamInfo {
380                    name,
381                    ty: ty.clone(),
382                    doc_attrs,
383                    is_context,
384                });
385            }
386        }
387    }
388    Ok(params)
389}
390
391fn extract_doc_string(attrs: &[syn::Attribute]) -> String {
392    let lines: Vec<String> = attrs
393        .iter()
394        .filter_map(|attr| {
395            if !attr.path().is_ident("doc") {
396                return None;
397            }
398            if let syn::Meta::NameValue(nv) = &attr.meta
399                && let syn::Expr::Lit(lit) = &nv.value
400                && let syn::Lit::Str(s) = &lit.lit
401            {
402                return Some(s.value());
403            }
404            None
405        })
406        .collect();
407    lines
408        .iter()
409        .map(|l| l.trim())
410        .collect::<Vec<_>>()
411        .join("\n")
412        .trim()
413        .to_string()
414}
415
416/// Parse the return type — either `Result<T, E>` or a bare type `T`.
417fn parse_return_type(func: &ItemFn) -> syn::Result<ReturnInfo> {
418    let syn::ReturnType::Type(_, ty) = &func.sig.output else {
419        return Err(syn::Error::new_spanned(
420            &func.sig,
421            "#[llm_tool] functions must have an explicit return type",
422        ));
423    };
424
425    // Try to parse as Result<T, E>.
426    if let Some(result_types) = try_extract_result_types(ty) {
427        return Ok(result_types);
428    }
429
430    // Not a Result — treat as infallible bare type.
431    Ok(ReturnInfo::BareType)
432}
433
434/// Try to extract `T` and `E` from a `Result<T, E>` return type.
435/// Returns `None` if the type is not a `Result`.
436fn try_extract_result_types(ty: &syn::Type) -> Option<ReturnInfo> {
437    let Type::Path(type_path) = ty else {
438        return None;
439    };
440
441    let last_seg = type_path.path.segments.last()?;
442
443    if last_seg.ident != "Result" {
444        return None;
445    }
446
447    let PathArguments::AngleBracketed(args) = &last_seg.arguments else {
448        return None;
449    };
450
451    if args.args.len() != 2 {
452        return None;
453    }
454
455    let GenericArgument::Type(ok_type) = &args.args[0] else {
456        return None;
457    };
458
459    let GenericArgument::Type(err_type) = &args.args[1] else {
460        return None;
461    };
462
463    Some(ReturnInfo::ResultType {
464        ok_type: Box::new(ok_type.clone()),
465        err_type: Box::new(err_type.clone()),
466    })
467}