1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use std::{collections::HashMap, ops::Deref};
7use syn::{
8 parse::{Parse, ParseStream},
9 parse_macro_input,
10 punctuated::Punctuated,
11 DeriveInput, Expr, ExprLit, Lit, Meta, PathArguments, ReturnType, Token, Type,
12};
13
14mod basic;
15mod client;
16mod custom;
17mod embed;
18
19pub(crate) const EMBED: &str = "embed";
20
21#[proc_macro_derive(ProviderClient, attributes(client))]
22pub fn derive_provider_client(input: TokenStream) -> TokenStream {
23 client::provider_client(input)
24}
25
26#[proc_macro_derive(Embed, attributes(embed))]
30pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
31 let mut input = parse_macro_input!(item as DeriveInput);
32
33 embed::expand_derive_embedding(&mut input)
34 .unwrap_or_else(syn::Error::into_compile_error)
35 .into()
36}
37
38struct MacroArgs {
39 description: Option<String>,
40 param_descriptions: HashMap<String, String>,
41}
42
43impl Parse for MacroArgs {
44 fn parse(input: ParseStream) -> syn::Result<Self> {
45 let mut description = None;
46 let mut param_descriptions = HashMap::new();
47
48 if input.is_empty() {
50 return Ok(MacroArgs {
51 description,
52 param_descriptions,
53 });
54 }
55
56 let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
57
58 for meta in meta_list {
59 match meta {
60 Meta::NameValue(nv) => {
61 let ident = nv.path.get_ident().unwrap().to_string();
62 if let Expr::Lit(ExprLit {
63 lit: Lit::Str(lit_str),
64 ..
65 }) = nv.value
66 {
67 if ident.as_str() == "description" {
68 description = Some(lit_str.value());
69 }
70 }
71 }
72 Meta::List(list) if list.path.is_ident("params") => {
73 let nested: Punctuated<Meta, Token![,]> =
74 list.parse_args_with(Punctuated::parse_terminated)?;
75
76 for meta in nested {
77 if let Meta::NameValue(nv) = meta {
78 if let Expr::Lit(ExprLit {
79 lit: Lit::Str(lit_str),
80 ..
81 }) = nv.value
82 {
83 let param_name = nv.path.get_ident().unwrap().to_string();
84 param_descriptions.insert(param_name, lit_str.value());
85 }
86 }
87 }
88 }
89 _ => {}
90 }
91 }
92
93 Ok(MacroArgs {
94 description,
95 param_descriptions,
96 })
97 }
98}
99
100fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
101 match ty {
102 Type::Path(type_path) => {
103 let segment = &type_path.path.segments[0];
104 let type_name = segment.ident.to_string();
105
106 if type_name == "Vec" {
108 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
109 if let syn::GenericArgument::Type(inner_type) = &args.args[0] {
110 let inner_json_type = get_json_type(inner_type);
111 return quote! {
112 "type": "array",
113 "items": { #inner_json_type }
114 };
115 }
116 }
117 return quote! { "type": "array" };
118 }
119
120 match type_name.as_str() {
122 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
123 quote! { "type": "number" }
124 }
125 "String" | "str" => {
126 quote! { "type": "string" }
127 }
128 "bool" => {
129 quote! { "type": "boolean" }
130 }
131 _ => {
133 quote! { "type": "object" }
134 }
135 }
136 }
137 _ => {
138 quote! { "type": "object" }
139 }
140 }
141}
142
143#[proc_macro_attribute]
194pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
195 let args = parse_macro_input!(args as MacroArgs);
196 let input_fn = parse_macro_input!(input as syn::ItemFn);
197
198 let fn_name = &input_fn.sig.ident;
200 let fn_name_str = fn_name.to_string();
201 let is_async = input_fn.sig.asyncness.is_some();
202
203 let return_type = &input_fn.sig.output;
205 let output_type = match return_type {
206 ReturnType::Type(_, ty) => {
207 if let Type::Path(type_path) = ty.deref() {
208 if let Some(last_segment) = type_path.path.segments.last() {
209 if last_segment.ident == "Result" {
210 if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
211 if args.args.len() == 2 {
212 let output = args.args.first().unwrap();
213 let error = args.args.last().unwrap();
214
215 let error_str = quote!(#error).to_string().replace(" ", "");
217 if !error_str.contains("rig::tool::ToolError") {
218 panic!("Expected rig::tool::ToolError as second type parameter but found {}", error_str);
219 }
220
221 quote!(#output)
222 } else {
223 panic!("Expected Result with two type parameters");
224 }
225 } else {
226 panic!("Expected angle bracketed type parameters for Result");
227 }
228 } else {
229 panic!("Return type must be a Result");
230 }
231 } else {
232 panic!("Invalid return type");
233 }
234 } else {
235 panic!("Invalid return type");
236 }
237 }
238 _ => panic!("Function must have a return type"),
239 };
240
241 let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
243
244 let tool_description = match args.description {
246 Some(desc) => quote! { #desc.to_string() },
247 None => quote! { format!("Function to {}", Self::NAME) },
248 };
249
250 let mut param_names = Vec::new();
252 let mut param_types = Vec::new();
253 let mut param_descriptions = Vec::new();
254 let mut json_types = Vec::new();
255
256 for arg in input_fn.sig.inputs.iter() {
257 if let syn::FnArg::Typed(pat_type) = arg {
258 if let syn::Pat::Ident(param_ident) = &*pat_type.pat {
259 let param_name = ¶m_ident.ident;
260 let param_name_str = param_name.to_string();
261 let ty = &pat_type.ty;
262 let default_parameter_description = format!("Parameter {}", param_name_str);
263 let description = args
264 .param_descriptions
265 .get(¶m_name_str)
266 .map(|s| s.to_owned())
267 .unwrap_or(default_parameter_description);
268
269 param_names.push(param_name);
270 param_types.push(ty);
271 param_descriptions.push(description);
272 json_types.push(get_json_type(ty));
273 }
274 }
275 }
276
277 let params_struct_name = format_ident!("{}Parameters", struct_name);
278 let static_name = format_ident!("{}", fn_name_str.to_uppercase());
279
280 let call_impl = if is_async {
282 quote! {
283 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
284 #fn_name(#(args.#param_names,)*).await.map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
285 }
286 }
287 } else {
288 quote! {
289 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
290 #fn_name(#(args.#param_names,)*).map_err(|e| rig::tool::ToolError::ToolCallError(e.into()))
291 }
292 }
293 };
294
295 let expanded = quote! {
296 #[derive(serde::Deserialize)]
297 pub(crate) struct #params_struct_name {
298 #(#param_names: #param_types,)*
299 }
300
301 #input_fn
302
303 #[derive(Default)]
304 pub(crate) struct #struct_name;
305
306 impl rig::tool::Tool for #struct_name {
307 const NAME: &'static str = #fn_name_str;
308
309 type Args = #params_struct_name;
310 type Output = #output_type;
311 type Error = rig::tool::ToolError;
312
313 fn name(&self) -> String {
314 #fn_name_str.to_string()
315 }
316
317 async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
318 let parameters = serde_json::json!({
319 "type": "object",
320 "properties": {
321 #(
322 stringify!(#param_names): {
323 #json_types,
324 "description": #param_descriptions
325 }
326 ),*
327 }
328 });
329
330 rig::completion::ToolDefinition {
331 name: #fn_name_str.to_string(),
332 description: #tool_description.to_string(),
333 parameters,
334 }
335 }
336
337 #call_impl
338 }
339
340 pub static #static_name: #struct_name = #struct_name;
341 };
342
343 TokenStream::from(expanded)
344}