macron_impl_error/
lib.rs

1//! See the documentation here [macron documentation](https://docs.rs/macron)
2
3use proc_macro::TokenStream;
4use quote::quote;
5
6/// The implementation of trait Error
7#[proc_macro_derive(Error, attributes(source))]
8pub fn impl_error(input: TokenStream) -> TokenStream {
9    let syn::DeriveInput { ident, data, .. } = syn::parse_macro_input!(input as syn::DeriveInput);
10    
11    match data {
12        // impl Struct:
13        syn::Data::Struct(st) => {
14            let source = match st.fields {
15                syn::Fields::Named(fields) => {
16                    let src_field = fields.named
17                        .iter()
18                        .find(|f| f.ident
19                            .as_ref()
20                            .map(|i| i == "source")
21                            .unwrap_or(false)
22                        );
23            
24                    match src_field {
25                        Some(field) if is_option_type(&field.ty) => quote! { self.source.as_deref() },
26
27                        Some(_) => quote! { Some(self.source) },
28
29                        None => quote! { None }
30                    }
31                },
32
33                _ => quote! { None }
34            };
35            
36            quote! {
37                impl ::std::error::Error for #ident {
38                    fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
39                        #source
40                    }
41                }
42            }.into()
43        },
44
45        // impl Enum:
46        syn::Data::Enum(en) => {
47            let vars = en.variants
48                .into_iter()
49                .map(|syn::Variant { ident: var_ident, fields, .. }| {
50                    match fields {
51                        // Fields::Named { ..: .. }
52                        syn::Fields::Named(fields) => {
53                            let src_field = fields.named
54                                .iter()
55                                .find(|f| f.ident
56                                    .as_ref()
57                                    .map(|i| i == "source")
58                                    .unwrap_or(false)
59                                );
60                    
61                            match src_field {
62                                Some(field) if is_option_type(&field.ty) => quote! { Self::#var_ident { source, .. } => source.as_deref() },
63
64                                Some(_) => quote! { Self::#var_ident { source, .. } => Some(source) },
65
66                                None => quote! { Self::#var_ident(_) => None }
67                            }
68                        },
69
70                        // Fields::Unnamed(.., ..)
71                        syn::Fields::Unnamed(fields) => {
72                            let src_field = fields.unnamed
73                                .iter()
74                                .enumerate()
75                                .find(|(_, field)| field.attrs
76                                    .iter()
77                                    .any(|attr| attr.path().is_ident("source"))
78                                );
79
80                            let src_idx = src_field.map(|(idx, _)| idx).unwrap_or(0);
81                            let stubs = (0..src_idx).into_iter().map(|_| quote! { _ });
82
83                            match src_field {
84                                Some((_, field)) if is_option_type(&field.ty) => quote! { Self::#var_ident(#(#stubs,)* source) => source.as_deref() },
85
86                                Some(_) => quote! { Self::#var_ident(#(#stubs,)* source, ..) => Some(source) },
87
88                                None => quote! { Self::#var_ident(_) => None }
89                            }
90                        },
91
92                        // Self::Unit
93                        syn::Fields::Unit => quote! { Self::#var_ident => None }
94                    }
95                });
96            
97            quote! {
98                impl ::std::error::Error for #ident {
99                    fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
100                        match &self {
101                            #(
102                                #vars,
103                            )*
104                        }
105                    }
106                }
107            }.into()
108        },
109
110        _ => panic!("the expected a 'struct' or 'enum'")
111    }
112}
113
114// Do check if a type is "Option"
115fn is_option_type(ty: &syn::Type) -> bool {
116    if let syn::Type::Path(type_path) = ty {
117        type_path.path.segments.last().map_or(false, |seg| seg.ident == "Option")
118    } else {
119        false
120    }
121}