derive-enum-error 0.0.1

Derive macro for `std::error::Error`
Documentation
//! # `derive_enum_error`
//!
//! ## Deriving error sources
//!
//! Add an `#[error(source)]` attribute to the field:
//!
//! ```
//! use derive_enum_error::Error;
//!
//! use std::io;
//!
//! /// `MyError::source` will return a reference to the `io_error` field
//! #[derive(Debug, Error)]
//! #[error(display = "An error occurred.")]
//! struct MyError {
//!     #[error(source)]
//!     io_error: io::Error,
//! }
//! #
//! # fn main() {}
//! ```
//!
//! ## Formatting fields
//!
//! ```rust
//! use derive_enum_error::Error;
//!
//! use std::path::PathBuf;
//!
//! #[derive(Debug, Error)]
//! pub enum FormatError {
//!     #[error(display = "invalid header (expected: {:?}, got: {:?})", expected, found)]
//!     InvalidHeader {
//!         expected: String,
//!         found: String,
//!     },
//!     // Note that tuple fields need to be prefixed with `_`
//!     #[error(display = "missing attribute: {:?}", _0)]
//!     MissingAttribute(String),
//!
//! }
//!
//! #[derive(Debug, Error)]
//! pub enum LoadingError {
//!     #[error(display = "could not decode file")]
//!     FormatError(#[error(source)] FormatError),
//!     #[error(display = "could not find file: {:?}", path)]
//!     NotFound { path: PathBuf },
//! }
//! #
//! # fn main() {}
//! ```
//!
//! ## Printing the error
//!
//! ```
//! use std::error::Error;
//!
//! fn print_error(e: &dyn Error) {
//!     eprintln!("error: {}", e);
//!     let mut source = e.source();
//!     while let Some(e) = source {
//!         eprintln!("sourced by: {}", e);
//!         source = e.source();
//!     }
//! }
//! ```
//!
#![recursion_limit = "192"]

use proc_macro2::TokenStream;
use quote::quote_spanned;
use synstructure::decl_derive;

macro_rules! quote {
    ($($t:tt)*) => (quote_spanned!(proc_macro2::Span::call_site() => $($t)*))
}

decl_derive!([Error, attributes(error, source)] => error);

fn error(s: synstructure::Structure) -> TokenStream {
    let source_body = s.each_variant(|v| {
        if let Some(source) = v.bindings().iter().find(is_source) {
            quote!(return Some(#source as & ::std::error::Error))
        } else {
            quote!(return None)
        }
    });

    let source_method = quote! {
        #[allow(unreachable_code)]
        fn source(&self) -> ::std::option::Option<&(::std::error::Error + 'static)> {
            match *self { #source_body }
            None
        }
    };

    let error = s.unbound_impl(
        quote!(::std::error::Error),
        quote! {
            fn description(&self) -> &str {
                "description() is deprecated; use Display"
            }

            #source_method
        },
    );

    let display = display_body(&s).map(|display_body| {
        s.unbound_impl(
            quote!(::std::fmt::Display),
            quote! {
                #[allow(unreachable_code)]
                fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
                    match *self { #display_body }
                    write!(f, "An error has occurred.")
                }
            },
        )
    });

    (quote! {
        #error
        #display
    })
    .into()
}

fn display_body(s: &synstructure::Structure) -> Option<quote::__rt::TokenStream> {
    let mut msgs = s.variants().iter().map(|v| find_error_msg(&v.ast().attrs));
    if msgs.all(|msg| msg.is_none()) {
        return None;
    }

    Some(s.each_variant(|v| {
        let msg =
            find_error_msg(&v.ast().attrs).expect("All variants must have display attribute.");
        if msg.nested.is_empty() {
            panic!("Expected at least one argument to error attribute");
        }

        let format_string = match msg.nested[0] {
            syn::NestedMeta::Meta(syn::Meta::NameValue(ref nv)) if nv.ident == "display" => {
                nv.lit.clone()
            }
            _ => panic!(
                "Error attribute must begin `display = \"\"` to control the Display message."
            ),
        };
        let args = msg.nested.iter().skip(1).map(|arg| match *arg {
            syn::NestedMeta::Literal(syn::Lit::Int(ref i)) => {
                let bi = &v.bindings()[i.value() as usize];
                quote!(#bi)
            }
            syn::NestedMeta::Meta(syn::Meta::Word(ref id)) => {
                let id_s = id.to_string();
                if id_s.starts_with("_") {
                    if let Ok(idx) = id_s[1..].parse::<usize>() {
                        let bi = match v.bindings().get(idx) {
                            Some(bi) => bi,
                            None => {
                                panic!(
                                    "display attempted to access field `{}` in `{}::{}` which \
                                     does not exist (there are {} field{})",
                                    idx,
                                    s.ast().ident,
                                    v.ast().ident,
                                    v.bindings().len(),
                                    if v.bindings().len() != 1 { "s" } else { "" }
                                );
                            }
                        };
                        return quote!(#bi);
                    }
                }
                for bi in v.bindings() {
                    if bi.ast().ident.as_ref() == Some(id) {
                        return quote!(#bi);
                    }
                }
                panic!(
                    "Couldn't find field `{}` in `{}::{}`",
                    id,
                    s.ast().ident,
                    v.ast().ident
                );
            }
            _ => panic!("Invalid argument to error attribute!"),
        });

        quote! {
            return write!(f, #format_string #(, #args)*)
        }
    }))
}

fn find_error_msg(attrs: &[syn::Attribute]) -> Option<syn::MetaList> {
    let mut error_msg = None;
    for attr in attrs {
        if let Some(meta) = attr.interpret_meta() {
            if meta.name() == "error" {
                if error_msg.is_some() {
                    panic!("Cannot have two display attributes")
                } else {
                    if let syn::Meta::List(list) = meta {
                        error_msg = Some(list);
                    } else {
                        panic!("error attribute must take a list in parentheses")
                    }
                }
            }
        }
    }
    error_msg
}

fn is_source(bi: &&synstructure::BindingInfo) -> bool {
    let mut found_source = false;
    for attr in &bi.ast().attrs {
        if let Some(meta) = attr.interpret_meta() {
            if meta.name() == "source" {
                if found_source {
                    panic!("Cannot have two `source` attributes");
                }
                found_source = true;
            }
            if meta.name() == "error" {
                if let syn::Meta::List(ref list) = meta {
                    if let Some(ref pair) = list.nested.first() {
                        if let &&syn::NestedMeta::Meta(syn::Meta::Word(ref word)) = pair.value() {
                            if word == "source" {
                                if found_source {
                                    panic!("Cannot have two `source` attributes");
                                }
                                found_source = true;
                            }
                        }
                    }
                }
            }
        }
    }
    found_source
}