Skip to main content

rig_derive/
lib.rs

1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use std::{collections::HashMap, ops::Deref};
7use syn::{
8    DeriveInput, Expr, ExprLit, Ident, Lit, Meta, PathArguments, ReturnType, Token, Type,
9    parse::{Parse, ParseStream},
10    parse_macro_input,
11    punctuated::Punctuated,
12};
13
14mod basic;
15mod client;
16mod custom;
17mod embed;
18
19pub(crate) const EMBED: &str = "embed";
20
21#[proc_macro_derive(ProviderClient, attributes(client))]
22pub fn derive_provider_client(input: TokenStream) -> TokenStream {
23    client::provider_client(input)
24}
25
26//References:
27//<https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro>
28//<https://doc.rust-lang.org/reference/procedural-macros.html>
29/// A macro that allows you to implement the `rig::embedding::Embed` trait by deriving it.
30/// Usage can be found below:
31///
32/// ```text
33/// use rig::Embed;
34/// use rig_derive::Embed;
35///
36/// #[derive(Embed)]
37/// struct Foo {
38///     id: String,
39///     #[embed] // this helper shows which field to embed
40///     description: String
41///}
42/// ```
43#[proc_macro_derive(Embed, attributes(embed))]
44pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
45    let mut input = parse_macro_input!(item as DeriveInput);
46
47    embed::expand_derive_embedding(&mut input)
48        .unwrap_or_else(syn::Error::into_compile_error)
49        .into()
50}
51
52struct MacroArgs {
53    name: Option<String>,
54    description: Option<String>,
55    param_descriptions: HashMap<String, String>,
56    required: Vec<String>,
57}
58
59fn parse_string_literal(expr: &Expr, field_name: &str) -> syn::Result<String> {
60    match expr {
61        Expr::Lit(ExprLit {
62            lit: Lit::Str(lit_str),
63            ..
64        }) => Ok(lit_str.value()),
65        _ => Err(syn::Error::new_spanned(
66            expr,
67            format!("`{field_name}` must be a string literal"),
68        )),
69    }
70}
71
72fn validate_explicit_tool_name(name: &str, expr: &Expr) -> syn::Result<()> {
73    if name.is_empty() || name.len() > 64 {
74        return Err(syn::Error::new_spanned(
75            expr,
76            "`name` must be between 1 and 64 characters long",
77        ));
78    }
79
80    let mut chars = name.chars();
81    let Some(first_char) = chars.next() else {
82        return Err(syn::Error::new_spanned(
83            expr,
84            "`name` must be between 1 and 64 characters long",
85        ));
86    };
87
88    if !first_char.is_ascii_alphabetic() && first_char != '_' {
89        return Err(syn::Error::new_spanned(
90            expr,
91            "`name` must start with an ASCII letter or underscore",
92        ));
93    }
94
95    if chars.any(|ch| !ch.is_ascii_alphanumeric() && ch != '_' && ch != '-') {
96        return Err(syn::Error::new_spanned(
97            expr,
98            "`name` may only contain ASCII letters, digits, underscores, or hyphens",
99        ));
100    }
101
102    Ok(())
103}
104
105impl Parse for MacroArgs {
106    fn parse(input: ParseStream) -> syn::Result<Self> {
107        let mut name = None;
108        let mut description = None;
109        let mut param_descriptions = HashMap::new();
110        let mut required = Vec::new();
111
112        // If the input is empty, return default values
113        if input.is_empty() {
114            return Ok(MacroArgs {
115                name,
116                description,
117                param_descriptions,
118                required,
119            });
120        }
121
122        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
123
124        for meta in meta_list {
125            match meta {
126                Meta::NameValue(nv) => {
127                    let ident = nv.path.get_ident().ok_or_else(|| {
128                        syn::Error::new_spanned(
129                            &nv.path,
130                            "unsupported top-level #[rig_tool] argument",
131                        )
132                    })?;
133
134                    match ident.to_string().as_str() {
135                        "name" => {
136                            let parsed_name = parse_string_literal(&nv.value, "name")?;
137                            validate_explicit_tool_name(&parsed_name, &nv.value)?;
138                            name = Some(parsed_name);
139                        }
140                        "description" => {
141                            description = Some(parse_string_literal(&nv.value, "description")?);
142                        }
143                        _ => {
144                            return Err(syn::Error::new_spanned(
145                                &nv.path,
146                                format!("unsupported top-level #[rig_tool] argument `{}`", ident),
147                            ));
148                        }
149                    }
150                }
151                Meta::List(list) => {
152                    let ident = list.path.get_ident().ok_or_else(|| {
153                        syn::Error::new_spanned(
154                            &list.path,
155                            "unsupported top-level #[rig_tool] argument",
156                        )
157                    })?;
158
159                    match ident.to_string().as_str() {
160                        "params" => {
161                            let nested: Punctuated<Meta, Token![,]> =
162                                list.parse_args_with(Punctuated::parse_terminated)?;
163
164                            for meta in nested {
165                                if let Meta::NameValue(nv) = meta
166                                    && let Expr::Lit(ExprLit {
167                                        lit: Lit::Str(lit_str),
168                                        ..
169                                    }) = nv.value
170                                {
171                                    let Some(param_ident) = nv.path.get_ident() else {
172                                        return Err(syn::Error::new_spanned(
173                                            &nv.path,
174                                            "parameter descriptions must use identifier keys",
175                                        ));
176                                    };
177                                    let param_name = param_ident.to_string();
178                                    param_descriptions.insert(param_name, lit_str.value());
179                                }
180                            }
181                        }
182                        "required" => {
183                            let required_variables: Punctuated<Ident, Token![,]> =
184                                list.parse_args_with(Punctuated::parse_terminated)?;
185
186                            required_variables.into_iter().for_each(|x| {
187                                required.push(x.to_string());
188                            });
189                        }
190                        _ => {
191                            return Err(syn::Error::new_spanned(
192                                &list.path,
193                                format!("unsupported top-level #[rig_tool] argument `{}`", ident),
194                            ));
195                        }
196                    }
197                }
198                Meta::Path(path) => {
199                    let message = if let Some(ident) = path.get_ident() {
200                        format!("unsupported top-level #[rig_tool] argument `{ident}`")
201                    } else {
202                        "unsupported top-level #[rig_tool] argument".to_string()
203                    };
204
205                    return Err(syn::Error::new_spanned(path, message));
206                }
207            }
208        }
209
210        Ok(MacroArgs {
211            name,
212            description,
213            param_descriptions,
214            required,
215        })
216    }
217}
218
219fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
220    match ty {
221        Type::Path(type_path) => {
222            let Some(segment) = type_path.path.segments.first() else {
223                return quote! { "type": "object" };
224            };
225            let type_name = segment.ident.to_string();
226
227            // Handle Vec types
228            if type_name == "Vec" {
229                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
230                    && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
231                {
232                    let inner_json_type = get_json_type(inner_type);
233                    return quote! {
234                        "type": "array",
235                        "items": { #inner_json_type }
236                    };
237                }
238                return quote! { "type": "array" };
239            }
240
241            // Handle primitive types
242            match type_name.as_str() {
243                "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
244                    quote! { "type": "number" }
245                }
246                "String" | "str" => {
247                    quote! { "type": "string" }
248                }
249                "bool" => {
250                    quote! { "type": "boolean" }
251                }
252                // Handle other types as objects
253                _ => {
254                    quote! { "type": "object" }
255                }
256            }
257        }
258        _ => {
259            quote! { "type": "object" }
260        }
261    }
262}
263
264fn result_type_tokens(
265    return_type: &ReturnType,
266) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
267    let ReturnType::Type(_, ty) = return_type else {
268        return Err(syn::Error::new_spanned(
269            return_type,
270            "function must have a return type of Result<T, E>",
271        ));
272    };
273
274    let Type::Path(type_path) = ty.deref() else {
275        return Err(syn::Error::new_spanned(
276            ty,
277            "return type must be Result<T, E>",
278        ));
279    };
280
281    let Some(last_segment) = type_path.path.segments.last() else {
282        return Err(syn::Error::new_spanned(
283            &type_path.path,
284            "return type must be Result<T, E>",
285        ));
286    };
287
288    if last_segment.ident != "Result" {
289        return Err(syn::Error::new_spanned(
290            &last_segment.ident,
291            "return type must be Result<T, E>",
292        ));
293    }
294
295    let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
296        return Err(syn::Error::new_spanned(
297            &last_segment.arguments,
298            "expected angle-bracketed type parameters for Result<T, E>",
299        ));
300    };
301
302    let mut generic_args = args.args.iter();
303    let Some(output) = generic_args.next() else {
304        return Err(syn::Error::new_spanned(
305            &args.args,
306            "expected Result<T, E> with exactly two type parameters",
307        ));
308    };
309    let Some(error) = generic_args.next() else {
310        return Err(syn::Error::new_spanned(
311            &args.args,
312            "expected Result<T, E> with exactly two type parameters",
313        ));
314    };
315
316    if generic_args.next().is_some() {
317        return Err(syn::Error::new_spanned(
318            &args.args,
319            "expected Result<T, E> with exactly two type parameters",
320        ));
321    }
322
323    Ok((quote!(#output), quote!(#error)))
324}
325
326/// A procedural macro that transforms a function into a `rig::tool::Tool` that can be used with a `rig::agent::Agent`.
327///
328/// # Examples
329///
330/// Basic usage:
331/// ```text
332/// use rig_derive::rig_tool;
333///
334/// #[rig_tool]
335/// fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
336///     Ok(a + b)
337/// }
338/// ```
339///
340/// With description:
341/// ```text
342/// use rig_derive::rig_tool;
343///
344/// #[rig_tool(description = "Perform basic arithmetic operations")]
345/// fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
346///     match operation.as_str() {
347///         "add" => Ok(x + y),
348///         "subtract" => Ok(x - y),
349///         "multiply" => Ok(x * y),
350///         "divide" => Ok(x / y),
351///         _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
352///     }
353/// }
354/// ```
355///
356/// With a custom tool name:
357/// ```text
358/// use rig_derive::rig_tool;
359///
360/// // Explicit names must be string literals that start with an ASCII letter
361/// // or `_`, may contain ASCII letters, digits, `_`, or `-`, and be at most
362/// // 64 characters long.
363/// #[rig_tool(name = "search-docs", description = "Search the documentation")]
364/// fn search_docs_impl(query: String) -> Result<String, rig::tool::ToolError> {
365///     Ok(format!("Searching docs for {query}"))
366/// }
367/// ```
368///
369/// With parameter descriptions:
370/// ```text
371/// use rig_derive::rig_tool;
372///
373/// #[rig_tool(
374///     description = "A tool that performs string operations",
375///     params(
376///         text = "The input text to process",
377///         operation = "The operation to perform (uppercase, lowercase, reverse)"
378///     )
379/// )]
380/// fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
381///     match operation.as_str() {
382///         "uppercase" => Ok(text.to_uppercase()),
383///         "lowercase" => Ok(text.to_lowercase()),
384///         "reverse" => Ok(text.chars().rev().collect()),
385///         _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
386///     }
387/// }
388/// ```
389#[proc_macro_attribute]
390pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
391    let args = parse_macro_input!(args as MacroArgs);
392    let input_fn = parse_macro_input!(input as syn::ItemFn);
393
394    // Extract function details
395    let fn_name = &input_fn.sig.ident;
396    let fn_name_str = fn_name.to_string();
397    let tool_name = args.name.clone().unwrap_or_else(|| fn_name_str.clone());
398    let vis = &input_fn.vis;
399    let is_async = input_fn.sig.asyncness.is_some();
400
401    // Extract return type and get Output and Error types from Result<T, E>
402    let return_type = &input_fn.sig.output;
403    let (output_type, error_type) = match result_type_tokens(return_type) {
404        Ok(types) => types,
405        Err(error) => return error.into_compile_error().into(),
406    };
407
408    // Generate PascalCase struct name from the function name
409    let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
410
411    // Use provided description or generate a default one
412    let tool_description = match args.description {
413        Some(desc) => quote! { #desc.to_string() },
414        None => quote! { format!("Function to {}", Self::NAME) },
415    };
416
417    // Extract parameter names, types, and descriptions
418    let mut param_names = Vec::new();
419    let mut param_types = Vec::new();
420    let mut param_descriptions = Vec::new();
421    let mut json_types = Vec::new();
422
423    let required_args = args.required;
424
425    for arg in input_fn.sig.inputs.iter() {
426        if let syn::FnArg::Typed(pat_type) = arg
427            && let syn::Pat::Ident(param_ident) = &*pat_type.pat
428        {
429            let param_name = &param_ident.ident;
430            let param_name_str = param_name.to_string();
431            let ty = &pat_type.ty;
432            let default_parameter_description = format!("Parameter {param_name_str}");
433            let description = args
434                .param_descriptions
435                .get(&param_name_str)
436                .map(|s| s.to_owned())
437                .unwrap_or(default_parameter_description);
438
439            param_names.push(param_name);
440            param_types.push(ty);
441            param_descriptions.push(description);
442            json_types.push(get_json_type(ty));
443        }
444    }
445
446    let params_struct_name = format_ident!("{}Parameters", struct_name);
447    let static_name = format_ident!("{}", fn_name_str.to_uppercase());
448
449    // Generate the call implementation based on whether the function is async
450    let call_impl = if is_async {
451        quote! {
452            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
453                #fn_name(#(args.#param_names,)*).await
454            }
455        }
456    } else {
457        quote! {
458            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
459                #fn_name(#(args.#param_names,)*)
460            }
461        }
462    };
463
464    let expanded = quote! {
465        #[derive(serde::Deserialize)]
466        #vis struct #params_struct_name {
467            #(#vis #param_names: #param_types,)*
468        }
469
470        #input_fn
471
472        #[derive(Default)]
473        #vis struct #struct_name;
474
475        impl rig::tool::Tool for #struct_name {
476            const NAME: &'static str = #tool_name;
477
478            type Args = #params_struct_name;
479            type Output = #output_type;
480            type Error = #error_type;
481
482            fn name(&self) -> String {
483                #tool_name.to_string()
484            }
485
486            async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
487                let parameters = serde_json::json!({
488                    "type": "object",
489                    "properties": {
490                        #(
491                            stringify!(#param_names): {
492                                #json_types,
493                                "description": #param_descriptions
494                            }
495                        ),*
496                    },
497                    "required": [#(#required_args),*]
498                });
499
500                rig::completion::ToolDefinition {
501                    name: #tool_name.to_string(),
502                    description: #tool_description.to_string(),
503                    parameters,
504                }
505            }
506
507            #call_impl
508        }
509
510        #vis static #static_name: #struct_name = #struct_name;
511    };
512
513    TokenStream::from(expanded)
514}