Skip to main content

rig_derive/
lib.rs

1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use std::{collections::HashMap, ops::Deref};
7use syn::{
8    Attribute, DeriveInput, Expr, ExprLit, Ident, Lit, Meta, PathArguments, ReturnType, Token,
9    Type,
10    parse::{Parse, ParseStream},
11    parse_macro_input,
12    punctuated::Punctuated,
13};
14
15mod basic;
16mod client;
17mod custom;
18mod embed;
19
20pub(crate) const EMBED: &str = "embed";
21
22pub(crate) fn rig_core_path() -> proc_macro2::TokenStream {
23    match proc_macro_crate::crate_name("rig-core") {
24        Ok(proc_macro_crate::FoundCrate::Itself) => quote!(crate),
25        Ok(proc_macro_crate::FoundCrate::Name(name)) => {
26            let ident = format_ident!("{name}");
27            quote!(::#ident)
28        }
29        Err(_) => match proc_macro_crate::crate_name("rig") {
30            Ok(proc_macro_crate::FoundCrate::Itself) => quote!(crate),
31            Ok(proc_macro_crate::FoundCrate::Name(name)) => {
32                let ident = format_ident!("{name}");
33                quote!(::#ident)
34            }
35            Err(_) => quote!(::rig_core),
36        },
37    }
38}
39
40#[proc_macro_derive(ProviderClient, attributes(client))]
41pub fn derive_provider_client(input: TokenStream) -> TokenStream {
42    client::provider_client(input)
43}
44
45//References:
46//<https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro>
47//<https://doc.rust-lang.org/reference/procedural-macros.html>
48/// A macro that allows you to implement the `rig::embedding::Embed` trait by deriving it.
49/// Usage can be found below:
50///
51/// ```text
52/// use rig::Embed;
53/// use rig_derive::Embed;
54///
55/// #[derive(Embed)]
56/// struct Foo {
57///     id: String,
58///     #[embed] // this helper shows which field to embed
59///     description: String
60///}
61/// ```
62#[proc_macro_derive(Embed, attributes(embed))]
63pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
64    let mut input = parse_macro_input!(item as DeriveInput);
65
66    embed::expand_derive_embedding(&mut input)
67        .unwrap_or_else(syn::Error::into_compile_error)
68        .into()
69}
70
71struct MacroArgs {
72    name: Option<String>,
73    description: Option<String>,
74    param_descriptions: HashMap<String, String>,
75    required: Option<Vec<String>>,
76}
77
78fn parse_string_literal(expr: &Expr, field_name: &str) -> syn::Result<String> {
79    match expr {
80        Expr::Lit(ExprLit {
81            lit: Lit::Str(lit_str),
82            ..
83        }) => Ok(lit_str.value()),
84        _ => Err(syn::Error::new_spanned(
85            expr,
86            format!("`{field_name}` must be a string literal"),
87        )),
88    }
89}
90
91fn validate_explicit_tool_name(name: &str, expr: &Expr) -> syn::Result<()> {
92    if name.is_empty() || name.len() > 64 {
93        return Err(syn::Error::new_spanned(
94            expr,
95            "`name` must be between 1 and 64 characters long",
96        ));
97    }
98
99    let mut chars = name.chars();
100    let Some(first_char) = chars.next() else {
101        return Err(syn::Error::new_spanned(
102            expr,
103            "`name` must be between 1 and 64 characters long",
104        ));
105    };
106
107    if !first_char.is_ascii_alphabetic() && first_char != '_' {
108        return Err(syn::Error::new_spanned(
109            expr,
110            "`name` must start with an ASCII letter or underscore",
111        ));
112    }
113
114    if chars.any(|ch| !ch.is_ascii_alphanumeric() && ch != '_' && ch != '-') {
115        return Err(syn::Error::new_spanned(
116            expr,
117            "`name` may only contain ASCII letters, digits, underscores, or hyphens",
118        ));
119    }
120
121    Ok(())
122}
123
124impl Parse for MacroArgs {
125    fn parse(input: ParseStream) -> syn::Result<Self> {
126        let mut name = None;
127        let mut description = None;
128        let mut param_descriptions = HashMap::new();
129        let mut required = None;
130
131        // If the input is empty, return default values
132        if input.is_empty() {
133            return Ok(MacroArgs {
134                name,
135                description,
136                param_descriptions,
137                required,
138            });
139        }
140
141        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
142
143        for meta in meta_list {
144            match meta {
145                Meta::NameValue(nv) => {
146                    let ident = nv.path.get_ident().ok_or_else(|| {
147                        syn::Error::new_spanned(
148                            &nv.path,
149                            "unsupported top-level #[rig_tool] argument",
150                        )
151                    })?;
152
153                    match ident.to_string().as_str() {
154                        "name" => {
155                            let parsed_name = parse_string_literal(&nv.value, "name")?;
156                            validate_explicit_tool_name(&parsed_name, &nv.value)?;
157                            name = Some(parsed_name);
158                        }
159                        "description" => {
160                            description = Some(parse_string_literal(&nv.value, "description")?);
161                        }
162                        _ => {
163                            return Err(syn::Error::new_spanned(
164                                &nv.path,
165                                format!("unsupported top-level #[rig_tool] argument `{}`", ident),
166                            ));
167                        }
168                    }
169                }
170                Meta::List(list) => {
171                    let ident = list.path.get_ident().ok_or_else(|| {
172                        syn::Error::new_spanned(
173                            &list.path,
174                            "unsupported top-level #[rig_tool] argument",
175                        )
176                    })?;
177
178                    match ident.to_string().as_str() {
179                        "params" => {
180                            let nested: Punctuated<Meta, Token![,]> =
181                                list.parse_args_with(Punctuated::parse_terminated)?;
182
183                            for meta in nested {
184                                if let Meta::NameValue(nv) = meta
185                                    && let Expr::Lit(ExprLit {
186                                        lit: Lit::Str(lit_str),
187                                        ..
188                                    }) = nv.value
189                                {
190                                    let Some(param_ident) = nv.path.get_ident() else {
191                                        return Err(syn::Error::new_spanned(
192                                            &nv.path,
193                                            "parameter descriptions must use identifier keys",
194                                        ));
195                                    };
196                                    let param_name = param_ident.to_string();
197                                    param_descriptions.insert(param_name, lit_str.value());
198                                }
199                            }
200                        }
201                        "required" => {
202                            let required_variables: Punctuated<Ident, Token![,]> =
203                                list.parse_args_with(Punctuated::parse_terminated)?;
204
205                            required = Some(
206                                required_variables
207                                    .into_iter()
208                                    .map(|x| x.to_string())
209                                    .collect(),
210                            );
211                        }
212                        _ => {
213                            return Err(syn::Error::new_spanned(
214                                &list.path,
215                                format!("unsupported top-level #[rig_tool] argument `{}`", ident),
216                            ));
217                        }
218                    }
219                }
220                Meta::Path(path) => {
221                    let message = if let Some(ident) = path.get_ident() {
222                        format!("unsupported top-level #[rig_tool] argument `{ident}`")
223                    } else {
224                        "unsupported top-level #[rig_tool] argument".to_string()
225                    };
226
227                    return Err(syn::Error::new_spanned(path, message));
228                }
229            }
230        }
231
232        Ok(MacroArgs {
233            name,
234            description,
235            param_descriptions,
236            required,
237        })
238    }
239}
240
241/// Extract doc comment text from `#[doc = "..."]` attributes.
242fn extract_doc_comment(attrs: &[Attribute]) -> Option<String> {
243    let lines: Vec<String> = attrs
244        .iter()
245        .filter_map(|attr| {
246            if !attr.path().is_ident("doc") {
247                return None;
248            }
249            if let Meta::NameValue(nv) = &attr.meta
250                && let Expr::Lit(ExprLit {
251                    lit: Lit::Str(s), ..
252                }) = &nv.value
253            {
254                return Some(s.value());
255            }
256            None
257        })
258        .collect();
259
260    if lines.is_empty() {
261        return None;
262    }
263
264    Some(
265        lines
266            .iter()
267            .map(|l| l.strip_prefix(' ').unwrap_or(l))
268            .collect::<Vec<_>>()
269            .join("\n")
270            .trim()
271            .to_string(),
272    )
273}
274
275/// Check if a type is `Option<T>`.
276fn is_option_type(ty: &Type) -> bool {
277    if let Type::Path(type_path) = ty
278        && let Some(segment) = type_path.path.segments.last()
279    {
280        return segment.ident == "Option";
281    }
282    false
283}
284
285fn result_type_tokens(
286    return_type: &ReturnType,
287) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
288    let ReturnType::Type(_, ty) = return_type else {
289        return Err(syn::Error::new_spanned(
290            return_type,
291            "function must have a return type of Result<T, E>",
292        ));
293    };
294
295    let Type::Path(type_path) = ty.deref() else {
296        return Err(syn::Error::new_spanned(
297            ty,
298            "return type must be Result<T, E>",
299        ));
300    };
301
302    let Some(last_segment) = type_path.path.segments.last() else {
303        return Err(syn::Error::new_spanned(
304            &type_path.path,
305            "return type must be Result<T, E>",
306        ));
307    };
308
309    if last_segment.ident != "Result" {
310        return Err(syn::Error::new_spanned(
311            &last_segment.ident,
312            "return type must be Result<T, E>",
313        ));
314    }
315
316    let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
317        return Err(syn::Error::new_spanned(
318            &last_segment.arguments,
319            "expected angle-bracketed type parameters for Result<T, E>",
320        ));
321    };
322
323    let mut generic_args = args.args.iter();
324    let Some(output) = generic_args.next() else {
325        return Err(syn::Error::new_spanned(
326            &args.args,
327            "expected Result<T, E> with exactly two type parameters",
328        ));
329    };
330    let Some(error) = generic_args.next() else {
331        return Err(syn::Error::new_spanned(
332            &args.args,
333            "expected Result<T, E> with exactly two type parameters",
334        ));
335    };
336
337    if generic_args.next().is_some() {
338        return Err(syn::Error::new_spanned(
339            &args.args,
340            "expected Result<T, E> with exactly two type parameters",
341        ));
342    }
343
344    Ok((quote!(#output), quote!(#error)))
345}
346
347/// A procedural macro that transforms a function into a `rig::tool::Tool` that can be used with a `rig::agent::Agent`.
348///
349/// # Examples
350///
351/// Basic usage:
352/// ```text
353/// use rig_derive::rig_tool;
354///
355/// #[rig_tool]
356/// fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
357///     Ok(a + b)
358/// }
359/// ```
360///
361/// With description:
362/// ```text
363/// use rig_derive::rig_tool;
364///
365/// #[rig_tool(description = "Perform basic arithmetic operations")]
366/// fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
367///     match operation.as_str() {
368///         "add" => Ok(x + y),
369///         "subtract" => Ok(x - y),
370///         "multiply" => Ok(x * y),
371///         "divide" => Ok(x / y),
372///         _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
373///     }
374/// }
375/// ```
376///
377/// With a custom tool name:
378/// ```text
379/// use rig_derive::rig_tool;
380///
381/// // Explicit names must be string literals that start with an ASCII letter
382/// // or `_`, may contain ASCII letters, digits, `_`, or `-`, and be at most
383/// // 64 characters long.
384/// #[rig_tool(name = "search-docs", description = "Search the documentation")]
385/// fn search_docs_impl(query: String) -> Result<String, rig::tool::ToolError> {
386///     Ok(format!("Searching docs for {query}"))
387/// }
388/// ```
389///
390/// With parameter descriptions:
391/// ```text
392/// use rig_derive::rig_tool;
393///
394/// #[rig_tool(
395///     description = "A tool that performs string operations",
396///     params(
397///         text = "The input text to process",
398///         operation = "The operation to perform (uppercase, lowercase, reverse)"
399///     )
400/// )]
401/// fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
402///     match operation.as_str() {
403///         "uppercase" => Ok(text.to_uppercase()),
404///         "lowercase" => Ok(text.to_lowercase()),
405///         "reverse" => Ok(text.chars().rev().collect()),
406///         _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
407///     }
408/// }
409/// ```
410#[proc_macro_attribute]
411pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
412    let args = parse_macro_input!(args as MacroArgs);
413    let input_fn = parse_macro_input!(input as syn::ItemFn);
414
415    // Extract function details
416    let fn_name = &input_fn.sig.ident;
417    let fn_name_str = fn_name.to_string();
418    let tool_name = args.name.clone().unwrap_or_else(|| fn_name_str.clone());
419    let vis = &input_fn.vis;
420    let is_async = input_fn.sig.asyncness.is_some();
421
422    // Build a cleaned copy of the function with doc attrs stripped from parameters,
423    // since `#[doc]` on function parameters is not allowed by the compiler.
424    let cleaned_fn = {
425        let mut f = input_fn.clone();
426        for arg in f.sig.inputs.iter_mut() {
427            if let syn::FnArg::Typed(pat_type) = arg {
428                pat_type.attrs.retain(|a| !a.path().is_ident("doc"));
429            }
430        }
431        f
432    };
433
434    // Extract return type and get Output and Error types from Result<T, E>
435    let return_type = &input_fn.sig.output;
436    let (output_type, error_type) = match result_type_tokens(return_type) {
437        Ok(types) => types,
438        Err(error) => return error.into_compile_error().into(),
439    };
440
441    // Generate PascalCase struct name from the function name
442    let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
443
444    // Tool description: explicit attribute > doc comment > default
445    let fn_doc = extract_doc_comment(&input_fn.attrs);
446    let tool_description = match args.description {
447        Some(desc) => quote! { #desc.to_string() },
448        None => match fn_doc {
449            Some(doc) => quote! { #doc.to_string() },
450            None => quote! { format!("Function to {}", Self::NAME) },
451        },
452    };
453
454    // Extract parameter names, doc comments, and build struct field tokens
455    let mut param_names = Vec::new();
456    let mut field_tokens = Vec::new();
457
458    for arg in input_fn.sig.inputs.iter() {
459        if let syn::FnArg::Typed(pat_type) = arg
460            && let syn::Pat::Ident(param_ident) = &*pat_type.pat
461        {
462            let param_name = &param_ident.ident;
463            let param_name_str = param_name.to_string();
464            let ty = &pat_type.ty;
465
466            // Determine the description for this field:
467            // explicit params() > parameter doc comment > default
468            let field_doc_attr =
469                if let Some(explicit) = args.param_descriptions.get(&param_name_str) {
470                    // Explicit override via params() — use #[schemars(description = "...")]
471                    quote! { #[schemars(description = #explicit)] }
472                } else if let Some(doc) = extract_doc_comment(&pat_type.attrs) {
473                    // Doc comment on the parameter — propagate as #[doc = "..."]
474                    quote! { #[doc = #doc] }
475                } else {
476                    // Default fallback
477                    let default_desc = format!("Parameter {param_name_str}");
478                    quote! { #[schemars(description = #default_desc)] }
479                };
480
481            // Auto-add #[serde(default)] for Option<T> fields
482            let serde_default = if is_option_type(ty) {
483                quote! { #[serde(default)] }
484            } else {
485                quote! {}
486            };
487
488            field_tokens.push(quote! {
489                #field_doc_attr
490                #serde_default
491                #vis #param_name: #ty
492            });
493
494            param_names.push(param_name);
495        }
496    }
497
498    // Default required to all parameters only when required(...) was omitted.
499    let required_args: Vec<String> = args
500        .required
501        .unwrap_or_else(|| param_names.iter().map(|n| n.to_string()).collect());
502
503    let params_struct_name = format_ident!("{}Parameters", struct_name);
504    let static_name = format_ident!("{}", fn_name_str.to_uppercase());
505
506    // Generate the call implementation based on whether the function is async
507    let call_impl = if is_async {
508        quote! {
509            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
510                #fn_name(#(args.#param_names,)*).await
511            }
512        }
513    } else {
514        quote! {
515            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
516                #fn_name(#(args.#param_names,)*)
517            }
518        }
519    };
520
521    let rig_core = rig_core_path();
522    let schemars_crate = format!("{}::schemars", rig_core.to_string().replace(' ', ""));
523    let expanded = quote! {
524        #[derive(serde::Deserialize, #rig_core::schemars::JsonSchema)]
525        #[schemars(crate = #schemars_crate)]
526        #vis struct #params_struct_name {
527            #(#field_tokens,)*
528        }
529
530        #cleaned_fn
531
532        #[derive(Default)]
533        #vis struct #struct_name;
534
535        impl #rig_core::tool::Tool for #struct_name {
536            const NAME: &'static str = #tool_name;
537
538            type Args = #params_struct_name;
539            type Output = #output_type;
540            type Error = #error_type;
541
542            fn name(&self) -> String {
543                #tool_name.to_string()
544            }
545
546            async fn definition(&self, _prompt: String) -> #rig_core::completion::ToolDefinition {
547                let mut schema = serde_json::to_value(
548                    #rig_core::schemars::schema_for!(#params_struct_name)
549                ).expect("schema serialization");
550                schema["required"] = serde_json::json!([#(#required_args),*]);
551
552                #rig_core::completion::ToolDefinition {
553                    name: #tool_name.to_string(),
554                    description: #tool_description.to_string(),
555                    parameters: schema,
556                }
557            }
558
559            #call_impl
560        }
561
562        #vis static #static_name: #struct_name = #struct_name;
563    };
564
565    TokenStream::from(expanded)
566}