1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse::Parse;
4use syn::{parse_macro_input, FnArg, ItemFn, LitStr, PatType, ReturnType, Type};
5
6struct ToolAttr {
8 description: Option<String>,
9}
10
11impl Parse for ToolAttr {
12 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
13 let mut description = None;
14
15 if !input.is_empty() {
16 let name: syn::Ident = input.parse()?;
17 if name == "description" {
18 let _: syn::Token![=] = input.parse()?;
19 let desc: LitStr = input.parse()?;
20 description = Some(desc.value());
21 }
22 }
23
24 Ok(ToolAttr { description })
25 }
26}
27
28fn to_pascal_case(s: &str) -> String {
29 s.split('_')
30 .map(|part| {
31 let mut chars = part.chars();
32 match chars.next() {
33 None => String::new(),
34 Some(first) => first.to_uppercase().chain(chars).collect(),
35 }
36 })
37 .collect()
38}
39
40fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
41 match ty {
42 Type::Path(type_path) => {
43 let segment = &type_path.path.segments[0];
44 let type_name = segment.ident.to_string();
45
46 if type_name == "Vec" {
48 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
49 if let syn::GenericArgument::Type(inner_type) = &args.args[0] {
50 let inner_json_type = get_json_type(inner_type);
51 return quote! {
52 "type": "array",
53 "items": { #inner_json_type }
54 };
55 }
56 }
57 return quote! { "type": "array" };
58 }
59
60 match type_name.as_str() {
62 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
63 quote! { "type": "number" }
64 }
65 "String" | "str" => {
66 quote! { "type": "string" }
67 }
68 "bool" => {
69 quote! { "type": "boolean" }
70 }
71 _ => {
73 quote! { "type": "object" }
74 }
75 }
76 }
77 _ => quote! { "type": "object" },
78 }
79}
80
81#[proc_macro_attribute]
82pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
83 let attr = parse_macro_input!(attr as ToolAttr);
84 let input_fn = parse_macro_input!(item as ItemFn);
85
86 let fn_name = &input_fn.sig.ident;
87 let fn_name_str = fn_name.to_string();
88 let struct_name = quote::format_ident!("{}Tool", to_pascal_case(&fn_name_str));
89 let static_name = quote::format_ident!("{}", to_pascal_case(&fn_name_str));
90 let error_name = quote::format_ident!("{}Error", struct_name);
91
92 let return_type = if let ReturnType::Type(_, ty) = &input_fn.sig.output {
94 if let Type::Path(type_path) = ty.as_ref() {
95 if type_path.path.segments[0].ident == "Result" {
96 if let syn::PathArguments::AngleBracketed(args) =
97 &type_path.path.segments[0].arguments
98 {
99 if let syn::GenericArgument::Type(t) = &args.args[0] {
100 t
101 } else {
102 panic!("Expected type argument in Result")
103 }
104 } else {
105 panic!("Expected angle bracketed arguments in Result")
106 }
107 } else {
108 ty.as_ref()
109 }
110 } else {
111 ty.as_ref()
112 }
113 } else {
114 panic!("Function must return a Result")
115 };
116
117 let args = input_fn.sig.inputs.iter().filter_map(|arg| {
118 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
119 Some((pat, ty))
120 } else {
121 None
122 }
123 });
124
125 let arg_names: Vec<_> = args.clone().map(|(pat, _)| pat).collect();
126 let arg_types: Vec<_> = args.clone().map(|(_, ty)| ty).collect();
127 let json_types: Vec<_> = arg_types.iter().map(|ty| get_json_type(ty)).collect();
128
129 let args_struct_name = quote::format_ident!("{}Args", to_pascal_case(&fn_name_str));
130
131 let call_impl = if input_fn.sig.asyncness.is_some() {
132 quote! {
133 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
134 #fn_name(#(args.#arg_names),*).await
135 .map_err(|e| Self::Error::ExecutionError(e.to_string()))
136 }
137 }
138 } else {
139 quote! {
140 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
141 #fn_name(#(args.#arg_names),*)
142 .map_err(|e| Self::Error::ExecutionError(e.to_string()))
143 }
144 }
145 };
146
147 let description = match attr.description {
149 Some(desc) => quote! { #desc.to_string() },
150 None => quote! { format!("Function to {}", Self::NAME) },
151 };
152
153 let expanded = quote! {
154 #[derive(Debug, thiserror::Error)]
155 pub enum #error_name {
156 #[error("Tool execution failed: {0}")]
157 ExecutionError(String),
158 }
159
160 #[derive(Debug, Clone, Copy, serde::Deserialize, serde::Serialize)]
161 pub struct #struct_name;
162
163 #[derive(Debug, serde::Deserialize, serde::Serialize)]
164 pub struct #args_struct_name {
165 #(#arg_names: #arg_types),*
166 }
167
168 #input_fn
169
170 impl rig::tool::Tool for #struct_name {
171 const NAME: &'static str = #fn_name_str;
172
173 type Error = #error_name;
174 type Args = #args_struct_name;
175 type Output = #return_type;
176
177 async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
178 rig::completion::ToolDefinition {
179 name: Self::NAME.to_string(),
180 description: #description,
181 parameters: serde_json::json!({
182 "type": "object",
183 "properties": {
184 #(
185 stringify!(#arg_names): {
186 #json_types,
187 "description": format!("Parameter {}", stringify!(#arg_names))
188 }
189 ),*
190 },
191 }),
192 }
193 }
194
195 #call_impl
196 }
197
198 pub static #static_name: #struct_name = #struct_name;
199 };
200
201 expanded.into()
202}