1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use std::{collections::HashMap, ops::Deref};
7use syn::{
8 DeriveInput, Expr, ExprLit, Ident, Lit, Meta, PathArguments, ReturnType, Token, Type,
9 parse::{Parse, ParseStream},
10 parse_macro_input,
11 punctuated::Punctuated,
12};
13
14mod basic;
15mod client;
16mod custom;
17mod embed;
18
19pub(crate) const EMBED: &str = "embed";
20
21pub(crate) fn rig_core_path() -> proc_macro2::TokenStream {
22 match proc_macro_crate::crate_name("rig-core") {
23 Ok(proc_macro_crate::FoundCrate::Itself) => quote!(crate),
24 Ok(proc_macro_crate::FoundCrate::Name(name)) => {
25 let ident = format_ident!("{name}");
26 quote!(::#ident)
27 }
28 Err(_) => match proc_macro_crate::crate_name("rig") {
29 Ok(proc_macro_crate::FoundCrate::Itself) => quote!(crate),
30 Ok(proc_macro_crate::FoundCrate::Name(name)) => {
31 let ident = format_ident!("{name}");
32 quote!(::#ident)
33 }
34 Err(_) => quote!(::rig_core),
35 },
36 }
37}
38
39#[proc_macro_derive(ProviderClient, attributes(client))]
40pub fn derive_provider_client(input: TokenStream) -> TokenStream {
41 client::provider_client(input)
42}
43
44#[proc_macro_derive(Embed, attributes(embed))]
62pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
63 let mut input = parse_macro_input!(item as DeriveInput);
64
65 embed::expand_derive_embedding(&mut input)
66 .unwrap_or_else(syn::Error::into_compile_error)
67 .into()
68}
69
70struct MacroArgs {
71 name: Option<String>,
72 description: Option<String>,
73 param_descriptions: HashMap<String, String>,
74 required: Vec<String>,
75}
76
77fn parse_string_literal(expr: &Expr, field_name: &str) -> syn::Result<String> {
78 match expr {
79 Expr::Lit(ExprLit {
80 lit: Lit::Str(lit_str),
81 ..
82 }) => Ok(lit_str.value()),
83 _ => Err(syn::Error::new_spanned(
84 expr,
85 format!("`{field_name}` must be a string literal"),
86 )),
87 }
88}
89
90fn validate_explicit_tool_name(name: &str, expr: &Expr) -> syn::Result<()> {
91 if name.is_empty() || name.len() > 64 {
92 return Err(syn::Error::new_spanned(
93 expr,
94 "`name` must be between 1 and 64 characters long",
95 ));
96 }
97
98 let mut chars = name.chars();
99 let Some(first_char) = chars.next() else {
100 return Err(syn::Error::new_spanned(
101 expr,
102 "`name` must be between 1 and 64 characters long",
103 ));
104 };
105
106 if !first_char.is_ascii_alphabetic() && first_char != '_' {
107 return Err(syn::Error::new_spanned(
108 expr,
109 "`name` must start with an ASCII letter or underscore",
110 ));
111 }
112
113 if chars.any(|ch| !ch.is_ascii_alphanumeric() && ch != '_' && ch != '-') {
114 return Err(syn::Error::new_spanned(
115 expr,
116 "`name` may only contain ASCII letters, digits, underscores, or hyphens",
117 ));
118 }
119
120 Ok(())
121}
122
123impl Parse for MacroArgs {
124 fn parse(input: ParseStream) -> syn::Result<Self> {
125 let mut name = None;
126 let mut description = None;
127 let mut param_descriptions = HashMap::new();
128 let mut required = Vec::new();
129
130 if input.is_empty() {
132 return Ok(MacroArgs {
133 name,
134 description,
135 param_descriptions,
136 required,
137 });
138 }
139
140 let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
141
142 for meta in meta_list {
143 match meta {
144 Meta::NameValue(nv) => {
145 let ident = nv.path.get_ident().ok_or_else(|| {
146 syn::Error::new_spanned(
147 &nv.path,
148 "unsupported top-level #[rig_tool] argument",
149 )
150 })?;
151
152 match ident.to_string().as_str() {
153 "name" => {
154 let parsed_name = parse_string_literal(&nv.value, "name")?;
155 validate_explicit_tool_name(&parsed_name, &nv.value)?;
156 name = Some(parsed_name);
157 }
158 "description" => {
159 description = Some(parse_string_literal(&nv.value, "description")?);
160 }
161 _ => {
162 return Err(syn::Error::new_spanned(
163 &nv.path,
164 format!("unsupported top-level #[rig_tool] argument `{}`", ident),
165 ));
166 }
167 }
168 }
169 Meta::List(list) => {
170 let ident = list.path.get_ident().ok_or_else(|| {
171 syn::Error::new_spanned(
172 &list.path,
173 "unsupported top-level #[rig_tool] argument",
174 )
175 })?;
176
177 match ident.to_string().as_str() {
178 "params" => {
179 let nested: Punctuated<Meta, Token![,]> =
180 list.parse_args_with(Punctuated::parse_terminated)?;
181
182 for meta in nested {
183 if let Meta::NameValue(nv) = meta
184 && let Expr::Lit(ExprLit {
185 lit: Lit::Str(lit_str),
186 ..
187 }) = nv.value
188 {
189 let Some(param_ident) = nv.path.get_ident() else {
190 return Err(syn::Error::new_spanned(
191 &nv.path,
192 "parameter descriptions must use identifier keys",
193 ));
194 };
195 let param_name = param_ident.to_string();
196 param_descriptions.insert(param_name, lit_str.value());
197 }
198 }
199 }
200 "required" => {
201 let required_variables: Punctuated<Ident, Token![,]> =
202 list.parse_args_with(Punctuated::parse_terminated)?;
203
204 required_variables.into_iter().for_each(|x| {
205 required.push(x.to_string());
206 });
207 }
208 _ => {
209 return Err(syn::Error::new_spanned(
210 &list.path,
211 format!("unsupported top-level #[rig_tool] argument `{}`", ident),
212 ));
213 }
214 }
215 }
216 Meta::Path(path) => {
217 let message = if let Some(ident) = path.get_ident() {
218 format!("unsupported top-level #[rig_tool] argument `{ident}`")
219 } else {
220 "unsupported top-level #[rig_tool] argument".to_string()
221 };
222
223 return Err(syn::Error::new_spanned(path, message));
224 }
225 }
226 }
227
228 Ok(MacroArgs {
229 name,
230 description,
231 param_descriptions,
232 required,
233 })
234 }
235}
236
237fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
238 match ty {
239 Type::Path(type_path) => {
240 let Some(segment) = type_path.path.segments.first() else {
241 return quote! { "type": "object" };
242 };
243 let type_name = segment.ident.to_string();
244
245 if type_name == "Vec" {
247 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
248 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
249 {
250 let inner_json_type = get_json_type(inner_type);
251 return quote! {
252 "type": "array",
253 "items": { #inner_json_type }
254 };
255 }
256 return quote! { "type": "array" };
257 }
258
259 match type_name.as_str() {
261 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
262 quote! { "type": "number" }
263 }
264 "String" | "str" => {
265 quote! { "type": "string" }
266 }
267 "bool" => {
268 quote! { "type": "boolean" }
269 }
270 _ => {
272 quote! { "type": "object" }
273 }
274 }
275 }
276 _ => {
277 quote! { "type": "object" }
278 }
279 }
280}
281
282fn result_type_tokens(
283 return_type: &ReturnType,
284) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
285 let ReturnType::Type(_, ty) = return_type else {
286 return Err(syn::Error::new_spanned(
287 return_type,
288 "function must have a return type of Result<T, E>",
289 ));
290 };
291
292 let Type::Path(type_path) = ty.deref() else {
293 return Err(syn::Error::new_spanned(
294 ty,
295 "return type must be Result<T, E>",
296 ));
297 };
298
299 let Some(last_segment) = type_path.path.segments.last() else {
300 return Err(syn::Error::new_spanned(
301 &type_path.path,
302 "return type must be Result<T, E>",
303 ));
304 };
305
306 if last_segment.ident != "Result" {
307 return Err(syn::Error::new_spanned(
308 &last_segment.ident,
309 "return type must be Result<T, E>",
310 ));
311 }
312
313 let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
314 return Err(syn::Error::new_spanned(
315 &last_segment.arguments,
316 "expected angle-bracketed type parameters for Result<T, E>",
317 ));
318 };
319
320 let mut generic_args = args.args.iter();
321 let Some(output) = generic_args.next() else {
322 return Err(syn::Error::new_spanned(
323 &args.args,
324 "expected Result<T, E> with exactly two type parameters",
325 ));
326 };
327 let Some(error) = generic_args.next() else {
328 return Err(syn::Error::new_spanned(
329 &args.args,
330 "expected Result<T, E> with exactly two type parameters",
331 ));
332 };
333
334 if generic_args.next().is_some() {
335 return Err(syn::Error::new_spanned(
336 &args.args,
337 "expected Result<T, E> with exactly two type parameters",
338 ));
339 }
340
341 Ok((quote!(#output), quote!(#error)))
342}
343
344#[proc_macro_attribute]
408pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
409 let args = parse_macro_input!(args as MacroArgs);
410 let input_fn = parse_macro_input!(input as syn::ItemFn);
411
412 let fn_name = &input_fn.sig.ident;
414 let fn_name_str = fn_name.to_string();
415 let tool_name = args.name.clone().unwrap_or_else(|| fn_name_str.clone());
416 let vis = &input_fn.vis;
417 let is_async = input_fn.sig.asyncness.is_some();
418
419 let return_type = &input_fn.sig.output;
421 let (output_type, error_type) = match result_type_tokens(return_type) {
422 Ok(types) => types,
423 Err(error) => return error.into_compile_error().into(),
424 };
425
426 let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
428
429 let tool_description = match args.description {
431 Some(desc) => quote! { #desc.to_string() },
432 None => quote! { format!("Function to {}", Self::NAME) },
433 };
434
435 let mut param_names = Vec::new();
437 let mut param_types = Vec::new();
438 let mut param_descriptions = Vec::new();
439 let mut json_types = Vec::new();
440
441 let required_args = args.required;
442
443 for arg in input_fn.sig.inputs.iter() {
444 if let syn::FnArg::Typed(pat_type) = arg
445 && let syn::Pat::Ident(param_ident) = &*pat_type.pat
446 {
447 let param_name = ¶m_ident.ident;
448 let param_name_str = param_name.to_string();
449 let ty = &pat_type.ty;
450 let default_parameter_description = format!("Parameter {param_name_str}");
451 let description = args
452 .param_descriptions
453 .get(¶m_name_str)
454 .map(|s| s.to_owned())
455 .unwrap_or(default_parameter_description);
456
457 param_names.push(param_name);
458 param_types.push(ty);
459 param_descriptions.push(description);
460 json_types.push(get_json_type(ty));
461 }
462 }
463
464 let params_struct_name = format_ident!("{}Parameters", struct_name);
465 let static_name = format_ident!("{}", fn_name_str.to_uppercase());
466
467 let call_impl = if is_async {
469 quote! {
470 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
471 #fn_name(#(args.#param_names,)*).await
472 }
473 }
474 } else {
475 quote! {
476 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
477 #fn_name(#(args.#param_names,)*)
478 }
479 }
480 };
481
482 let rig_core = rig_core_path();
483 let expanded = quote! {
484 #[derive(serde::Deserialize)]
485 #vis struct #params_struct_name {
486 #(#vis #param_names: #param_types,)*
487 }
488
489 #input_fn
490
491 #[derive(Default)]
492 #vis struct #struct_name;
493
494 impl #rig_core::tool::Tool for #struct_name {
495 const NAME: &'static str = #tool_name;
496
497 type Args = #params_struct_name;
498 type Output = #output_type;
499 type Error = #error_type;
500
501 fn name(&self) -> String {
502 #tool_name.to_string()
503 }
504
505 async fn definition(&self, _prompt: String) -> #rig_core::completion::ToolDefinition {
506 let parameters = serde_json::json!({
507 "type": "object",
508 "properties": {
509 #(
510 stringify!(#param_names): {
511 #json_types,
512 "description": #param_descriptions
513 }
514 ),*
515 },
516 "required": [#(#required_args),*]
517 });
518
519 #rig_core::completion::ToolDefinition {
520 name: #tool_name.to_string(),
521 description: #tool_description.to_string(),
522 parameters,
523 }
524 }
525
526 #call_impl
527 }
528
529 #vis static #static_name: #struct_name = #struct_name;
530 };
531
532 TokenStream::from(expanded)
533}