1#![recursion_limit="128"]
2
3extern crate proc_macro;
4extern crate proc_macro2;
5extern crate quote;
6extern crate syn;
7
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::parse::{ Parse, ParseStream };
11use syn::punctuated::Punctuated;
12use syn::{parse_macro_input, Token};
13use syn::spanned::Spanned;
14
15struct CombinationErrorDescriptions {
16 descriptions: Punctuated<syn::LitStr, Token![,]>,
17}
18
19impl CombinationErrorDescriptions {
20 fn variants_len(&self) -> usize {
21 match self.descriptions.len() {
22 0 => 0,
23 l => l - 1,
24 }
25 }
26}
27
28impl Parse for CombinationErrorDescriptions {
29 fn parse(input: ParseStream) -> syn::Result<Self> {
30 Ok(CombinationErrorDescriptions {
31 descriptions: input.parse_terminated(<syn::LitStr as Parse>::parse)?,
32 })
33 }
34}
35
36#[proc_macro_attribute]
37pub fn combination_err(attribute: TokenStream, item: TokenStream) -> TokenStream {
38 let descriptions = parse_macro_input!(attribute as CombinationErrorDescriptions);
39 let enum_input = parse_macro_input!(item as syn::ItemEnum);
40
41 if descriptions.variants_len() != enum_input.variants.len() {
42 let s = format!("Number of descriptions ({}) does not match number of enum variants ({})", descriptions.variants_len(), enum_input.variants.len());
43 panic!(syn::Error::new(descriptions.descriptions.span(), s));
44 }
45
46 let enum_ident = &enum_input.ident;
47 let enum_generics = &enum_input.generics;
48 let enum_description = {
49 descriptions.descriptions.iter().next().unwrap()
50 };
51 let source_arms = enum_input.variants.iter()
52 .map(|variant| {
53 let variant_ident = &variant.ident;
54 quote! {
55 &#enum_ident::#variant_ident(ref e) => Some(e as &(dyn std::error::Error + 'static)),
56 }
57 });
58 let source_description_arms = descriptions.descriptions.iter()
59 .skip(1)
60 .zip(enum_input.variants.iter())
61 .map(|(description, variant)| {
62 let variant_ident = &variant.ident;
63 quote! {
64 &#enum_ident::#variant_ident(..) => #description,
65 }
66 });
67 let source_description = quote! {
68 match self {
69 #(#source_description_arms)*
70 }
71 };
72
73 let from_implementations = enum_input.variants.iter()
74 .map(|variant| {
75 let variant_ty = variant.fields.iter()
76 .next()
77 .map(|field| field.ty.clone())
78 .unwrap();
79 let variant_ident = &variant.ident;
80 quote! {
81 impl#enum_generics From<#variant_ty> for #enum_ident#enum_generics {
82 fn from(e: #variant_ty) -> Self {
83 #enum_ident::#variant_ident(e)
84 }
85 }
86 }
87 });
88 TokenStream::from(quote! {
89 #enum_input
90
91 impl#enum_generics std::error::Error for #enum_ident#enum_generics {
92 fn description(&self) -> &str {
93 #enum_description
94 }
95
96 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
97 match self {
98 #(#source_arms)*
99 }
100 }
101 }
102
103 impl#enum_generics std::fmt::Display for #enum_ident#enum_generics {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 use std::error::Error;
106 let description: &str = #enum_description;
107 write!(f, "{}", description)?;
108 if let Some(src) = self.source() {
109 let src_desc = #source_description;
110 write!(f, ": {}: {}", src_desc, src)?;
111 }
112 Ok(())
113 }
114 }
115
116 #(#from_implementations)*
117 })
118}