1extern crate proc_macro;
2
3use convert_case::{Case, Casing};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use std::{collections::HashMap, ops::Deref};
7use syn::{
8 Attribute, DeriveInput, Expr, ExprLit, Ident, Lit, Meta, PathArguments, ReturnType, Token,
9 Type,
10 parse::{Parse, ParseStream},
11 parse_macro_input,
12 punctuated::Punctuated,
13};
14
15mod basic;
16mod client;
17mod custom;
18mod embed;
19
20pub(crate) const EMBED: &str = "embed";
21
22pub(crate) fn rig_core_path() -> proc_macro2::TokenStream {
23 match proc_macro_crate::crate_name("rig-core") {
24 Ok(proc_macro_crate::FoundCrate::Itself) => quote!(crate),
25 Ok(proc_macro_crate::FoundCrate::Name(name)) => {
26 let ident = format_ident!("{name}");
27 quote!(::#ident)
28 }
29 Err(_) => match proc_macro_crate::crate_name("rig") {
30 Ok(proc_macro_crate::FoundCrate::Itself) => quote!(crate),
31 Ok(proc_macro_crate::FoundCrate::Name(name)) => {
32 let ident = format_ident!("{name}");
33 quote!(::#ident)
34 }
35 Err(_) => quote!(::rig_core),
36 },
37 }
38}
39
40#[proc_macro_derive(ProviderClient, attributes(client))]
41pub fn derive_provider_client(input: TokenStream) -> TokenStream {
42 client::provider_client(input)
43}
44
45#[proc_macro_derive(Embed, attributes(embed))]
63pub fn derive_embedding_trait(item: TokenStream) -> TokenStream {
64 let mut input = parse_macro_input!(item as DeriveInput);
65
66 embed::expand_derive_embedding(&mut input)
67 .unwrap_or_else(syn::Error::into_compile_error)
68 .into()
69}
70
71struct MacroArgs {
72 name: Option<String>,
73 description: Option<String>,
74 param_descriptions: HashMap<String, String>,
75 required: Option<Vec<String>>,
76}
77
78fn parse_string_literal(expr: &Expr, field_name: &str) -> syn::Result<String> {
79 match expr {
80 Expr::Lit(ExprLit {
81 lit: Lit::Str(lit_str),
82 ..
83 }) => Ok(lit_str.value()),
84 _ => Err(syn::Error::new_spanned(
85 expr,
86 format!("`{field_name}` must be a string literal"),
87 )),
88 }
89}
90
91fn validate_explicit_tool_name(name: &str, expr: &Expr) -> syn::Result<()> {
92 if name.is_empty() || name.len() > 64 {
93 return Err(syn::Error::new_spanned(
94 expr,
95 "`name` must be between 1 and 64 characters long",
96 ));
97 }
98
99 let mut chars = name.chars();
100 let Some(first_char) = chars.next() else {
101 return Err(syn::Error::new_spanned(
102 expr,
103 "`name` must be between 1 and 64 characters long",
104 ));
105 };
106
107 if !first_char.is_ascii_alphabetic() && first_char != '_' {
108 return Err(syn::Error::new_spanned(
109 expr,
110 "`name` must start with an ASCII letter or underscore",
111 ));
112 }
113
114 if chars.any(|ch| !ch.is_ascii_alphanumeric() && ch != '_' && ch != '-') {
115 return Err(syn::Error::new_spanned(
116 expr,
117 "`name` may only contain ASCII letters, digits, underscores, or hyphens",
118 ));
119 }
120
121 Ok(())
122}
123
124impl Parse for MacroArgs {
125 fn parse(input: ParseStream) -> syn::Result<Self> {
126 let mut name = None;
127 let mut description = None;
128 let mut param_descriptions = HashMap::new();
129 let mut required = None;
130
131 if input.is_empty() {
133 return Ok(MacroArgs {
134 name,
135 description,
136 param_descriptions,
137 required,
138 });
139 }
140
141 let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
142
143 for meta in meta_list {
144 match meta {
145 Meta::NameValue(nv) => {
146 let ident = nv.path.get_ident().ok_or_else(|| {
147 syn::Error::new_spanned(
148 &nv.path,
149 "unsupported top-level #[rig_tool] argument",
150 )
151 })?;
152
153 match ident.to_string().as_str() {
154 "name" => {
155 let parsed_name = parse_string_literal(&nv.value, "name")?;
156 validate_explicit_tool_name(&parsed_name, &nv.value)?;
157 name = Some(parsed_name);
158 }
159 "description" => {
160 description = Some(parse_string_literal(&nv.value, "description")?);
161 }
162 _ => {
163 return Err(syn::Error::new_spanned(
164 &nv.path,
165 format!("unsupported top-level #[rig_tool] argument `{}`", ident),
166 ));
167 }
168 }
169 }
170 Meta::List(list) => {
171 let ident = list.path.get_ident().ok_or_else(|| {
172 syn::Error::new_spanned(
173 &list.path,
174 "unsupported top-level #[rig_tool] argument",
175 )
176 })?;
177
178 match ident.to_string().as_str() {
179 "params" => {
180 let nested: Punctuated<Meta, Token![,]> =
181 list.parse_args_with(Punctuated::parse_terminated)?;
182
183 for meta in nested {
184 if let Meta::NameValue(nv) = meta
185 && let Expr::Lit(ExprLit {
186 lit: Lit::Str(lit_str),
187 ..
188 }) = nv.value
189 {
190 let Some(param_ident) = nv.path.get_ident() else {
191 return Err(syn::Error::new_spanned(
192 &nv.path,
193 "parameter descriptions must use identifier keys",
194 ));
195 };
196 let param_name = param_ident.to_string();
197 param_descriptions.insert(param_name, lit_str.value());
198 }
199 }
200 }
201 "required" => {
202 let required_variables: Punctuated<Ident, Token![,]> =
203 list.parse_args_with(Punctuated::parse_terminated)?;
204
205 required = Some(
206 required_variables
207 .into_iter()
208 .map(|x| x.to_string())
209 .collect(),
210 );
211 }
212 _ => {
213 return Err(syn::Error::new_spanned(
214 &list.path,
215 format!("unsupported top-level #[rig_tool] argument `{}`", ident),
216 ));
217 }
218 }
219 }
220 Meta::Path(path) => {
221 let message = if let Some(ident) = path.get_ident() {
222 format!("unsupported top-level #[rig_tool] argument `{ident}`")
223 } else {
224 "unsupported top-level #[rig_tool] argument".to_string()
225 };
226
227 return Err(syn::Error::new_spanned(path, message));
228 }
229 }
230 }
231
232 Ok(MacroArgs {
233 name,
234 description,
235 param_descriptions,
236 required,
237 })
238 }
239}
240
241fn extract_doc_comment(attrs: &[Attribute]) -> Option<String> {
243 let lines: Vec<String> = attrs
244 .iter()
245 .filter_map(|attr| {
246 if !attr.path().is_ident("doc") {
247 return None;
248 }
249 if let Meta::NameValue(nv) = &attr.meta
250 && let Expr::Lit(ExprLit {
251 lit: Lit::Str(s), ..
252 }) = &nv.value
253 {
254 return Some(s.value());
255 }
256 None
257 })
258 .collect();
259
260 if lines.is_empty() {
261 return None;
262 }
263
264 Some(
265 lines
266 .iter()
267 .map(|l| l.strip_prefix(' ').unwrap_or(l))
268 .collect::<Vec<_>>()
269 .join("\n")
270 .trim()
271 .to_string(),
272 )
273}
274
275fn is_option_type(ty: &Type) -> bool {
277 if let Type::Path(type_path) = ty
278 && let Some(segment) = type_path.path.segments.last()
279 {
280 return segment.ident == "Option";
281 }
282 false
283}
284
285fn result_type_tokens(
286 return_type: &ReturnType,
287) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
288 let ReturnType::Type(_, ty) = return_type else {
289 return Err(syn::Error::new_spanned(
290 return_type,
291 "function must have a return type of Result<T, E>",
292 ));
293 };
294
295 let Type::Path(type_path) = ty.deref() else {
296 return Err(syn::Error::new_spanned(
297 ty,
298 "return type must be Result<T, E>",
299 ));
300 };
301
302 let Some(last_segment) = type_path.path.segments.last() else {
303 return Err(syn::Error::new_spanned(
304 &type_path.path,
305 "return type must be Result<T, E>",
306 ));
307 };
308
309 if last_segment.ident != "Result" {
310 return Err(syn::Error::new_spanned(
311 &last_segment.ident,
312 "return type must be Result<T, E>",
313 ));
314 }
315
316 let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
317 return Err(syn::Error::new_spanned(
318 &last_segment.arguments,
319 "expected angle-bracketed type parameters for Result<T, E>",
320 ));
321 };
322
323 let mut generic_args = args.args.iter();
324 let Some(output) = generic_args.next() else {
325 return Err(syn::Error::new_spanned(
326 &args.args,
327 "expected Result<T, E> with exactly two type parameters",
328 ));
329 };
330 let Some(error) = generic_args.next() else {
331 return Err(syn::Error::new_spanned(
332 &args.args,
333 "expected Result<T, E> with exactly two type parameters",
334 ));
335 };
336
337 if generic_args.next().is_some() {
338 return Err(syn::Error::new_spanned(
339 &args.args,
340 "expected Result<T, E> with exactly two type parameters",
341 ));
342 }
343
344 Ok((quote!(#output), quote!(#error)))
345}
346
347#[proc_macro_attribute]
411pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream {
412 let args = parse_macro_input!(args as MacroArgs);
413 let input_fn = parse_macro_input!(input as syn::ItemFn);
414
415 let fn_name = &input_fn.sig.ident;
417 let fn_name_str = fn_name.to_string();
418 let tool_name = args.name.clone().unwrap_or_else(|| fn_name_str.clone());
419 let vis = &input_fn.vis;
420 let is_async = input_fn.sig.asyncness.is_some();
421
422 let cleaned_fn = {
425 let mut f = input_fn.clone();
426 for arg in f.sig.inputs.iter_mut() {
427 if let syn::FnArg::Typed(pat_type) = arg {
428 pat_type.attrs.retain(|a| !a.path().is_ident("doc"));
429 }
430 }
431 f
432 };
433
434 let return_type = &input_fn.sig.output;
436 let (output_type, error_type) = match result_type_tokens(return_type) {
437 Ok(types) => types,
438 Err(error) => return error.into_compile_error().into(),
439 };
440
441 let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
443
444 let fn_doc = extract_doc_comment(&input_fn.attrs);
446 let tool_description = match args.description {
447 Some(desc) => quote! { #desc.to_string() },
448 None => match fn_doc {
449 Some(doc) => quote! { #doc.to_string() },
450 None => quote! { format!("Function to {}", Self::NAME) },
451 },
452 };
453
454 let mut param_names = Vec::new();
456 let mut field_tokens = Vec::new();
457
458 for arg in input_fn.sig.inputs.iter() {
459 if let syn::FnArg::Typed(pat_type) = arg
460 && let syn::Pat::Ident(param_ident) = &*pat_type.pat
461 {
462 let param_name = ¶m_ident.ident;
463 let param_name_str = param_name.to_string();
464 let ty = &pat_type.ty;
465
466 let field_doc_attr =
469 if let Some(explicit) = args.param_descriptions.get(¶m_name_str) {
470 quote! { #[schemars(description = #explicit)] }
472 } else if let Some(doc) = extract_doc_comment(&pat_type.attrs) {
473 quote! { #[doc = #doc] }
475 } else {
476 let default_desc = format!("Parameter {param_name_str}");
478 quote! { #[schemars(description = #default_desc)] }
479 };
480
481 let serde_default = if is_option_type(ty) {
483 quote! { #[serde(default)] }
484 } else {
485 quote! {}
486 };
487
488 field_tokens.push(quote! {
489 #field_doc_attr
490 #serde_default
491 #vis #param_name: #ty
492 });
493
494 param_names.push(param_name);
495 }
496 }
497
498 let required_args: Vec<String> = args
500 .required
501 .unwrap_or_else(|| param_names.iter().map(|n| n.to_string()).collect());
502
503 let params_struct_name = format_ident!("{}Parameters", struct_name);
504 let static_name = format_ident!("{}", fn_name_str.to_uppercase());
505
506 let call_impl = if is_async {
508 quote! {
509 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
510 #fn_name(#(args.#param_names,)*).await
511 }
512 }
513 } else {
514 quote! {
515 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
516 #fn_name(#(args.#param_names,)*)
517 }
518 }
519 };
520
521 let rig_core = rig_core_path();
522 let schemars_crate = format!("{}::schemars", rig_core.to_string().replace(' ', ""));
523 let expanded = quote! {
524 #[derive(serde::Deserialize, #rig_core::schemars::JsonSchema)]
525 #[schemars(crate = #schemars_crate)]
526 #vis struct #params_struct_name {
527 #(#field_tokens,)*
528 }
529
530 #cleaned_fn
531
532 #[derive(Default)]
533 #vis struct #struct_name;
534
535 impl #rig_core::tool::Tool for #struct_name {
536 const NAME: &'static str = #tool_name;
537
538 type Args = #params_struct_name;
539 type Output = #output_type;
540 type Error = #error_type;
541
542 fn name(&self) -> String {
543 #tool_name.to_string()
544 }
545
546 async fn definition(&self, _prompt: String) -> #rig_core::completion::ToolDefinition {
547 let mut schema = serde_json::to_value(
548 #rig_core::schemars::schema_for!(#params_struct_name)
549 ).expect("schema serialization");
550 schema["required"] = serde_json::json!([#(#required_args),*]);
551
552 #rig_core::completion::ToolDefinition {
553 name: #tool_name.to_string(),
554 description: #tool_description.to_string(),
555 parameters: schema,
556 }
557 }
558
559 #call_impl
560 }
561
562 #vis static #static_name: #struct_name = #struct_name;
563 };
564
565 TokenStream::from(expanded)
566}