rig_tool_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse::Parse;
4use syn::{parse_macro_input, FnArg, ItemFn, LitStr, PatType, ReturnType, Type};
5
6// Add this struct to parse the description attribute
7struct ToolAttr {
8    description: Option<String>,
9}
10
11impl Parse for ToolAttr {
12    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
13        let mut description = None;
14
15        if !input.is_empty() {
16            let name: syn::Ident = input.parse()?;
17            if name == "description" {
18                let _: syn::Token![=] = input.parse()?;
19                let desc: LitStr = input.parse()?;
20                description = Some(desc.value());
21            }
22        }
23
24        Ok(ToolAttr { description })
25    }
26}
27
28fn to_pascal_case(s: &str) -> String {
29    s.split('_')
30        .map(|part| {
31            let mut chars = part.chars();
32            match chars.next() {
33                None => String::new(),
34                Some(first) => first.to_uppercase().chain(chars).collect(),
35            }
36        })
37        .collect()
38}
39
40fn get_json_type(ty: &Type) -> proc_macro2::TokenStream {
41    match ty {
42        Type::Path(type_path) => {
43            let segment = &type_path.path.segments[0];
44            let type_name = segment.ident.to_string();
45
46            // Handle Vec types
47            if type_name == "Vec" {
48                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
49                    if let syn::GenericArgument::Type(inner_type) = &args.args[0] {
50                        let inner_json_type = get_json_type(inner_type);
51                        return quote! {
52                            "type": "array",
53                            "items": { #inner_json_type }
54                        };
55                    }
56                }
57                return quote! { "type": "array" };
58            }
59
60            // Handle primitive types
61            match type_name.as_str() {
62                "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" => {
63                    quote! { "type": "number" }
64                }
65                "String" | "str" => {
66                    quote! { "type": "string" }
67                }
68                "bool" => {
69                    quote! { "type": "boolean" }
70                }
71                // Handle other types as objects
72                _ => {
73                    quote! { "type": "object" }
74                }
75            }
76        }
77        _ => quote! { "type": "object" },
78    }
79}
80
81#[proc_macro_attribute]
82pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
83    let attr = parse_macro_input!(attr as ToolAttr);
84    let input_fn = parse_macro_input!(item as ItemFn);
85
86    let fn_name = &input_fn.sig.ident;
87    let fn_name_str = fn_name.to_string();
88    let struct_name = quote::format_ident!("{}Tool", to_pascal_case(&fn_name_str));
89    let static_name = quote::format_ident!("{}", to_pascal_case(&fn_name_str));
90    let error_name = quote::format_ident!("{}Error", struct_name);
91
92    // Extract return type
93    let return_type = if let ReturnType::Type(_, ty) = &input_fn.sig.output {
94        if let Type::Path(type_path) = ty.as_ref() {
95            if type_path.path.segments[0].ident == "Result" {
96                if let syn::PathArguments::AngleBracketed(args) =
97                    &type_path.path.segments[0].arguments
98                {
99                    if let syn::GenericArgument::Type(t) = &args.args[0] {
100                        t
101                    } else {
102                        panic!("Expected type argument in Result")
103                    }
104                } else {
105                    panic!("Expected angle bracketed arguments in Result")
106                }
107            } else {
108                ty.as_ref()
109            }
110        } else {
111            ty.as_ref()
112        }
113    } else {
114        panic!("Function must return a Result")
115    };
116
117    let args = input_fn.sig.inputs.iter().filter_map(|arg| {
118        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
119            Some((pat, ty))
120        } else {
121            None
122        }
123    });
124
125    let arg_names: Vec<_> = args.clone().map(|(pat, _)| pat).collect();
126    let arg_types: Vec<_> = args.clone().map(|(_, ty)| ty).collect();
127    let json_types: Vec<_> = arg_types.iter().map(|ty| get_json_type(ty)).collect();
128
129    let args_struct_name = quote::format_ident!("{}Args", to_pascal_case(&fn_name_str));
130
131    let call_impl = if input_fn.sig.asyncness.is_some() {
132        quote! {
133            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
134                #fn_name(#(args.#arg_names),*).await
135                    .map_err(|e| Self::Error::ExecutionError(e.to_string()))
136            }
137        }
138    } else {
139        quote! {
140            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
141                #fn_name(#(args.#arg_names),*)
142                    .map_err(|e| Self::Error::ExecutionError(e.to_string()))
143            }
144        }
145    };
146
147    // Modify the definition implementation to use the description
148    let description = match attr.description {
149        Some(desc) => quote! { #desc.to_string() },
150        None => quote! { format!("Function to {}", Self::NAME) },
151    };
152
153    let expanded = quote! {
154        #[derive(Debug, thiserror::Error)]
155        pub enum #error_name {
156            #[error("Tool execution failed: {0}")]
157            ExecutionError(String),
158        }
159
160        #[derive(Debug, Clone, Copy, serde::Deserialize, serde::Serialize)]
161        pub struct #struct_name;
162
163        #[derive(Debug, serde::Deserialize, serde::Serialize)]
164        pub struct #args_struct_name {
165            #(#arg_names: #arg_types),*
166        }
167
168        #input_fn
169
170        impl rig::tool::Tool for #struct_name {
171            const NAME: &'static str = #fn_name_str;
172
173            type Error = #error_name;
174            type Args = #args_struct_name;
175            type Output = #return_type;
176
177            async fn definition(&self, _prompt: String) -> rig::completion::ToolDefinition {
178                rig::completion::ToolDefinition {
179                    name: Self::NAME.to_string(),
180                    description: #description,
181                    parameters: serde_json::json!({
182                        "type": "object",
183                        "properties": {
184                            #(
185                                stringify!(#arg_names): {
186                                    #json_types,
187                                    "description": format!("Parameter {}", stringify!(#arg_names))
188                                }
189                            ),*
190                        },
191                    }),
192                }
193            }
194
195            #call_impl
196        }
197
198        pub static #static_name: #struct_name = #struct_name;
199    };
200
201    expanded.into()
202}