errorset/
lib.rs

1#![recursion_limit = "512"]
2#![doc = include_str!("../README.md")]
3
4extern crate proc_macro;
5
6use core::panic;
7
8use convert_case::{Case, Casing};
9use proc_macro::TokenStream;
10use proc_macro2::Span;
11use quote::quote;
12use std::collections::HashSet;
13use syn::{
14    parse::{Parse, ParseStream, Result},
15    parse_macro_input,
16    punctuated::Punctuated,
17    token::PathSep,
18    Ident, ImplItemFn, ItemFn, ItemImpl, PathArguments, PathSegment, ReturnType, Token, Type, TypePath, Visibility,
19};
20
21struct ErrorsetArgs {
22    visibility: Visibility,
23    module: Option<Ident>,
24}
25
26impl Parse for ErrorsetArgs {
27    fn parse(input: ParseStream) -> Result<Self> {
28        let mut module = None;
29
30        // try parse Visibility of module
31        let visibility: Visibility = input.parse()?;
32        // try parse module definition like `mod "module_name"`
33        let lookahead = input.lookahead1();
34        if lookahead.peek(Token![mod]) {
35            input.parse::<Token![mod]>()?;
36            let mod_name: Ident = input.parse()?;
37            module = Some(mod_name);
38        }
39
40        Ok(ErrorsetArgs { visibility, module })
41    }
42}
43
44#[proc_macro_attribute]
45pub fn errorset(attr: TokenStream, item: TokenStream) -> TokenStream {
46    let args = parse_macro_input!(attr as ErrorsetArgs);
47    let input = parse_macro_input!(item as syn::Item);
48
49    match input {
50        syn::Item::Fn(item_fn) => handle_function(&args, item_fn),
51        syn::Item::Impl(item_impl) => handle_impl_block(&args, item_impl),
52        _ => panic!("errorset can only be applied to functions or impl blocks"),
53    }
54}
55
56struct Output {
57    enum_def: proc_macro2::TokenStream,
58    fn_def: proc_macro2::TokenStream,
59}
60
61fn process_fn(args: &ErrorsetArgs, item_fn: &ItemFn) -> Result<Option<Output>> {
62    // Extract the function name and convert it to camel-case for the enum name
63    let fn_name = &item_fn.sig.ident;
64    let enum_name = Ident::new(
65        &format!("{}Errors", fn_name.to_string().to_case(Case::Pascal)),
66        Span::call_site(),
67    );
68
69    // Extract the return type from the function signature
70    let output_type = match &item_fn.sig.output {
71        ReturnType::Type(_, ty) => ty,
72        _ => {
73            return Err(syn::Error::new_spanned(
74                &item_fn.sig.output,
75                "Function must have a valid return type",
76            ))
77        }
78    };
79
80    let (new_return_type, err_types) = if let Type::Path(TypePath { path, .. }) = &**output_type {
81        if let Some(last_segment) = path.segments.last() {
82            if let PathArguments::AngleBracketed(ref params) = last_segment.arguments {
83                if params.args.len() != 2 {
84                    return Err(syn::Error::new_spanned(
85                        &params.args,
86                        "Expected exactly 2 generic arguments",
87                    ));
88                }
89
90                match params.args.iter().nth(1).unwrap() {
91                    syn::GenericArgument::Type(Type::Tuple(tuple)) => {
92                        let mut punctuated = Punctuated::<PathSegment, PathSep>::new();
93                        for seg in path.segments.iter() {
94                            punctuated.push_value(seg.ident.clone().into());
95                            // Add separator if there are more segments
96                            if punctuated.len() < path.segments.len() {
97                                punctuated.push_punct(PathSep::default());
98                            }
99                        }
100                        let new_path = syn::Path {
101                            leading_colon: path.leading_colon.clone(),
102                            segments: punctuated,
103                        };
104
105                        // Create new return type with the same name and the first generic parameter
106                        // The second parameter is the enum with error types
107                        let first_generic_arg = params.args.iter().next().unwrap();
108                        let new_return_type = if let Some(module) = &args.module {
109                            quote! {
110                                #new_path<#first_generic_arg, #module::#enum_name>
111                            }
112                        } else {
113                            quote! {
114                                #new_path<#first_generic_arg, #enum_name>
115                            }
116                        };
117                        let err_types = tuple.elems.clone();
118                        (new_return_type, err_types)
119                    }
120                    syn::GenericArgument::Type(Type::Paren(_)) | syn::GenericArgument::Type(Type::Path(_)) => {
121                        // If the second argument is defined as `(Error1)`, it does not determined as a tuple, just leave it as is
122                        // The same if the second argument is a regular type
123                        return Ok(None);
124                    }
125                    other => {
126                        return Err(syn::Error::new_spanned(
127                            other,
128                            "Expected the second generic argument to be a tuple",
129                        ));
130                    }
131                }
132            } else {
133                return Err(syn::Error::new_spanned(
134                    last_segment,
135                    "Expected angle-bracketed generic arguments",
136                ));
137            }
138        } else {
139            return Err(syn::Error::new_spanned(
140                path,
141                "Expected a valid type path for the generic type",
142            ));
143        }
144    } else {
145        return Err(syn::Error::new_spanned(
146            output_type,
147            "Function must return a generic type with 2 parameters",
148        ));
149    };
150
151    // Generate enum variants for each error type
152    let mut seen = HashSet::new();
153    let enum_variants = err_types
154        .iter()
155        .filter(|ty| match ty {
156            Type::Path(TypePath { path, .. }) => seen.insert(path.segments.last().unwrap().ident.to_string()),
157            _ => true,
158        })
159        .map(|ty| {
160            let ty_name = match ty {
161                Type::Path(TypePath { path, .. }) => path.segments.last().unwrap().ident.clone(),
162                _ => return quote! {}, // skip invalid
163            };
164            quote! {
165                #[error(transparent)]
166                #ty_name(#[from] #ty),
167            }
168        });
169
170    // Generate the enum definition
171    let enum_vis = if args.module.is_some() {
172        // use pub visibility for the enum if it's inside a module
173        syn::Visibility::Public(Default::default())
174    } else {
175        item_fn.vis.clone()
176    };
177    let enum_def = quote! {
178        #[derive(::thiserror::Error, Debug)]
179        #enum_vis enum #enum_name {
180            #(#enum_variants)*
181        }
182    };
183
184    let fn_sig = &item_fn.sig;
185    let fn_attrs = &item_fn.attrs;
186    let fn_vis = &item_fn.vis;
187    let fn_body = &item_fn.block;
188
189    let mut new_sig = fn_sig.clone();
190    new_sig.output = syn::parse2(quote! { -> #new_return_type }).unwrap();
191
192    // Generate the modified function with the new return type
193    let new_fn = quote! {
194        #(#fn_attrs)*
195        #fn_vis #new_sig
196        #fn_body
197    };
198
199    Ok(Some(Output { enum_def, fn_def: new_fn }))
200}
201
202fn handle_function(args: &ErrorsetArgs, item_fn: ItemFn) -> TokenStream {
203    match process_fn(args, &item_fn) {
204        Ok(Some(Output { enum_def, fn_def })) => {
205            if let Some(module) = &args.module {
206                let vis = &args.visibility;
207                quote! {
208                    #vis mod #module {
209                        use super::*;
210                        #enum_def
211                    }
212                    #fn_def
213                }
214            } else {
215                quote! {
216                    #enum_def
217                    #fn_def
218                }
219            }
220        }
221        Ok(None) => quote! { #item_fn },
222        Err(e) => e.to_compile_error(),
223    }
224    .into()
225}
226
227fn handle_impl_block(args: &ErrorsetArgs, item_impl: ItemImpl) -> TokenStream {
228    let mut new_items = Vec::new();
229    let mut new_enums = Vec::new();
230
231    for item in item_impl.items {
232        if let syn::ImplItem::Fn(method) = &item {
233            let mut new_attrs = Vec::new();
234            let mut marked = false;
235
236            for attr in &method.attrs {
237                if attr.path().is_ident("errorset") {
238                    if attr.meta.require_path_only().is_err() {
239                        return syn::Error::new_spanned(
240                            attr,
241                            "errorset attribute must not have arguments inside impl blocks",
242                        )
243                        .to_compile_error()
244                        .into();
245                    }
246                    marked = true;
247                } else {
248                    new_attrs.push(attr.clone());
249                }
250            }
251
252            if !marked {
253                new_items.push(item);
254                continue;
255            }
256
257            let item_fn = ItemFn {
258                attrs: new_attrs,
259                vis: method.vis.clone(),
260                sig: method.sig.clone(),
261                block: Box::new(method.block.clone()),
262            };
263
264            match process_fn(args, &item_fn) {
265                Ok(Some(Output { enum_def, fn_def })) => {
266                    let impl_item = syn::parse2::<ImplItemFn>(fn_def).expect("Invalid method reparse");
267                    new_items.push(impl_item.into());
268                    new_enums.push(enum_def);
269                }
270                Ok(None) => new_items.push(item),
271                Err(e) => return e.to_compile_error().into(),
272            }
273        } else {
274            new_items.push(item);
275        }
276    }
277
278    let new_impl_block = ItemImpl { items: new_items, ..item_impl };
279
280    if let Some(module) = &args.module {
281        // create module if new_enums is not empty
282        // otherwise, just add new_impl_block
283        if new_enums.is_empty() {
284            quote! {
285                #new_impl_block
286            }
287        } else {
288            let vis = &args.visibility;
289            quote! {
290                #vis mod #module {
291                    use super::*;
292                    #(#new_enums)*
293                }
294                #new_impl_block
295            }
296        }
297    } else {
298        quote! {
299            #(#new_enums)*
300            #new_impl_block
301        }
302    }
303    .into()
304}