extern crate proc_macro;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, Fields};
#[proc_macro_derive(
ModError,
attributes(
error_prefix,
error_display,
error_kind,
error_caption,
error_retryable,
error_http_status,
error_exit_code,
error_fatal
)
)]
pub fn derive_mod_error(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let is_enum = match &input.data {
Data::Enum(_) => true,
Data::Struct(_) => false,
Data::Union(_) => panic!("ModError cannot be derived for unions"),
};
let error_prefix = get_error_prefix(&input.attrs);
let implementation = if is_enum {
implement_for_enum(&input, &error_prefix)
} else {
implement_for_struct(&input, &error_prefix)
};
TokenStream::from(implementation)
}
fn get_error_prefix(attrs: &[syn::Attribute]) -> String {
for attr in attrs {
if attr.path.is_ident("error_prefix") {
if let Some(value) = parse_string_attribute(attr) {
return value;
}
}
}
String::new()
}
fn parse_string_attribute(attr: &syn::Attribute) -> Option<String> {
match attr.parse_meta().ok()? {
syn::Meta::NameValue(meta) => match meta.lit {
syn::Lit::Str(lit) => Some(lit.value()),
_ => None,
},
syn::Meta::List(meta) => match meta.nested.iter().next() {
Some(syn::NestedMeta::Lit(syn::Lit::Str(lit))) => Some(lit.value()),
_ => None,
},
syn::Meta::Path(_) => None,
}
}
fn parse_int_attribute<T>(attr: &syn::Attribute) -> Option<T>
where
T: std::str::FromStr,
T::Err: std::fmt::Display,
{
match attr.parse_meta().ok()? {
syn::Meta::NameValue(meta) => match meta.lit {
syn::Lit::Int(lit) => lit.base10_parse().ok(),
_ => None,
},
syn::Meta::List(meta) => match meta.nested.iter().next() {
Some(syn::NestedMeta::Lit(syn::Lit::Int(lit))) => lit.base10_parse().ok(),
_ => None,
},
syn::Meta::Path(_) => None,
}
}
fn has_flag_attribute(attr: &syn::Attribute, name: &str) -> bool {
attr.path.is_ident(name)
}
fn implement_for_enum(input: &DeriveInput, error_prefix: &str) -> proc_macro2::TokenStream {
let name = &input.ident;
let data_enum = match &input.data {
Data::Enum(data) => data,
_ => panic!("Expected enum"),
};
let mut kind_match_arms = Vec::new();
let mut caption_match_arms = Vec::new();
let mut display_match_arms = Vec::new();
let mut retryable_match_arms = Vec::new();
let mut fatal_match_arms = Vec::new();
let mut status_code_match_arms = Vec::new();
let mut exit_code_match_arms = Vec::new();
for variant in &data_enum.variants {
let variant_name = &variant.ident;
let variant_name_str = variant_name.to_string();
let mut display_format = variant_name_str.clone();
let mut kind_name = variant_name_str.clone();
let mut caption = format!("{}: Error", error_prefix);
let mut retryable = false;
let mut fatal = false;
let mut status_code: u16 = 500;
let mut exit_code: i32 = 1;
for attr in &variant.attrs {
if attr.path.is_ident("error_display") {
if let Some(value) = parse_string_attribute(attr) {
display_format = value;
}
} else if attr.path.is_ident("error_kind") {
if let Some(value) = parse_string_attribute(attr) {
kind_name = value;
}
} else if attr.path.is_ident("error_caption") {
if let Some(value) = parse_string_attribute(attr) {
caption = value;
}
} else if attr.path.is_ident("error_retryable") {
retryable = true;
} else if has_flag_attribute(attr, "error_fatal") {
fatal = true;
} else if attr.path.is_ident("error_http_status") {
if let Some(value) = parse_int_attribute(attr) {
status_code = value;
}
} else if attr.path.is_ident("error_exit_code") {
if let Some(value) = parse_int_attribute(attr) {
exit_code = value;
}
}
}
match &variant.fields {
Fields::Named(fields) => {
let field_names: Vec<_> = fields
.named
.iter()
.map(|f| f.ident.as_ref().unwrap())
.collect();
kind_match_arms.push(quote! {
Self::#variant_name { .. } => #kind_name
});
caption_match_arms.push(quote! {
Self::#variant_name { .. } => #caption
});
display_match_arms.push(quote! {
Self::#variant_name { #(#field_names),* } => format!(#display_format, #(#field_names = #field_names),*)
});
retryable_match_arms.push(quote! {
Self::#variant_name { .. } => #retryable
});
fatal_match_arms.push(quote! {
Self::#variant_name { .. } => #fatal
});
status_code_match_arms.push(quote! {
Self::#variant_name { .. } => #status_code
});
exit_code_match_arms.push(quote! {
Self::#variant_name { .. } => #exit_code
});
}
Fields::Unnamed(fields) => {
let field_count = fields.unnamed.len();
let field_names: Vec<_> =
(0..field_count).map(|i| format_ident!("_{}", i)).collect();
kind_match_arms.push(quote! {
Self::#variant_name(..) => #kind_name
});
caption_match_arms.push(quote! {
Self::#variant_name(..) => #caption
});
let field_pattern_list = field_names.iter().map(|name| quote! { #name, });
display_match_arms.push(quote! {
Self::#variant_name(#(#field_pattern_list)*) => format!(#display_format #(, #field_names)*)
});
retryable_match_arms.push(quote! {
Self::#variant_name(..) => #retryable
});
fatal_match_arms.push(quote! {
Self::#variant_name(..) => #fatal
});
status_code_match_arms.push(quote! {
Self::#variant_name(..) => #status_code
});
exit_code_match_arms.push(quote! {
Self::#variant_name(..) => #exit_code
});
}
Fields::Unit => {
kind_match_arms.push(quote! {
Self::#variant_name => #kind_name
});
caption_match_arms.push(quote! {
Self::#variant_name => #caption
});
display_match_arms.push(quote! {
Self::#variant_name => #display_format.to_string()
});
retryable_match_arms.push(quote! {
Self::#variant_name => #retryable
});
fatal_match_arms.push(quote! {
Self::#variant_name => #fatal
});
status_code_match_arms.push(quote! {
Self::#variant_name => #status_code
});
exit_code_match_arms.push(quote! {
Self::#variant_name => #exit_code
});
}
}
}
quote! {
impl ::std::fmt::Display for #name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
let msg = match self {
#(#display_match_arms,)*
};
write!(f, "{}", msg)
}
}
impl ::error_forge::error::ForgeError for #name {
fn kind(&self) -> &'static str {
match self {
#(#kind_match_arms,)*
}
}
fn caption(&self) -> &'static str {
match self {
#(#caption_match_arms,)*
}
}
fn is_retryable(&self) -> bool {
match self {
#(#retryable_match_arms,)*
}
}
fn is_fatal(&self) -> bool {
match self {
#(#fatal_match_arms,)*
}
}
fn status_code(&self) -> u16 {
match self {
#(#status_code_match_arms,)*
}
}
fn exit_code(&self) -> i32 {
match self {
#(#exit_code_match_arms,)*
}
}
}
impl ::std::error::Error for #name {
fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
None
}
}
}
}
fn implement_for_struct(input: &DeriveInput, error_prefix: &str) -> proc_macro2::TokenStream {
let name = &input.ident;
let name_str = name.to_string();
quote! {
impl ::std::fmt::Display for #name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
write!(f, "{}: Error", #error_prefix)
}
}
impl ::error_forge::error::ForgeError for #name {
fn kind(&self) -> &'static str {
#name_str
}
fn caption(&self) -> &'static str {
concat!(#error_prefix, ": Error")
}
}
impl ::std::error::Error for #name {
fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
None
}
}
}
}