combination_err/
lib.rs

1#![recursion_limit="128"]
2
3extern crate proc_macro;
4extern crate proc_macro2;
5extern crate quote;
6extern crate syn;
7
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::parse::{ Parse, ParseStream };
11use syn::punctuated::Punctuated;
12use syn::{parse_macro_input, Token};
13use syn::spanned::Spanned;
14
15struct CombinationErrorDescriptions {
16    descriptions: Punctuated<syn::LitStr, Token![,]>,
17}
18
19impl CombinationErrorDescriptions {
20    fn variants_len(&self) -> usize {
21        match self.descriptions.len() {
22            0 => 0,
23            l => l - 1,
24        }
25    }
26}
27
28impl Parse for CombinationErrorDescriptions {
29    fn parse(input: ParseStream) -> syn::Result<Self> {
30        Ok(CombinationErrorDescriptions {
31            descriptions: input.parse_terminated(<syn::LitStr as Parse>::parse)?,
32        })
33    }
34}
35
36#[proc_macro_attribute]
37pub fn combination_err(attribute: TokenStream, item: TokenStream) -> TokenStream {
38    let descriptions = parse_macro_input!(attribute as CombinationErrorDescriptions);
39    let enum_input = parse_macro_input!(item as syn::ItemEnum);
40
41    if descriptions.variants_len() != enum_input.variants.len() {
42        let s = format!("Number of descriptions ({}) does not match number of enum variants ({})", descriptions.variants_len(), enum_input.variants.len());
43        panic!(syn::Error::new(descriptions.descriptions.span(), s));
44    }
45
46    let enum_ident = &enum_input.ident;
47    let enum_generics = &enum_input.generics;
48    let enum_description = {
49        descriptions.descriptions.iter().next().unwrap()
50    };
51    let source_arms = enum_input.variants.iter()
52        .map(|variant| {
53            let variant_ident = &variant.ident;
54            quote! {
55                &#enum_ident::#variant_ident(ref e) => Some(e as &(dyn std::error::Error + 'static)),
56            }
57        });
58    let source_description_arms = descriptions.descriptions.iter()
59        .skip(1)
60        .zip(enum_input.variants.iter())
61        .map(|(description, variant)| {
62            let variant_ident = &variant.ident;
63            quote! {
64                &#enum_ident::#variant_ident(..) => #description,
65            }
66        });
67    let source_description = quote! {
68        match self {
69            #(#source_description_arms)*
70        }
71    };
72
73    let from_implementations = enum_input.variants.iter()
74        .map(|variant| {
75            let variant_ty = variant.fields.iter()
76                .next()
77                .map(|field| field.ty.clone())
78                .unwrap();
79            let variant_ident = &variant.ident;
80            quote! {
81                impl#enum_generics From<#variant_ty> for #enum_ident#enum_generics {
82                    fn from(e: #variant_ty) -> Self {
83                        #enum_ident::#variant_ident(e)
84                    }
85                }
86            }
87        });
88    TokenStream::from(quote! {
89        #enum_input
90
91        impl#enum_generics std::error::Error for #enum_ident#enum_generics {
92            fn description(&self) -> &str {
93                #enum_description
94            }
95
96            fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
97                match self {
98                    #(#source_arms)*
99                }
100            }
101        }
102
103        impl#enum_generics std::fmt::Display for #enum_ident#enum_generics {
104            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105                use std::error::Error;
106                let description: &str = #enum_description;
107                write!(f, "{}", description)?;
108                if let Some(src) = self.source() {
109                    let src_desc = #source_description;
110                    write!(f, ": {}: {}", src_desc, src)?;
111                }
112                Ok(())
113            }
114        }
115
116        #(#from_implementations)*
117    })
118}