Skip to main content

rig_derive/
lib.rs

1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use std::{collections::HashMap, ops::Deref};
7use syn::{
8    DeriveInput, Expr, ExprLit, Ident, Lit, Meta, PathArguments, ReturnType, Token, Type,
9    parse::{Parse, ParseStream},
10    parse_macro_input,
11    punctuated::Punctuated,
12};
13
14mod basic;
15mod client;
16mod custom;
17mod embed;
18
19pub(crate) const EMBED: &str = "embed";
20
21pub(crate) fn rig_core_path() -> proc_macro2::TokenStream {
22    match proc_macro_crate::crate_name("rig-core") {
23        Ok(proc_macro_crate::FoundCrate::Itself) => quote!(crate),
24        Ok(proc_macro_crate::FoundCrate::Name(name)) => {
25            let ident = format_ident!("{name}");
26            quote!(::#ident)
27        }
28        Err(_) => match proc_macro_crate::crate_name("rig") {
29            Ok(proc_macro_crate::FoundCrate::Itself) => quote!(crate),
30            Ok(proc_macro_crate::FoundCrate::Name(name)) => {
31                let ident = format_ident!("{name}");
32                quote!(::#ident)
33            }
34            Err(_) => quote!(::rig_core),
35        },
36    }
37}
38
39#[proc_macro_derive(ProviderClient, attributes(client))]
40pub fn derive_provider_client(input: TokenStream) -> TokenStream {
41    client::provider_client(input)
42}
43
44//References:
45//<https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro>
46//<https://doc.rust-lang.org/reference/procedural-macros.html>
47/// A macro that allows you to implement the `rig::embedding::Embed` trait by deriving it.
48/// Usage can be found below:
49///
50/// ```text
51/// use rig::Embed;
52/// use rig_derive::Embed;
53///
54/// #[derive(Embed)]
55/// struct Foo {
56///     id: String,
57///     #[embed] // this helper shows which field to embed
58///     description: String
59///}
60/// ```
61#[proc_macro_derive(Embed, attributes(embed))]
62pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
63    let mut input = parse_macro_input!(item as DeriveInput);
64
65    embed::expand_derive_embedding(&mut input)
66        .unwrap_or_else(syn::Error::into_compile_error)
67        .into()
68}
69
70struct MacroArgs {
71    name: Option<String>,
72    description: Option<String>,
73    param_descriptions: HashMap<String, String>,
74    required: Vec<String>,
75}
76
77fn parse_string_literal(expr: &Expr, field_name: &str) -> syn::Result<String> {
78    match expr {
79        Expr::Lit(ExprLit {
80            lit: Lit::Str(lit_str),
81            ..
82        }) => Ok(lit_str.value()),
83        _ => Err(syn::Error::new_spanned(
84            expr,
85            format!("`{field_name}` must be a string literal"),
86        )),
87    }
88}
89
90fn validate_explicit_tool_name(name: &str, expr: &Expr) -> syn::Result<()> {
91    if name.is_empty() || name.len() > 64 {
92        return Err(syn::Error::new_spanned(
93            expr,
94            "`name` must be between 1 and 64 characters long",
95        ));
96    }
97
98    let mut chars = name.chars();
99    let Some(first_char) = chars.next() else {
100        return Err(syn::Error::new_spanned(
101            expr,
102            "`name` must be between 1 and 64 characters long",
103        ));
104    };
105
106    if !first_char.is_ascii_alphabetic() && first_char != '_' {
107        return Err(syn::Error::new_spanned(
108            expr,
109            "`name` must start with an ASCII letter or underscore",
110        ));
111    }
112
113    if chars.any(|ch| !ch.is_ascii_alphanumeric() && ch != '_' && ch != '-') {
114        return Err(syn::Error::new_spanned(
115            expr,
116            "`name` may only contain ASCII letters, digits, underscores, or hyphens",
117        ));
118    }
119
120    Ok(())
121}
122
123impl Parse for MacroArgs {
124    fn parse(input: ParseStream) -> syn::Result<Self> {
125        let mut name = None;
126        let mut description = None;
127        let mut param_descriptions = HashMap::new();
128        let mut required = Vec::new();
129
130        // If the input is empty, return default values
131        if input.is_empty() {
132            return Ok(MacroArgs {
133                name,
134                description,
135                param_descriptions,
136                required,
137            });
138        }
139
140        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
141
142        for meta in meta_list {
143            match meta {
144                Meta::NameValue(nv) => {
145                    let ident = nv.path.get_ident().ok_or_else(|| {
146                        syn::Error::new_spanned(
147                            &nv.path,
148                            "unsupported top-level #[rig_tool] argument",
149                        )
150                    })?;
151
152                    match ident.to_string().as_str() {
153                        "name" => {
154                            let parsed_name = parse_string_literal(&nv.value, "name")?;
155                            validate_explicit_tool_name(&parsed_name, &nv.value)?;
156                            name = Some(parsed_name);
157                        }
158                        "description" => {
159                            description = Some(parse_string_literal(&nv.value, "description")?);
160                        }
161                        _ => {
162                            return Err(syn::Error::new_spanned(
163                                &nv.path,
164                                format!("unsupported top-level #[rig_tool] argument `{}`", ident),
165                            ));
166                        }
167                    }
168                }
169                Meta::List(list) => {
170                    let ident = list.path.get_ident().ok_or_else(|| {
171                        syn::Error::new_spanned(
172                            &list.path,
173                            "unsupported top-level #[rig_tool] argument",
174                        )
175                    })?;
176
177                    match ident.to_string().as_str() {
178                        "params" => {
179                            let nested: Punctuated<Meta, Token![,]> =
180                                list.parse_args_with(Punctuated::parse_terminated)?;
181
182                            for meta in nested {
183                                if let Meta::NameValue(nv) = meta
184                                    && let Expr::Lit(ExprLit {
185                                        lit: Lit::Str(lit_str),
186                                        ..
187                                    }) = nv.value
188                                {
189                                    let Some(param_ident) = nv.path.get_ident() else {
190                                        return Err(syn::Error::new_spanned(
191                                            &nv.path,
192                                            "parameter descriptions must use identifier keys",
193                                        ));
194                                    };
195                                    let param_name = param_ident.to_string();
196                                    param_descriptions.insert(param_name, lit_str.value());
197                                }
198                            }
199                        }
200                        "required" => {
201                            let required_variables: Punctuated<Ident, Token![,]> =
202                                list.parse_args_with(Punctuated::parse_terminated)?;
203
204                            required_variables.into_iter().for_each(|x| {
205                                required.push(x.to_string());
206                            });
207                        }
208                        _ => {
209                            return Err(syn::Error::new_spanned(
210                                &list.path,
211                                format!("unsupported top-level #[rig_tool] argument `{}`", ident),
212                            ));
213                        }
214                    }
215                }
216                Meta::Path(path) => {
217                    let message = if let Some(ident) = path.get_ident() {
218                        format!("unsupported top-level #[rig_tool] argument `{ident}`")
219                    } else {
220                        "unsupported top-level #[rig_tool] argument".to_string()
221                    };
222
223                    return Err(syn::Error::new_spanned(path, message));
224                }
225            }
226        }
227
228        Ok(MacroArgs {
229            name,
230            description,
231            param_descriptions,
232            required,
233        })
234    }
235}
236
237fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
238    match ty {
239        Type::Path(type_path) => {
240            let Some(segment) = type_path.path.segments.first() else {
241                return quote! { "type": "object" };
242            };
243            let type_name = segment.ident.to_string();
244
245            // Handle Vec types
246            if type_name == "Vec" {
247                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
248                    && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
249                {
250                    let inner_json_type = get_json_type(inner_type);
251                    return quote! {
252                        "type": "array",
253                        "items": { #inner_json_type }
254                    };
255                }
256                return quote! { "type": "array" };
257            }
258
259            // Handle primitive types
260            match type_name.as_str() {
261                "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
262                    quote! { "type": "number" }
263                }
264                "String" | "str" => {
265                    quote! { "type": "string" }
266                }
267                "bool" => {
268                    quote! { "type": "boolean" }
269                }
270                // Handle other types as objects
271                _ => {
272                    quote! { "type": "object" }
273                }
274            }
275        }
276        _ => {
277            quote! { "type": "object" }
278        }
279    }
280}
281
282fn result_type_tokens(
283    return_type: &ReturnType,
284) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
285    let ReturnType::Type(_, ty) = return_type else {
286        return Err(syn::Error::new_spanned(
287            return_type,
288            "function must have a return type of Result<T, E>",
289        ));
290    };
291
292    let Type::Path(type_path) = ty.deref() else {
293        return Err(syn::Error::new_spanned(
294            ty,
295            "return type must be Result<T, E>",
296        ));
297    };
298
299    let Some(last_segment) = type_path.path.segments.last() else {
300        return Err(syn::Error::new_spanned(
301            &type_path.path,
302            "return type must be Result<T, E>",
303        ));
304    };
305
306    if last_segment.ident != "Result" {
307        return Err(syn::Error::new_spanned(
308            &last_segment.ident,
309            "return type must be Result<T, E>",
310        ));
311    }
312
313    let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
314        return Err(syn::Error::new_spanned(
315            &last_segment.arguments,
316            "expected angle-bracketed type parameters for Result<T, E>",
317        ));
318    };
319
320    let mut generic_args = args.args.iter();
321    let Some(output) = generic_args.next() else {
322        return Err(syn::Error::new_spanned(
323            &args.args,
324            "expected Result<T, E> with exactly two type parameters",
325        ));
326    };
327    let Some(error) = generic_args.next() else {
328        return Err(syn::Error::new_spanned(
329            &args.args,
330            "expected Result<T, E> with exactly two type parameters",
331        ));
332    };
333
334    if generic_args.next().is_some() {
335        return Err(syn::Error::new_spanned(
336            &args.args,
337            "expected Result<T, E> with exactly two type parameters",
338        ));
339    }
340
341    Ok((quote!(#output), quote!(#error)))
342}
343
344/// A procedural macro that transforms a function into a `rig::tool::Tool` that can be used with a `rig::agent::Agent`.
345///
346/// # Examples
347///
348/// Basic usage:
349/// ```text
350/// use rig_derive::rig_tool;
351///
352/// #[rig_tool]
353/// fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
354///     Ok(a + b)
355/// }
356/// ```
357///
358/// With description:
359/// ```text
360/// use rig_derive::rig_tool;
361///
362/// #[rig_tool(description = "Perform basic arithmetic operations")]
363/// fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
364///     match operation.as_str() {
365///         "add" => Ok(x + y),
366///         "subtract" => Ok(x - y),
367///         "multiply" => Ok(x * y),
368///         "divide" => Ok(x / y),
369///         _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
370///     }
371/// }
372/// ```
373///
374/// With a custom tool name:
375/// ```text
376/// use rig_derive::rig_tool;
377///
378/// // Explicit names must be string literals that start with an ASCII letter
379/// // or `_`, may contain ASCII letters, digits, `_`, or `-`, and be at most
380/// // 64 characters long.
381/// #[rig_tool(name = "search-docs", description = "Search the documentation")]
382/// fn search_docs_impl(query: String) -> Result<String, rig::tool::ToolError> {
383///     Ok(format!("Searching docs for {query}"))
384/// }
385/// ```
386///
387/// With parameter descriptions:
388/// ```text
389/// use rig_derive::rig_tool;
390///
391/// #[rig_tool(
392///     description = "A tool that performs string operations",
393///     params(
394///         text = "The input text to process",
395///         operation = "The operation to perform (uppercase, lowercase, reverse)"
396///     )
397/// )]
398/// fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
399///     match operation.as_str() {
400///         "uppercase" => Ok(text.to_uppercase()),
401///         "lowercase" => Ok(text.to_lowercase()),
402///         "reverse" => Ok(text.chars().rev().collect()),
403///         _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
404///     }
405/// }
406/// ```
407#[proc_macro_attribute]
408pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
409    let args = parse_macro_input!(args as MacroArgs);
410    let input_fn = parse_macro_input!(input as syn::ItemFn);
411
412    // Extract function details
413    let fn_name = &input_fn.sig.ident;
414    let fn_name_str = fn_name.to_string();
415    let tool_name = args.name.clone().unwrap_or_else(|| fn_name_str.clone());
416    let vis = &input_fn.vis;
417    let is_async = input_fn.sig.asyncness.is_some();
418
419    // Extract return type and get Output and Error types from Result<T, E>
420    let return_type = &input_fn.sig.output;
421    let (output_type, error_type) = match result_type_tokens(return_type) {
422        Ok(types) => types,
423        Err(error) => return error.into_compile_error().into(),
424    };
425
426    // Generate PascalCase struct name from the function name
427    let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
428
429    // Use provided description or generate a default one
430    let tool_description = match args.description {
431        Some(desc) => quote! { #desc.to_string() },
432        None => quote! { format!("Function to {}", Self::NAME) },
433    };
434
435    // Extract parameter names, types, and descriptions
436    let mut param_names = Vec::new();
437    let mut param_types = Vec::new();
438    let mut param_descriptions = Vec::new();
439    let mut json_types = Vec::new();
440
441    let required_args = args.required;
442
443    for arg in input_fn.sig.inputs.iter() {
444        if let syn::FnArg::Typed(pat_type) = arg
445            && let syn::Pat::Ident(param_ident) = &*pat_type.pat
446        {
447            let param_name = &param_ident.ident;
448            let param_name_str = param_name.to_string();
449            let ty = &pat_type.ty;
450            let default_parameter_description = format!("Parameter {param_name_str}");
451            let description = args
452                .param_descriptions
453                .get(&param_name_str)
454                .map(|s| s.to_owned())
455                .unwrap_or(default_parameter_description);
456
457            param_names.push(param_name);
458            param_types.push(ty);
459            param_descriptions.push(description);
460            json_types.push(get_json_type(ty));
461        }
462    }
463
464    let params_struct_name = format_ident!("{}Parameters", struct_name);
465    let static_name = format_ident!("{}", fn_name_str.to_uppercase());
466
467    // Generate the call implementation based on whether the function is async
468    let call_impl = if is_async {
469        quote! {
470            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
471                #fn_name(#(args.#param_names,)*).await
472            }
473        }
474    } else {
475        quote! {
476            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
477                #fn_name(#(args.#param_names,)*)
478            }
479        }
480    };
481
482    let rig_core = rig_core_path();
483    let expanded = quote! {
484        #[derive(serde::Deserialize)]
485        #vis struct #params_struct_name {
486            #(#vis #param_names: #param_types,)*
487        }
488
489        #input_fn
490
491        #[derive(Default)]
492        #vis struct #struct_name;
493
494        impl #rig_core::tool::Tool for #struct_name {
495            const NAME: &'static str = #tool_name;
496
497            type Args = #params_struct_name;
498            type Output = #output_type;
499            type Error = #error_type;
500
501            fn name(&self) -> String {
502                #tool_name.to_string()
503            }
504
505            async fn definition(&self, _prompt: String) -> #rig_core::completion::ToolDefinition {
506                let parameters = serde_json::json!({
507                    "type": "object",
508                    "properties": {
509                        #(
510                            stringify!(#param_names): {
511                                #json_types,
512                                "description": #param_descriptions
513                            }
514                        ),*
515                    },
516                    "required": [#(#required_args),*]
517                });
518
519                #rig_core::completion::ToolDefinition {
520                    name: #tool_name.to_string(),
521                    description: #tool_description.to_string(),
522                    parameters,
523                }
524            }
525
526            #call_impl
527        }
528
529        #vis static #static_name: #struct_name = #struct_name;
530    };
531
532    TokenStream::from(expanded)
533}