1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use std::{collections::HashMap, ops::Deref};
7use syn::{
8 DeriveInput, Expr, ExprLit, Ident, Lit, Meta, PathArguments, ReturnType, Token, Type,
9 parse::{Parse, ParseStream},
10 parse_macro_input,
11 punctuated::Punctuated,
12};
13
14mod basic;
15mod client;
16mod custom;
17mod embed;
18
19pub(crate) const EMBED: &str = "embed";
20
21#[proc_macro_derive(ProviderClient, attributes(client))]
22pub fn derive_provider_client(input: TokenStream) -> TokenStream {
23 client::provider_client(input)
24}
25
26#[proc_macro_derive(Embed, attributes(embed))]
41pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
42 let mut input = parse_macro_input!(item as DeriveInput);
43
44 embed::expand_derive_embedding(&mut input)
45 .unwrap_or_else(syn::Error::into_compile_error)
46 .into()
47}
48
49struct MacroArgs {
50 description: Option<String>,
51 param_descriptions: HashMap<String, String>,
52 required: Vec<String>,
53}
54
55impl Parse for MacroArgs {
56 fn parse(input: ParseStream) -> syn::Result<Self> {
57 let mut description = None;
58 let mut param_descriptions = HashMap::new();
59 let mut required = Vec::new();
60
61 if input.is_empty() {
63 return Ok(MacroArgs {
64 description,
65 param_descriptions,
66 required,
67 });
68 }
69
70 let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
71
72 for meta in meta_list {
73 match meta {
74 Meta::NameValue(nv) => {
75 let ident = nv.path.get_ident().unwrap().to_string();
76 if let Expr::Lit(ExprLit {
77 lit: Lit::Str(lit_str),
78 ..
79 }) = nv.value
80 && ident.as_str() == "description"
81 {
82 description = Some(lit_str.value());
83 }
84 }
85 Meta::List(list) if list.path.is_ident("params") => {
86 let nested: Punctuated<Meta, Token![,]> =
87 list.parse_args_with(Punctuated::parse_terminated)?;
88
89 for meta in nested {
90 if let Meta::NameValue(nv) = meta
91 && let Expr::Lit(ExprLit {
92 lit: Lit::Str(lit_str),
93 ..
94 }) = nv.value
95 {
96 let param_name = nv.path.get_ident().unwrap().to_string();
97 param_descriptions.insert(param_name, lit_str.value());
98 }
99 }
100 }
101 Meta::List(list) if list.path.is_ident("required") => {
102 let required_variables: Punctuated<Ident, Token![,]> =
103 list.parse_args_with(Punctuated::parse_terminated)?;
104
105 required_variables.into_iter().for_each(|x| {
106 required.push(x.to_string());
107 });
108 }
109 _ => {}
110 }
111 }
112
113 Ok(MacroArgs {
114 description,
115 param_descriptions,
116 required,
117 })
118 }
119}
120
121fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
122 match ty {
123 Type::Path(type_path) => {
124 let segment = &type_path.path.segments[0];
125 let type_name = segment.ident.to_string();
126
127 if type_name == "Vec" {
129 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
130 && let syn::GenericArgument::Type(inner_type) = &args.args[0]
131 {
132 let inner_json_type = get_json_type(inner_type);
133 return quote! {
134 "type": "array",
135 "items": { #inner_json_type }
136 };
137 }
138 return quote! { "type": "array" };
139 }
140
141 match type_name.as_str() {
143 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
144 quote! { "type": "number" }
145 }
146 "String" | "str" => {
147 quote! { "type": "string" }
148 }
149 "bool" => {
150 quote! { "type": "boolean" }
151 }
152 _ => {
154 quote! { "type": "object" }
155 }
156 }
157 }
158 _ => {
159 quote! { "type": "object" }
160 }
161 }
162}
163
164#[proc_macro_attribute]
215pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
216 let args = parse_macro_input!(args as MacroArgs);
217 let input_fn = parse_macro_input!(input as syn::ItemFn);
218
219 let fn_name = &input_fn.sig.ident;
221 let fn_name_str = fn_name.to_string();
222 let vis = &input_fn.vis;
223 let is_async = input_fn.sig.asyncness.is_some();
224
225 let return_type = &input_fn.sig.output;
227 let (output_type, error_type) = match return_type {
228 ReturnType::Type(_, ty) => {
229 if let Type::Path(type_path) = ty.deref() {
230 if let Some(last_segment) = type_path.path.segments.last() {
231 if last_segment.ident == "Result" {
232 if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
233 if args.args.len() == 2 {
234 let output = args.args.first().unwrap();
235 let error = args.args.last().unwrap();
236
237 (quote!(#output), quote!(#error))
238 } else {
239 panic!("Expected Result with two type parameters");
240 }
241 } else {
242 panic!("Expected angle bracketed type parameters for Result");
243 }
244 } else {
245 panic!("Return type must be a Result");
246 }
247 } else {
248 panic!("Invalid return type");
249 }
250 } else {
251 panic!("Invalid return type");
252 }
253 }
254 _ => panic!("Function must have a return type"),
255 };
256
257 let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
259
260 let tool_description = match args.description {
262 Some(desc) => quote! { #desc.to_string() },
263 None => quote! { format!("Function to {}", Self::NAME) },
264 };
265
266 let mut param_names = Vec::new();
268 let mut param_types = Vec::new();
269 let mut param_descriptions = Vec::new();
270 let mut json_types = Vec::new();
271
272 let required_args = args.required;
273
274 for arg in input_fn.sig.inputs.iter() {
275 if let syn::FnArg::Typed(pat_type) = arg
276 && let syn::Pat::Ident(param_ident) = &*pat_type.pat
277 {
278 let param_name = ¶m_ident.ident;
279 let param_name_str = param_name.to_string();
280 let ty = &pat_type.ty;
281 let default_parameter_description = format!("Parameter {param_name_str}");
282 let description = args
283 .param_descriptions
284 .get(¶m_name_str)
285 .map(|s| s.to_owned())
286 .unwrap_or(default_parameter_description);
287
288 param_names.push(param_name);
289 param_types.push(ty);
290 param_descriptions.push(description);
291 json_types.push(get_json_type(ty));
292 }
293 }
294
295 let params_struct_name = format_ident!("{}Parameters", struct_name);
296 let static_name = format_ident!("{}", fn_name_str.to_uppercase());
297
298 let call_impl = if is_async {
300 quote! {
301 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
302 #fn_name(#(args.#param_names,)*).await
303 }
304 }
305 } else {
306 quote! {
307 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
308 #fn_name(#(args.#param_names,)*)
309 }
310 }
311 };
312
313 let expanded = quote! {
314 #[derive(serde::Deserialize)]
315 #vis struct #params_struct_name {
316 #(#vis #param_names: #param_types,)*
317 }
318
319 #input_fn
320
321 #[derive(Default)]
322 #vis struct #struct_name;
323
324 impl rig::tool::Tool for #struct_name {
325 const NAME: &'static str = #fn_name_str;
326
327 type Args = #params_struct_name;
328 type Output = #output_type;
329 type Error = #error_type;
330
331 fn name(&self) -> String {
332 #fn_name_str.to_string()
333 }
334
335 async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
336 let parameters = serde_json::json!({
337 "type": "object",
338 "properties": {
339 #(
340 stringify!(#param_names): {
341 #json_types,
342 "description": #param_descriptions
343 }
344 ),*
345 },
346 "required": [#(#required_args),*]
347 });
348
349 rig::completion::ToolDefinition {
350 name: #fn_name_str.to_string(),
351 description: #tool_description.to_string(),
352 parameters,
353 }
354 }
355
356 #call_impl
357 }
358
359 #vis static #static_name: #struct_name = #struct_name;
360 };
361
362 TokenStream::from(expanded)
363}