1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use std::{collections::HashMap, ops::Deref};
7use syn::{
8 DeriveInput, Expr, ExprLit, Ident, Lit, Meta, PathArguments, ReturnType, Token, Type,
9 parse::{Parse, ParseStream},
10 parse_macro_input,
11 punctuated::Punctuated,
12};
13
14mod basic;
15mod client;
16mod custom;
17mod embed;
18
19pub(crate) const EMBED: &str = "embed";
20
21#[proc_macro_derive(ProviderClient, attributes(client))]
22pub fn derive_provider_client(input: TokenStream) -> TokenStream {
23 client::provider_client(input)
24}
25
26#[proc_macro_derive(Embed, attributes(embed))]
44pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
45 let mut input = parse_macro_input!(item as DeriveInput);
46
47 embed::expand_derive_embedding(&mut input)
48 .unwrap_or_else(syn::Error::into_compile_error)
49 .into()
50}
51
52struct MacroArgs {
53 name: Option<String>,
54 description: Option<String>,
55 param_descriptions: HashMap<String, String>,
56 required: Vec<String>,
57}
58
59fn parse_string_literal(expr: &Expr, field_name: &str) -> syn::Result<String> {
60 match expr {
61 Expr::Lit(ExprLit {
62 lit: Lit::Str(lit_str),
63 ..
64 }) => Ok(lit_str.value()),
65 _ => Err(syn::Error::new_spanned(
66 expr,
67 format!("`{field_name}` must be a string literal"),
68 )),
69 }
70}
71
72fn validate_explicit_tool_name(name: &str, expr: &Expr) -> syn::Result<()> {
73 if name.is_empty() || name.len() > 64 {
74 return Err(syn::Error::new_spanned(
75 expr,
76 "`name` must be between 1 and 64 characters long",
77 ));
78 }
79
80 let mut chars = name.chars();
81 let Some(first_char) = chars.next() else {
82 return Err(syn::Error::new_spanned(
83 expr,
84 "`name` must be between 1 and 64 characters long",
85 ));
86 };
87
88 if !first_char.is_ascii_alphabetic() && first_char != '_' {
89 return Err(syn::Error::new_spanned(
90 expr,
91 "`name` must start with an ASCII letter or underscore",
92 ));
93 }
94
95 if chars.any(|ch| !ch.is_ascii_alphanumeric() && ch != '_' && ch != '-') {
96 return Err(syn::Error::new_spanned(
97 expr,
98 "`name` may only contain ASCII letters, digits, underscores, or hyphens",
99 ));
100 }
101
102 Ok(())
103}
104
105impl Parse for MacroArgs {
106 fn parse(input: ParseStream) -> syn::Result<Self> {
107 let mut name = None;
108 let mut description = None;
109 let mut param_descriptions = HashMap::new();
110 let mut required = Vec::new();
111
112 if input.is_empty() {
114 return Ok(MacroArgs {
115 name,
116 description,
117 param_descriptions,
118 required,
119 });
120 }
121
122 let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
123
124 for meta in meta_list {
125 match meta {
126 Meta::NameValue(nv) => {
127 let ident = nv.path.get_ident().ok_or_else(|| {
128 syn::Error::new_spanned(
129 &nv.path,
130 "unsupported top-level #[rig_tool] argument",
131 )
132 })?;
133
134 match ident.to_string().as_str() {
135 "name" => {
136 let parsed_name = parse_string_literal(&nv.value, "name")?;
137 validate_explicit_tool_name(&parsed_name, &nv.value)?;
138 name = Some(parsed_name);
139 }
140 "description" => {
141 description = Some(parse_string_literal(&nv.value, "description")?);
142 }
143 _ => {
144 return Err(syn::Error::new_spanned(
145 &nv.path,
146 format!("unsupported top-level #[rig_tool] argument `{}`", ident),
147 ));
148 }
149 }
150 }
151 Meta::List(list) => {
152 let ident = list.path.get_ident().ok_or_else(|| {
153 syn::Error::new_spanned(
154 &list.path,
155 "unsupported top-level #[rig_tool] argument",
156 )
157 })?;
158
159 match ident.to_string().as_str() {
160 "params" => {
161 let nested: Punctuated<Meta, Token![,]> =
162 list.parse_args_with(Punctuated::parse_terminated)?;
163
164 for meta in nested {
165 if let Meta::NameValue(nv) = meta
166 && let Expr::Lit(ExprLit {
167 lit: Lit::Str(lit_str),
168 ..
169 }) = nv.value
170 {
171 let Some(param_ident) = nv.path.get_ident() else {
172 return Err(syn::Error::new_spanned(
173 &nv.path,
174 "parameter descriptions must use identifier keys",
175 ));
176 };
177 let param_name = param_ident.to_string();
178 param_descriptions.insert(param_name, lit_str.value());
179 }
180 }
181 }
182 "required" => {
183 let required_variables: Punctuated<Ident, Token![,]> =
184 list.parse_args_with(Punctuated::parse_terminated)?;
185
186 required_variables.into_iter().for_each(|x| {
187 required.push(x.to_string());
188 });
189 }
190 _ => {
191 return Err(syn::Error::new_spanned(
192 &list.path,
193 format!("unsupported top-level #[rig_tool] argument `{}`", ident),
194 ));
195 }
196 }
197 }
198 Meta::Path(path) => {
199 let message = if let Some(ident) = path.get_ident() {
200 format!("unsupported top-level #[rig_tool] argument `{ident}`")
201 } else {
202 "unsupported top-level #[rig_tool] argument".to_string()
203 };
204
205 return Err(syn::Error::new_spanned(path, message));
206 }
207 }
208 }
209
210 Ok(MacroArgs {
211 name,
212 description,
213 param_descriptions,
214 required,
215 })
216 }
217}
218
219fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
220 match ty {
221 Type::Path(type_path) => {
222 let Some(segment) = type_path.path.segments.first() else {
223 return quote! { "type": "object" };
224 };
225 let type_name = segment.ident.to_string();
226
227 if type_name == "Vec" {
229 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
230 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
231 {
232 let inner_json_type = get_json_type(inner_type);
233 return quote! {
234 "type": "array",
235 "items": { #inner_json_type }
236 };
237 }
238 return quote! { "type": "array" };
239 }
240
241 match type_name.as_str() {
243 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
244 quote! { "type": "number" }
245 }
246 "String" | "str" => {
247 quote! { "type": "string" }
248 }
249 "bool" => {
250 quote! { "type": "boolean" }
251 }
252 _ => {
254 quote! { "type": "object" }
255 }
256 }
257 }
258 _ => {
259 quote! { "type": "object" }
260 }
261 }
262}
263
264fn result_type_tokens(
265 return_type: &ReturnType,
266) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
267 let ReturnType::Type(_, ty) = return_type else {
268 return Err(syn::Error::new_spanned(
269 return_type,
270 "function must have a return type of Result<T, E>",
271 ));
272 };
273
274 let Type::Path(type_path) = ty.deref() else {
275 return Err(syn::Error::new_spanned(
276 ty,
277 "return type must be Result<T, E>",
278 ));
279 };
280
281 let Some(last_segment) = type_path.path.segments.last() else {
282 return Err(syn::Error::new_spanned(
283 &type_path.path,
284 "return type must be Result<T, E>",
285 ));
286 };
287
288 if last_segment.ident != "Result" {
289 return Err(syn::Error::new_spanned(
290 &last_segment.ident,
291 "return type must be Result<T, E>",
292 ));
293 }
294
295 let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
296 return Err(syn::Error::new_spanned(
297 &last_segment.arguments,
298 "expected angle-bracketed type parameters for Result<T, E>",
299 ));
300 };
301
302 let mut generic_args = args.args.iter();
303 let Some(output) = generic_args.next() else {
304 return Err(syn::Error::new_spanned(
305 &args.args,
306 "expected Result<T, E> with exactly two type parameters",
307 ));
308 };
309 let Some(error) = generic_args.next() else {
310 return Err(syn::Error::new_spanned(
311 &args.args,
312 "expected Result<T, E> with exactly two type parameters",
313 ));
314 };
315
316 if generic_args.next().is_some() {
317 return Err(syn::Error::new_spanned(
318 &args.args,
319 "expected Result<T, E> with exactly two type parameters",
320 ));
321 }
322
323 Ok((quote!(#output), quote!(#error)))
324}
325
326#[proc_macro_attribute]
390pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
391 let args = parse_macro_input!(args as MacroArgs);
392 let input_fn = parse_macro_input!(input as syn::ItemFn);
393
394 let fn_name = &input_fn.sig.ident;
396 let fn_name_str = fn_name.to_string();
397 let tool_name = args.name.clone().unwrap_or_else(|| fn_name_str.clone());
398 let vis = &input_fn.vis;
399 let is_async = input_fn.sig.asyncness.is_some();
400
401 let return_type = &input_fn.sig.output;
403 let (output_type, error_type) = match result_type_tokens(return_type) {
404 Ok(types) => types,
405 Err(error) => return error.into_compile_error().into(),
406 };
407
408 let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
410
411 let tool_description = match args.description {
413 Some(desc) => quote! { #desc.to_string() },
414 None => quote! { format!("Function to {}", Self::NAME) },
415 };
416
417 let mut param_names = Vec::new();
419 let mut param_types = Vec::new();
420 let mut param_descriptions = Vec::new();
421 let mut json_types = Vec::new();
422
423 let required_args = args.required;
424
425 for arg in input_fn.sig.inputs.iter() {
426 if let syn::FnArg::Typed(pat_type) = arg
427 && let syn::Pat::Ident(param_ident) = &*pat_type.pat
428 {
429 let param_name = ¶m_ident.ident;
430 let param_name_str = param_name.to_string();
431 let ty = &pat_type.ty;
432 let default_parameter_description = format!("Parameter {param_name_str}");
433 let description = args
434 .param_descriptions
435 .get(¶m_name_str)
436 .map(|s| s.to_owned())
437 .unwrap_or(default_parameter_description);
438
439 param_names.push(param_name);
440 param_types.push(ty);
441 param_descriptions.push(description);
442 json_types.push(get_json_type(ty));
443 }
444 }
445
446 let params_struct_name = format_ident!("{}Parameters", struct_name);
447 let static_name = format_ident!("{}", fn_name_str.to_uppercase());
448
449 let call_impl = if is_async {
451 quote! {
452 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
453 #fn_name(#(args.#param_names,)*).await
454 }
455 }
456 } else {
457 quote! {
458 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
459 #fn_name(#(args.#param_names,)*)
460 }
461 }
462 };
463
464 let expanded = quote! {
465 #[derive(serde::Deserialize)]
466 #vis struct #params_struct_name {
467 #(#vis #param_names: #param_types,)*
468 }
469
470 #input_fn
471
472 #[derive(Default)]
473 #vis struct #struct_name;
474
475 impl rig::tool::Tool for #struct_name {
476 const NAME: &'static str = #tool_name;
477
478 type Args = #params_struct_name;
479 type Output = #output_type;
480 type Error = #error_type;
481
482 fn name(&self) -> String {
483 #tool_name.to_string()
484 }
485
486 async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
487 let parameters = serde_json::json!({
488 "type": "object",
489 "properties": {
490 #(
491 stringify!(#param_names): {
492 #json_types,
493 "description": #param_descriptions
494 }
495 ),*
496 },
497 "required": [#(#required_args),*]
498 });
499
500 rig::completion::ToolDefinition {
501 name: #tool_name.to_string(),
502 description: #tool_description.to_string(),
503 parameters,
504 }
505 }
506
507 #call_impl
508 }
509
510 #vis static #static_name: #struct_name = #struct_name;
511 };
512
513 TokenStream::from(expanded)
514}