backerror_macros/
lib.rs

1#![allow(unused_imports, dead_code)]
2use proc_macro::TokenStream;
3use proc_macro2::Ident;
4use quote::{quote, ToTokens};
5use syn::{
6    parse_macro_input, punctuated::Punctuated, Item, ItemEnum, ItemStruct, Meta, Path, Token,
7};
8
9/// backerror
10#[cfg(not(any(not(feature = "release_off"), debug_assertions)))]
11#[proc_macro_attribute]
12pub fn backerror(_args: TokenStream, input: TokenStream) -> TokenStream {
13    input
14}
15
16/// Helper attribute macro to enhance `thiserror::Error`, which adds `backerror::LocatedError` to the error type.
17/// ```ignore
18/// use backerror::backerror;
19/// use thiserror::Error;
20///
21/// #[backerror]
22/// #[derive(Debug, Error)]
23/// pub enum MyError1 {
24///     #[error("{0}")]
25///     IoError(#[from] std::io::Error),
26/// }
27///
28/// #[backerror]
29/// #[derive(Debug, Error)]
30/// #[error(transparent)]
31/// pub struct MyError(#[from] std::io::Error);
32///
33/// ```
34#[cfg(any(not(feature = "release_off"), debug_assertions))]
35#[proc_macro_attribute]
36pub fn backerror(_args: TokenStream, input: TokenStream) -> TokenStream {
37    let input2 = input.clone();
38    let item = parse_macro_input!(input2 as Item);
39
40    match item {
41        Item::Enum(item_enum) => backerror_enum(item_enum, input),
42        Item::Struct(item_struct) => backerror_struct(item_struct, input),
43        _ => input,
44    }
45}
46
47/// enum error
48///
49/// ```ignore
50/// #[backerror]
51/// #[derive(Debug, Error)]
52/// pub enum MyError {
53///     #[error("{0}")]
54///     IoError(#[from] std::io::Error),
55/// }
56/// ```
57fn backerror_enum(mut item_enum: ItemEnum, input: TokenStream) -> TokenStream {
58    // check whether the enum derives thiserror::Error
59    if !check_derive_thiserror(&item_enum.attrs) {
60        return input;
61    }
62
63    let mut error_types = Vec::new();
64
65    for variant in item_enum.variants.iter_mut() {
66        let fields = &mut variant.fields;
67        enhance_fields(fields, &mut error_types);
68    }
69
70    if let Ok(impls) = generate_from_impl(&item_enum.ident, &error_types) {
71        let ret = quote! {
72            #item_enum
73            #impls
74        };
75
76        ret.into()
77    } else {
78        input
79    }
80}
81
82/// transparent struct
83///
84/// ```ignore
85/// #[backerror]
86/// #[derive(Debug, Error)]
87/// #[error(transparent)]
88/// pub struct MyError(#[from] std::io::Error);
89/// ```
90fn backerror_struct(mut item_struct: ItemStruct, input: TokenStream) -> TokenStream {
91    // check whether the struct derives thiserror::Error
92    if !check_derive_thiserror(&item_struct.attrs) || !check_transparent_struct(&item_struct.attrs)
93    {
94        return input;
95    }
96
97    let mut error_types = Vec::new();
98
99    let fields = &mut item_struct.fields;
100    enhance_fields(fields, &mut error_types);
101
102    if let Ok(impls) = generate_from_impl(&item_struct.ident, &error_types) {
103        let ret = quote! {
104            #item_struct
105            #impls
106        };
107
108        ret.into()
109    } else {
110        input
111    }
112}
113
114fn generate_from_impl(
115    ident: &Ident,
116    error_types: &Vec<String>,
117) -> Result<proc_macro2::TokenStream, syn::Error> {
118    if error_types.is_empty() {
119        return Err(syn::Error::new(
120            proc_macro2::Span::call_site(),
121            "no attribute found",
122        ));
123    }
124
125    let mut impls = Vec::new();
126    for e in error_types {
127        let from_ty: Path = syn::parse_str(e)?;
128        let block = quote! {
129            impl From<#from_ty> for #ident {
130                #[track_caller]
131                fn from(e: #from_ty) -> Self {
132                    #ident::from(backerror::LocatedError::from(e))
133                }
134            }
135        };
136        impls.push(block);
137    }
138
139    Ok(quote! {
140        #(#impls)*
141    })
142}
143
144/// enhance fiels from `#[from] T` to `#[from] backerror::LocatedError<T>`
145fn enhance_fields(fields: &mut syn::Fields, errors: &mut Vec<String>) {
146    match fields {
147        syn::Fields::Unnamed(fs) => {
148            for field in fs.unnamed.iter_mut() {
149                if check_attr_from(&field.attrs) {
150                    let orig_ty = field.ty.clone().into_token_stream().to_string();
151                    let ty = format!("backerror::LocatedError<{}>", orig_ty);
152                    if let Ok(new_type) = syn::parse_str(&ty) {
153                        errors.push(orig_ty);
154                        field.ty = new_type;
155                    } else {
156                        println!("failed to parse {}", ty);
157                    }
158                } else {
159                    // println!("transparent struct field without #[from]");
160                }
161            }
162        }
163        syn::Fields::Unit | syn::Fields::Named(_) => {
164            // do nothing
165        }
166    }
167}
168
169/// check `#[derive(Error)]`
170fn check_derive_thiserror(attrs: &Vec<syn::Attribute>) -> bool {
171    for attr in attrs {
172        if attr.path().is_ident("derive") {
173            if let Ok(nested) =
174                attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
175            {
176                for meta in nested {
177                    match meta {
178                        Meta::Path(path) => {
179                            // #[derive(Error)]
180                            if path.is_ident("Error") {
181                                return true;
182                            }
183                            // #[derive(thiserror::Error)]
184                            let path = path.into_token_stream().to_string();
185                            if path.contains("thiserror") && path.contains("Error") {
186                                return true;
187                            }
188                        }
189
190                        _ => {}
191                    }
192                }
193            }
194        }
195    }
196    return false;
197}
198
199/// check `#[error(transparent)]`
200fn check_transparent_struct(attrs: &Vec<syn::Attribute>) -> bool {
201    for attr in attrs {
202        if attr.path().is_ident("error") {
203            if let Ok(nested) =
204                attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
205            {
206                for meta in nested {
207                    match meta {
208                        Meta::Path(path) => {
209                            // #[error(transparent)]
210                            if path.is_ident("transparent") {
211                                return true;
212                            }
213                        }
214
215                        _ => {}
216                    }
217                }
218            }
219        }
220    }
221    return false;
222}
223
224/// check `#[from]`
225fn check_attr_from(attrs: &Vec<syn::Attribute>) -> bool {
226    for attr in attrs {
227        if attr.path().is_ident("from") {
228            return true;
229        }
230    }
231    return false;
232}