rig_derive/
lib.rs

1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use std::{collections::HashMap, ops::Deref};
7use syn::{
8    parse::{Parse, ParseStream},
9    parse_macro_input,
10    punctuated::Punctuated,
11    DeriveInput, Expr, ExprLit, Lit, Meta, PathArguments, ReturnType, Token, Type,
12};
13
14mod basic;
15mod client;
16mod custom;
17mod embed;
18
19pub(crate) const EMBED: &str = "embed";
20
21#[proc_macro_derive(ProviderClient, attributes(client))]
22pub fn derive_provider_client(input: TokenStream) -> TokenStream {
23    client::provider_client(input)
24}
25
26/// References:
27/// <https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro>
28/// <https://doc.rust-lang.org/reference/procedural-macros.html>
29#[proc_macro_derive(Embed, attributes(embed))]
30pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
31    let mut input = parse_macro_input!(item as DeriveInput);
32
33    embed::expand_derive_embedding(&mut input)
34        .unwrap_or_else(syn::Error::into_compile_error)
35        .into()
36}
37
38struct MacroArgs {
39    description: Option<String>,
40    param_descriptions: HashMap<String, String>,
41}
42
43impl Parse for MacroArgs {
44    fn parse(input: ParseStream) -> syn::Result<Self> {
45        let mut description = None;
46        let mut param_descriptions = HashMap::new();
47
48        // If the input is empty, return default values
49        if input.is_empty() {
50            return Ok(MacroArgs {
51                description,
52                param_descriptions,
53            });
54        }
55
56        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
57
58        for meta in meta_list {
59            match meta {
60                Meta::NameValue(nv) => {
61                    let ident = nv.path.get_ident().unwrap().to_string();
62                    if let Expr::Lit(ExprLit {
63                        lit: Lit::Str(lit_str),
64                        ..
65                    }) = nv.value
66                    {
67                        if ident.as_str() == "description" {
68                            description = Some(lit_str.value());
69                        }
70                    }
71                }
72                Meta::List(list) if list.path.is_ident("params") => {
73                    let nested: Punctuated<Meta, Token![,]> =
74                        list.parse_args_with(Punctuated::parse_terminated)?;
75
76                    for meta in nested {
77                        if let Meta::NameValue(nv) = meta {
78                            if let Expr::Lit(ExprLit {
79                                lit: Lit::Str(lit_str),
80                                ..
81                            }) = nv.value
82                            {
83                                let param_name = nv.path.get_ident().unwrap().to_string();
84                                param_descriptions.insert(param_name, lit_str.value());
85                            }
86                        }
87                    }
88                }
89                _ => {}
90            }
91        }
92
93        Ok(MacroArgs {
94            description,
95            param_descriptions,
96        })
97    }
98}
99
100fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
101    match ty {
102        Type::Path(type_path) => {
103            let segment = &type_path.path.segments[0];
104            let type_name = segment.ident.to_string();
105
106            // Handle Vec types
107            if type_name == "Vec" {
108                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
109                    if let syn::GenericArgument::Type(inner_type) = &args.args[0] {
110                        let inner_json_type = get_json_type(inner_type);
111                        return quote! {
112                            "type": "array",
113                            "items": { #inner_json_type }
114                        };
115                    }
116                }
117                return quote! { "type": "array" };
118            }
119
120            // Handle primitive types
121            match type_name.as_str() {
122                "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
123                    quote! { "type": "number" }
124                }
125                "String" | "str" => {
126                    quote! { "type": "string" }
127                }
128                "bool" => {
129                    quote! { "type": "boolean" }
130                }
131                // Handle other types as objects
132                _ => {
133                    quote! { "type": "object" }
134                }
135            }
136        }
137        _ => {
138            quote! { "type": "object" }
139        }
140    }
141}
142
143/// A procedural macro that transforms a function into a `rig::tool::Tool` that can be used with a `rig::agent::Agent`.
144///
145/// # Examples
146///
147/// Basic usage:
148/// ```rust
149/// use rig_derive::rig_tool;
150///
151/// #[rig_tool]
152/// fn add(a: i32, b: i32) -> Result<i32, rig::tool::ToolError> {
153///     Ok(a + b)
154/// }
155/// ```
156///
157/// With description:
158/// ```rust
159/// use rig_derive::rig_tool;
160///
161/// #[rig_tool(description = "Perform basic arithmetic operations")]
162/// fn calculator(x: i32, y: i32, operation: String) -> Result<i32, rig::tool::ToolError> {
163///     match operation.as_str() {
164///         "add" => Ok(x + y),
165///         "subtract" => Ok(x - y),
166///         "multiply" => Ok(x * y),
167///         "divide" => Ok(x / y),
168///         _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
169///     }
170/// }
171/// ```
172///
173/// With parameter descriptions:
174/// ```rust
175/// use rig_derive::rig_tool;
176///
177/// #[rig_tool(
178///     description = "A tool that performs string operations",
179///     params(
180///         text = "The input text to process",
181///         operation = "The operation to perform (uppercase, lowercase, reverse)"
182///     )
183/// )]
184/// fn string_processor(text: String, operation: String) -> Result<String, rig::tool::ToolError> {
185///     match operation.as_str() {
186///         "uppercase" => Ok(text.to_uppercase()),
187///         "lowercase" => Ok(text.to_lowercase()),
188///         "reverse" => Ok(text.chars().rev().collect()),
189///         _ => Err(rig::tool::ToolError::ToolCallError("Unknown operation".into())),
190///     }
191/// }
192/// ```
193#[proc_macro_attribute]
194pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
195    let args = parse_macro_input!(args as MacroArgs);
196    let input_fn = parse_macro_input!(input as syn::ItemFn);
197
198    // Extract function details
199    let fn_name = &input_fn.sig.ident;
200    let fn_name_str = fn_name.to_string();
201    let is_async = input_fn.sig.asyncness.is_some();
202
203    // Extract return type and get Output and Error types from Result<T, E>
204    let return_type = &input_fn.sig.output;
205    let output_type = match return_type {
206        ReturnType::Type(_, ty) => {
207            if let Type::Path(type_path) = ty.deref() {
208                if let Some(last_segment) = type_path.path.segments.last() {
209                    if last_segment.ident == "Result" {
210                        if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
211                            if args.args.len() == 2 {
212                                let output = args.args.first().unwrap();
213                                let error = args.args.last().unwrap();
214
215                                // Convert the error type to a string for comparison
216                                let error_str = quote!(#error).to_string().replace(" ", "");
217                                if !error_str.contains("rig::tool::ToolError") {
218                                    panic!("Expected rig::tool::ToolError as second type parameter but found {}", error_str);
219                                }
220
221                                quote!(#output)
222                            } else {
223                                panic!("Expected Result with two type parameters");
224                            }
225                        } else {
226                            panic!("Expected angle bracketed type parameters for Result");
227                        }
228                    } else {
229                        panic!("Return type must be a Result");
230                    }
231                } else {
232                    panic!("Invalid return type");
233                }
234            } else {
235                panic!("Invalid return type");
236            }
237        }
238        _ => panic!("Function must have a return type"),
239    };
240
241    // Generate PascalCase struct name from the function name
242    let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
243
244    // Use provided description or generate a default one
245    let tool_description = match args.description {
246        Some(desc) => quote! { #desc.to_string() },
247        None => quote! { format!("Function to {}", Self::NAME) },
248    };
249
250    // Extract parameter names, types, and descriptions
251    let mut param_names = Vec::new();
252    let mut param_types = Vec::new();
253    let mut param_descriptions = Vec::new();
254    let mut json_types = Vec::new();
255
256    for arg in input_fn.sig.inputs.iter() {
257        if let syn::FnArg::Typed(pat_type) = arg {
258            if let syn::Pat::Ident(param_ident) = &*pat_type.pat {
259                let param_name = &param_ident.ident;
260                let param_name_str = param_name.to_string();
261                let ty = &pat_type.ty;
262                let default_parameter_description = format!("Parameter {}", param_name_str);
263                let description = args
264                    .param_descriptions
265                    .get(&param_name_str)
266                    .map(|s| s.to_owned())
267                    .unwrap_or(default_parameter_description);
268
269                param_names.push(param_name);
270                param_types.push(ty);
271                param_descriptions.push(description);
272                json_types.push(get_json_type(ty));
273            }
274        }
275    }
276
277    let params_struct_name = format_ident!("{}Parameters", struct_name);
278    let static_name = format_ident!("{}", fn_name_str.to_uppercase());
279
280    // Generate the call implementation based on whether the function is async
281    let call_impl = if is_async {
282        quote! {
283            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
284                #fn_name(#(args.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
285            }
286        }
287    } else {
288        quote! {
289            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
290                #fn_name(#(args.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
291            }
292        }
293    };
294
295    let expanded = quote! {
296        #[derive(serde::Deserialize)]
297        pub(crate) struct #params_struct_name {
298            #(#param_names: #param_types,)*
299        }
300
301        #input_fn
302
303        #[derive(Default)]
304        pub(crate) struct #struct_name;
305
306        impl rig::tool::Tool for #struct_name {
307            const NAME: &'static str = #fn_name_str;
308
309            type Args = #params_struct_name;
310            type Output = #output_type;
311            type Error = rig::tool::ToolError;
312
313            fn name(&self) -> String {
314                #fn_name_str.to_string()
315            }
316
317            async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
318                let parameters = serde_json::json!({
319                    "type": "object",
320                    "properties": {
321                        #(
322                            stringify!(#param_names): {
323                                #json_types,
324                                "description": #param_descriptions
325                            }
326                        ),*
327                    }
328                });
329
330                rig::completion::ToolDefinition {
331                    name: #fn_name_str.to_string(),
332                    description: #tool_description.to_string(),
333                    parameters,
334                }
335            }
336
337            #call_impl
338        }
339
340        pub static #static_name: #struct_name = #struct_name;
341    };
342
343    TokenStream::from(expanded)
344}