Skip to main content

rig_derive/
lib.rs

1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use std::{collections::HashMap, ops::Deref};
7use syn::{
8    DeriveInput, Expr, ExprLit, Ident, Lit, Meta, PathArguments, ReturnType, Token, Type,
9    parse::{Parse, ParseStream},
10    parse_macro_input,
11    punctuated::Punctuated,
12};
13
14mod basic;
15mod client;
16mod custom;
17mod embed;
18
19pub(crate) const EMBED: &str = "embed";
20
21#[proc_macro_derive(ProviderClient, attributes(client))]
22pub fn derive_provider_client(input: TokenStream) -> TokenStream {
23    client::provider_client(input)
24}
25
26//References:
27//<https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro>
28//<https://doc.rust-lang.org/reference/procedural-macros.html>
29/// A macro that allows you to implement the `rig::embedding::Embed` trait by deriving it.
30/// Usage can be found below:
31///
32/// ```rust
33/// #[derive(Embed)]
34/// struct Foo {
35///     id: String,
36///     #[embed] // this helper shows which field to embed
37///     description: String
38///}
39/// ```
40#[proc_macro_derive(Embed, attributes(embed))]
41pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
42    let mut input = parse_macro_input!(item as DeriveInput);
43
44    embed::expand_derive_embedding(&mut input)
45        .unwrap_or_else(syn::Error::into_compile_error)
46        .into()
47}
48
49struct MacroArgs {
50    description: Option<String>,
51    param_descriptions: HashMap<String, String>,
52    required: Vec<String>,
53}
54
55impl Parse for MacroArgs {
56    fn parse(input: ParseStream) -> syn::Result<Self> {
57        let mut description = None;
58        let mut param_descriptions = HashMap::new();
59        let mut required = Vec::new();
60
61        // If the input is empty, return default values
62        if input.is_empty() {
63            return Ok(MacroArgs {
64                description,
65                param_descriptions,
66                required,
67            });
68        }
69
70        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
71
72        for meta in meta_list {
73            match meta {
74                Meta::NameValue(nv) => {
75                    let ident = nv.path.get_ident().unwrap().to_string();
76                    if let Expr::Lit(ExprLit {
77                        lit: Lit::Str(lit_str),
78                        ..
79                    }) = nv.value
80                        && ident.as_str() == "description"
81                    {
82                        description = Some(lit_str.value());
83                    }
84                }
85                Meta::List(list) if list.path.is_ident("params") => {
86                    let nested: Punctuated<Meta, Token![,]> =
87                        list.parse_args_with(Punctuated::parse_terminated)?;
88
89                    for meta in nested {
90                        if let Meta::NameValue(nv) = meta
91                            && let Expr::Lit(ExprLit {
92                                lit: Lit::Str(lit_str),
93                                ..
94                            }) = nv.value
95                        {
96                            let param_name = nv.path.get_ident().unwrap().to_string();
97                            param_descriptions.insert(param_name, lit_str.value());
98                        }
99                    }
100                }
101                Meta::List(list) if list.path.is_ident("required") => {
102                    let required_variables: Punctuated<Ident, Token![,]> =
103                        list.parse_args_with(Punctuated::parse_terminated)?;
104
105                    required_variables.into_iter().for_each(|x| {
106                        required.push(x.to_string());
107                    });
108                }
109                _ => {}
110            }
111        }
112
113        Ok(MacroArgs {
114            description,
115            param_descriptions,
116            required,
117        })
118    }
119}
120
121fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
122    match ty {
123        Type::Path(type_path) => {
124            let segment = &type_path.path.segments[0];
125            let type_name = segment.ident.to_string();
126
127            // Handle Vec types
128            if type_name == "Vec" {
129                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
130                    && let syn::GenericArgument::Type(inner_type) = &args.args[0]
131                {
132                    let inner_json_type = get_json_type(inner_type);
133                    return quote! {
134                        "type": "array",
135                        "items": { #inner_json_type }
136                    };
137                }
138                return quote! { "type": "array" };
139            }
140
141            // Handle primitive types
142            match type_name.as_str() {
143                "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
144                    quote! { "type": "number" }
145                }
146                "String" | "str" => {
147                    quote! { "type": "string" }
148                }
149                "bool" => {
150                    quote! { "type": "boolean" }
151                }
152                // Handle other types as objects
153                _ => {
154                    quote! { "type": "object" }
155                }
156            }
157        }
158        _ => {
159            quote! { "type": "object" }
160        }
161    }
162}
163
164/// A procedural macro that transforms a function into a `rig::tool::Tool` that can be used with a `rig::agent::Agent`.
165///
166/// # Examples
167///
168/// Basic usage:
169/// ```rust
170/// use rig_derive::rig_tool;
171///
172/// #[rig_tool]
173/// fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
174///     Ok(a + b)
175/// }
176/// ```
177///
178/// With description:
179/// ```rust
180/// use rig_derive::rig_tool;
181///
182/// #[rig_tool(description = "Perform basic arithmetic operations")]
183/// fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
184///     match operation.as_str() {
185///         "add" => Ok(x + y),
186///         "subtract" => Ok(x - y),
187///         "multiply" => Ok(x * y),
188///         "divide" => Ok(x / y),
189///         _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
190///     }
191/// }
192/// ```
193///
194/// With parameter descriptions:
195/// ```rust
196/// use rig_derive::rig_tool;
197///
198/// #[rig_tool(
199///     description = "A tool that performs string operations",
200///     params(
201///         text = "The input text to process",
202///         operation = "The operation to perform (uppercase, lowercase, reverse)"
203///     )
204/// )]
205/// fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
206///     match operation.as_str() {
207///         "uppercase" => Ok(text.to_uppercase()),
208///         "lowercase" => Ok(text.to_lowercase()),
209///         "reverse" => Ok(text.chars().rev().collect()),
210///         _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
211///     }
212/// }
213/// ```
214#[proc_macro_attribute]
215pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
216    let args = parse_macro_input!(args as MacroArgs);
217    let input_fn = parse_macro_input!(input as syn::ItemFn);
218
219    // Extract function details
220    let fn_name = &input_fn.sig.ident;
221    let fn_name_str = fn_name.to_string();
222    let vis = &input_fn.vis;
223    let is_async = input_fn.sig.asyncness.is_some();
224
225    // Extract return type and get Output and Error types from Result<T, E>
226    let return_type = &input_fn.sig.output;
227    let (output_type, error_type) = match return_type {
228        ReturnType::Type(_, ty) => {
229            if let Type::Path(type_path) = ty.deref() {
230                if let Some(last_segment) = type_path.path.segments.last() {
231                    if last_segment.ident == "Result" {
232                        if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
233                            if args.args.len() == 2 {
234                                let output = args.args.first().unwrap();
235                                let error = args.args.last().unwrap();
236
237                                (quote!(#output), quote!(#error))
238                            } else {
239                                panic!("Expected Result with two type parameters");
240                            }
241                        } else {
242                            panic!("Expected angle bracketed type parameters for Result");
243                        }
244                    } else {
245                        panic!("Return type must be a Result");
246                    }
247                } else {
248                    panic!("Invalid return type");
249                }
250            } else {
251                panic!("Invalid return type");
252            }
253        }
254        _ => panic!("Function must have a return type"),
255    };
256
257    // Generate PascalCase struct name from the function name
258    let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
259
260    // Use provided description or generate a default one
261    let tool_description = match args.description {
262        Some(desc) => quote! { #desc.to_string() },
263        None => quote! { format!("Function to {}", Self::NAME) },
264    };
265
266    // Extract parameter names, types, and descriptions
267    let mut param_names = Vec::new();
268    let mut param_types = Vec::new();
269    let mut param_descriptions = Vec::new();
270    let mut json_types = Vec::new();
271
272    let required_args = args.required;
273
274    for arg in input_fn.sig.inputs.iter() {
275        if let syn::FnArg::Typed(pat_type) = arg
276            && let syn::Pat::Ident(param_ident) = &*pat_type.pat
277        {
278            let param_name = &param_ident.ident;
279            let param_name_str = param_name.to_string();
280            let ty = &pat_type.ty;
281            let default_parameter_description = format!("Parameter {param_name_str}");
282            let description = args
283                .param_descriptions
284                .get(&param_name_str)
285                .map(|s| s.to_owned())
286                .unwrap_or(default_parameter_description);
287
288            param_names.push(param_name);
289            param_types.push(ty);
290            param_descriptions.push(description);
291            json_types.push(get_json_type(ty));
292        }
293    }
294
295    let params_struct_name = format_ident!("{}Parameters", struct_name);
296    let static_name = format_ident!("{}", fn_name_str.to_uppercase());
297
298    // Generate the call implementation based on whether the function is async
299    let call_impl = if is_async {
300        quote! {
301            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
302                #fn_name(#(args.#param_names,)*).await
303            }
304        }
305    } else {
306        quote! {
307            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
308                #fn_name(#(args.#param_names,)*)
309            }
310        }
311    };
312
313    let expanded = quote! {
314        #[derive(serde::Deserialize)]
315        #vis struct #params_struct_name {
316            #(#vis #param_names: #param_types,)*
317        }
318
319        #input_fn
320
321        #[derive(Default)]
322        #vis struct #struct_name;
323
324        impl rig::tool::Tool for #struct_name {
325            const NAME: &'static str = #fn_name_str;
326
327            type Args = #params_struct_name;
328            type Output = #output_type;
329            type Error = #error_type;
330
331            fn name(&self) -> String {
332                #fn_name_str.to_string()
333            }
334
335            async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
336                let parameters = serde_json::json!({
337                    "type": "object",
338                    "properties": {
339                        #(
340                            stringify!(#param_names): {
341                                #json_types,
342                                "description": #param_descriptions
343                            }
344                        ),*
345                    },
346                    "required": [#(#required_args),*]
347                });
348
349                rig::completion::ToolDefinition {
350                    name: #fn_name_str.to_string(),
351                    description: #tool_description.to_string(),
352                    parameters,
353                }
354            }
355
356            #call_impl
357        }
358
359        #vis static #static_name: #struct_name = #struct_name;
360    };
361
362    TokenStream::from(expanded)
363}