macron_impl_error/
lib.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2
3//! See the documentation here [macron documentation](https://docs.rs/macron)
4
5use proc_macro::TokenStream;
6use quote::quote;
7
8/// The implementation of trait [Error](std::error::Error)
9#[proc_macro_derive(Error, attributes(source))]
10pub fn impl_error(input: TokenStream) -> TokenStream {
11    let syn::DeriveInput { ident, data, .. } = syn::parse_macro_input!(input as syn::DeriveInput);
12    
13    match data {
14        // impl Struct:
15        syn::Data::Struct(st) => {
16            let source = match st.fields {
17                syn::Fields::Named(fields) => {
18                    let src_field = fields.named
19                        .iter()
20                        .find(|f| f.ident
21                            .as_ref()
22                            .map(|i| i == "source")
23                            .unwrap_or(false)
24                        );
25            
26                    match src_field {
27                        Some(field) if is_option_type(&field.ty) => quote! { self.source.as_deref() },
28                        Some(_) => quote! { Some(self.source) },
29                        None => quote! { None }
30                    }
31                },
32
33                _ => quote! { None }
34            };
35            
36            quote! {
37                impl ::std::error::Error for #ident {
38                    fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
39                        #source
40                    }
41                }
42            }.into()
43        },
44
45        // impl Enum:
46        syn::Data::Enum(en) => {
47            let vars = en.variants
48                .into_iter()
49                .map(|syn::Variant { ident: var_ident, fields, .. }| {
50                    match fields {
51                        // Fields::Named { ..: .. }
52                        syn::Fields::Named(fields) => {
53                            let src_field = fields.named
54                                .iter()
55                                .find(|f| f.ident
56                                    .as_ref()
57                                    .map(|i| i == "source")
58                                    .unwrap_or(false)
59                                );
60                    
61                            match src_field {
62                                Some(field) if is_option_type(&field.ty) => quote! { Self::#var_ident { source, .. } => source.as_deref() },
63                                Some(_) => quote! { Self::#var_ident { source, .. } => Some(source) },
64                                None => quote! { Self::#var_ident { .. } => None }
65                            }
66                        },
67
68                        // Fields::Unnamed(.., ..)
69                        syn::Fields::Unnamed(fields) => {
70                             let src_field = fields.unnamed
71                                .iter()
72                                .enumerate()
73                                .find(|(_, field)| field.attrs
74                                    .iter()
75                                    .any(|attr| attr.path().is_ident("source"))
76                                );
77
78                            let src_idx = src_field.map(|(idx, _)| idx).unwrap_or(0);
79                            let stubs = (0..src_idx).into_iter().map(|_| quote! { _ });
80
81                            match src_field {
82                                Some((_, field)) if is_option_type(&field.ty) => quote! { Self::#var_ident(#(#stubs,)* source, ..) => source.as_deref() },
83                                Some(_) => quote! { Self::#var_ident(#(#stubs,)* source, ..) => Some(source) },
84                                None => quote! { Self::#var_ident(..) => None }
85                            }
86                        },
87
88                        // Self::Unit
89                        syn::Fields::Unit => quote! { Self::#var_ident => None }
90                    }
91                });
92            
93            quote! {
94                impl ::std::error::Error for #ident {
95                    fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
96                        match &self {
97                            #(
98                                #vars,
99                            )*
100                        }
101                    }
102                }
103            }.into()
104        },
105
106        _ => panic!("the expected a 'struct' or 'enum'")
107    }
108}
109
110// Do check if a type is "Option"
111fn is_option_type(ty: &syn::Type) -> bool {
112    if let syn::Type::Path(type_path) = ty {
113        type_path.path.segments.last().map_or(false, |seg| seg.ident == "Option")
114    } else {
115        false
116    }
117}