errer_derive 0.13.1

Flexible error management for Rust. An middle-ground between failure and SNAFU
Documentation
extern crate errer;

extern crate proc_macro;
extern crate proc_macro2;
extern crate syn;
extern crate quote;

use std::string::ToString;

use proc_macro2::*;
use quote::{quote, ToTokens};
use syn::{ punctuated::Punctuated, parenthesized, spanned::Spanned, parse::{Parse, Parser, ParseStream, ParseBuffer},
    DeriveInput, Data, Field, Fields,
    Token, Member, Ident, Index, Attribute};

const FROM_DEFAULT: &str = "from";

fn idx_to_ident(i: u32) -> Ident {
    Ident::new(&format!("idx_{}", i), Span::call_site())
}

fn unique_name(ident: &Ident) -> Ident {
    Ident::new(&format!("field_{}", ident), ident.span())
}

enum ErrerAttrib {
    //implements from<membertype> if there are no other fields
    From(Option<Member>),
    //makes a struct of the variant without the from field (requiring from). then implements contexterror<Error=from_membertype, Target=Self> for said struct
    //this can be used in patterns to specify context: ex. io::Error::new().context(CreatingUser(id)), where from=id
    //for structs: makes another struct without from field
    Context(Option<Ident>),
    //implements display
    Display(syn::LitStr, Option<Token![,]>, Punctuated<Member, Token![,]>),
    //implements ErrerCompat which implements std::error::Error if Self: Display
    Std
}

struct ErrerAttribSpanned {
    span: Span,
    attrib: ErrerAttrib
}

trait SpanExt {
    fn error<M: std::fmt::Display>(self, msg: M) -> syn::Error;
}

impl SpanExt for Span {
    fn error<M: std::fmt::Display>(self, msg: M) -> syn::Error {
        syn::Error::new(self, msg)
    }
}

fn parse_err_attrib(input: ParseStream, outer: bool, is_struct: bool) -> syn::Result<ErrerAttribSpanned> {
    //x (= y)
    let start = input.cursor().span();
    let id = input.parse::<Ident>()?;
    let eq = input.peek(Token![=]);
    
    if eq {
        input.parse::<Token![=]>()?;
    }

    let spanned = |start: Span, attrib: ErrerAttrib| -> ErrerAttribSpanned {
        ErrerAttribSpanned {
            attrib, span: start
        }
    };

    match &*id.to_string() {
        "context" => {
            if !outer && is_struct {
                Err(start.error("context can only be applied on the outside of a struct"))
            } else if eq && outer && !is_struct {
                Err(start.error("context as an outer enum attribute does not use a name. try applying this on the inside of the enum instead"))
            } else if eq {
                Ok(spanned(start, ErrerAttrib::Context(Some(input.parse()?))))
            } else {
                Ok(spanned(start, ErrerAttrib::Context(None)))
            }
        },
        "from" => {
            if outer && is_struct {
                Err(start.error("outer from on structs are not used. use from on a property instead"))
            } else if (outer || is_struct) && eq {
                Err(start.error("from members are not permitted in outer attributes or struct properties. try using this on a variant instead"))
            } else if eq {
                Ok(spanned(start, ErrerAttrib::From(Some(input.parse()?))))
            } else {
                Ok(spanned(start, ErrerAttrib::From(None)))
            }
        },
        "display" => {
            if !eq {
                Err(start.error("a format string must be provided"))
            } else {
                let lstr = input.parse::<syn::LitStr>()?;
                //skip comma after str if there is more
                let comma = { if !input.is_empty() { Some(input.parse::<Token![,]>()?) } else { None } };
                let members = input.parse_terminated::<Member, Token![,]>(Member::parse)?;

                Ok(spanned(start, ErrerAttrib::Display(lstr, comma, members)))
            }
        },
        "std" if !eq => {
            Ok(spanned(start, ErrerAttrib::Std))
        },
        _ => Err(start.error(format!("unrecognized attribute {}", id)))
    }
}

fn parse_err_attribs(x: Attribute, outer: bool, is_struct: bool) -> syn::Result<Vec<ErrerAttribSpanned>> {
    if x.path.is_ident("errer") {
        Parser::parse2(|pstream: ParseStream|{
            let inside: ParseBuffer;
            parenthesized!(inside in pstream);
            
            let mut attribs = Vec::new();

            if !inside.is_empty() {
                loop {
                    attribs.push(parse_err_attrib(&inside, outer, is_struct)?);

                    if inside.is_empty() {
                        break;
                    } else {
                        inside.parse::<Token![,]>()?;
                    }
                }
            }

            Ok(attribs)
        }, x.tts)
    } else {
        Ok(Vec::new())
    }
}

fn parse_err_attribs_vec(x: Vec<Attribute>, outer: bool, is_struct: bool) -> syn::Result<Vec<ErrerAttribSpanned>> {
    let mut errer_attrs = Vec::new();
    for attr in x { errer_attrs.append(&mut parse_err_attribs(attr, outer, is_struct)?); }

    Ok(errer_attrs)
}

fn member_ident(span: &Span, member: Member) -> syn::Result<Ident> {
    if let Member::Named(i) = member { Ok(i) }
        else { Err(span.error("expected identifier")) }
}

fn member_idx(span: &Span, member: Member) -> syn::Result<Index> {
    if let Member::Unnamed(i) = member { Ok(i) }
        else { Err(span.error("expected index")) }
}

type IndexedField<'a> = (&'a Field, Option<usize>);

fn get_primary_field<'a>(span: &Span, mem: Option<Member>, default_ident: &str, f: &'a Fields) -> syn::Result<IndexedField<'a>> {
    let field = match f {
        Fields::Named(nf) => {
            let mem_name = mem.map(|x| member_ident(&span, x)).transpose()?;
            
            nf.named.iter().find(|x|
                match &mem_name {
                    Some(name) => x.ident.as_ref().unwrap() == name,
                    None => x.ident.as_ref().unwrap() == default_ident
                })
                .map(|x| (x, None))
        },
        Fields::Unnamed(uf) => {
            let mem_idx = mem.map(|x| member_idx(&span, x)).transpose()?;
            
            match mem_idx {
                Some(idx) => {
                    let idx = idx.index as usize;
                    uf.unnamed.iter().enumerate().find(|(i, _)| *i == idx).map(|(i, x)| (x, Some(i)))
                },
                None => uf.unnamed.iter().next().map(|x| (x, Some(0)))
            }
        },
        Fields::Unit =>
            return Err(span.error("cannot get field from unit enum/struct"))
    };
    
    field.ok_or_else(|| span.error("field not found"))
}

fn get_fields_len(f: &Fields) -> usize {
    match f {
        Fields::Named(nf) => nf.named.len(),
        Fields::Unnamed(uf) => uf.unnamed.len(),
        Fields::Unit => 0
    }
}

fn extractor_variant_field(fields: &Fields, field: IndexedField) -> TokenStream {
    match &fields {
        Fields::Named(_) => {
            let f_name = field.0.ident.as_ref().unwrap();
            
            quote! ( { #f_name: x, .. })
        },
        Fields::Unnamed(uf) => {
            let from_i = field.1.unwrap();
            
            let extractor = (0..uf.unnamed.len()).into_iter().map(|i| {
                if i == from_i {
                    quote!(x)
                } else {
                    quote!(_)
                }
            });

            quote! ( (#(#extractor),*) )
        },
        Fields::Unit => unreachable!()
    }
}

fn make_default(vec: &mut Vec<TokenStream>, default: TokenStream, len: usize) {
    if vec.len() < len {
        vec.push(quote! ( _ => #default ));
    }
}

fn make_context_struct(fields: &Fields, field: IndexedField) -> (TokenStream, TokenStream) {
    let (field, field_i) = field;

    let context_fields: Vec<&Field> = {
        match fields {
            Fields::Named(nf) => {
                let field_i = field.ident.as_ref().unwrap();
                
                nf.named.iter().filter(|x| x.ident.as_ref().unwrap() != field_i).collect()
            },
            Fields::Unnamed(uf) => {
                uf.unnamed.iter().enumerate()
                    .filter_map(|(i, x)| if i == field_i.unwrap() { None } else { Some(x) })
                    .collect()
            },
            _ => unreachable!() //from requires a field.....
        }
    };
    
    let ctx_construct = {
        match fields {
            Fields::Named(_) => {
                let context_fields_ident = context_fields.iter().map(|f| f.ident.as_ref().unwrap());
                let context_fields_ident_cloned = context_fields_ident.clone();

                quote! ( { #field_i: ctx, #(#context_fields_ident: self.#context_fields_ident_cloned),* } )
            },
            Fields::Unnamed(uf) => {
                let mut i = 0; //keep track of state separate from context state cuz ``from`` shifts the index
                
                let nums = (0..uf.unnamed.len()).into_iter().map(|num| {
                    if num == field_i.unwrap() {
                        quote! (ctx)
                    } else {
                        let idx = syn::Index::from(i);
                        i += 1;

                        quote! (self.#idx)
                    }
                });

                quote! ( ( #(#nums),* ) )
            },
            _ => unreachable!()
        }
    };
    
    let ctx_struct = {
        if context_fields.len() == 0 {
            quote!(;)
        } else {
            match fields {
                Fields::Named(_) =>
                    quote! ( { #(#context_fields),* } ),
                Fields::Unnamed(_) =>
                    quote! ( ( #(#context_fields),* ); ),
                
                _ => unreachable!()
            }
        }
    };

    (ctx_struct, ctx_construct)
}

fn derive_errer_res(input: proc_macro::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
    let input: DeriveInput = Parser::parse(DeriveInput::parse, input)?;
    let mut stream = TokenStream::new();

    let is_struct;
    match &input.data {
        Data::Enum(_) => is_struct = false,
        Data::Struct(_) => is_struct = true,
        Data::Union(_) => return Err(input.span().error("Errer cannot be derived on unions"))
    }

    let vis = input.vis;
    let name = input.ident;
    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

    //parse outer attributes
    let mut display_overarching = None;
    let mut from = false;
    let mut context = None;
    let mut std = false;

    {
        for ErrerAttribSpanned {span, attrib} in parse_err_attribs_vec(input.attrs, true, is_struct)? {
            match attrib {
                ErrerAttrib::Display(s, comma, members) =>
                    display_overarching = Some((span, s, comma, members)),
                ErrerAttrib::From(None) => from = true,
                ErrerAttrib::Context(id) => context = Some((span, id)),
                ErrerAttrib::Std => std = true,

                _ => ()
            }
        }
    }

    let mut display_branches = Vec::new();
    let mut from_branches = Vec::new();

    match input.data {
        Data::Enum(x) => {
            let v_len = x.variants.len();

            //parse inner attributes and implement display, from, and store context fields
            for variant in x.variants.into_iter() {
                let v_span = variant.span();
                let (attrs, fields, variant_name) = (variant.attrs, variant.fields, variant.ident);
                let variant_path = quote! { #name::#variant_name }; //resolve name

                let mut from =
                    if from { get_primary_field(&v_span, None, FROM_DEFAULT, &fields).ok() } else { None };
                let mut context = context.as_ref().map(|x| (x.0, None)); //since enums do not support context ids

                for ErrerAttribSpanned {span, attrib} in parse_err_attribs_vec(attrs, false, false)? {
                    match attrib {
                        ErrerAttrib::Display(s, comma, members) => {
                            match &fields {
                                Fields::Named(_) => {
                                    let mut members_i = Vec::new();
                                    for x in members { members_i.push(member_ident(&span, x)?) }

                                    let mut members_i_unique = Vec::new();
                                    for x in &members_i { members_i_unique.push(unique_name(&x)); }
                                    let members_i_unique_ref = &members_i_unique;
                                    
                                    display_branches.push(quote! { #variant_path { #(#members_i: #members_i_unique_ref,)* .. } => write!(x, #s #comma #(#members_i_unique),*), });
                                },
                                Fields::Unnamed(uf) => {
                                    let mut members_i = Vec::new();
                                    for x in members { members_i.push(member_idx(&span, x)?.index) }

                                    let mut names_needed = Vec::new();

                                    let mut names = TokenStream::new();
                                    uf.unnamed.iter().enumerate().for_each(|(i, _)| {
                                        let i = i as u32;
                                        
                                        if members_i.contains(&i) {
                                            let ident = idx_to_ident(i);
                                            names_needed.push(ident.clone());

                                            ident.to_tokens(&mut names);
                                        } else {
                                            Token![_](Span::call_site()).to_tokens(&mut names);
                                        }

                                        Token![,](Span::call_site()).to_tokens(&mut names);
                                    });

                                    display_branches.push(quote! { #variant_path ( #names ) => write!(x, #s #comma #(#names_needed),*), });
                                },
                                Fields::Unit =>
                                    display_branches.push(quote! { #variant_path => write!(x, #s #comma #(#members),*), })
                            }
                        },

                        ErrerAttrib::From(mem) =>
                            from = Some(get_primary_field(&span, mem, FROM_DEFAULT, &fields)?),
                        
                        ErrerAttrib::Context(id) => context = Some((span, id)),
                        
                        _ => ()
                    }
                }

                if std {
                    for x in from {
                        let extractor = extractor_variant_field(&fields, x);
                        from_branches.push(quote!(#variant_path #extractor => Some(x), ));
                    }
                }

                if let Some((Field {ident, ty, ..}, _)) = &from {
                    if get_fields_len(&fields) == 1 && context.is_none() {
                        let constructor = match &ident {
                            Some(f_name) => quote! ( #variant_path { #f_name: x } ),
                            None => quote! ( #variant_path (x) )
                        };


                        stream.extend(quote! {
                            impl #impl_generics std::convert::From<#ty> for #name #ty_generics #where_clause {
                                fn from(x: #ty) -> Self {
                                    #constructor
                                }
                            }
                        });
                    }
                }

                if let Some((span, id)) = context {
                    if let Some(from) = from {
                        let (ctx_struct, ctx_constructor) = make_context_struct(&fields, from);
                        let ctx_name = id.as_ref().unwrap_or(&variant_name);
                        let ty = &from.0.ty;

                        stream.extend(quote! {
                            #vis struct #ctx_name #ctx_struct

                            impl #impl_generics errer::IntoErrorContext<#ty, #name #ty_generics> for #ctx_name #where_clause {
                                fn into_target(self, ctx: #ty) -> #name #ty_generics {
                                    #variant_path #ctx_constructor
                                }
                            }
                        });
                    } else {
                        return Err(span.error("context depends on from; from is not specified"));
                    }
                }
            }

            let cause = {
                if from_branches.len() > 0 {
                    make_default(&mut from_branches, quote!(None), v_len);

                    quote! {
                        match self {
                            #(#from_branches)*
                        }
                    }

                } else { quote!(None) }
            };

            stream.extend(quote! {
                impl #impl_generics errer::ErrorCompat for #name #ty_generics #where_clause {
                    fn error_source(&self) -> Option<&(dyn std::error::Error + 'static)> {
                        #cause
                    }
                }
            });

            if display_branches.len() > 0 {
                make_default(&mut display_branches, quote!( Ok(()) ), v_len);

                let writer = quote! {
                    match self {
                        #(#display_branches)*
                    }
                };

                let impl_disp = match display_overarching {
                    Some((span, s, _, members)) => {
                        if members.len() > 0 {
                            return Err(span.error("no extra formatting arguments are permitted in an outer enum display attribute"));
                        }

                        quote! {
                            impl #impl_generics std::fmt::Display for #name #ty_generics #where_clause {
                                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                                    use std::fmt::Write;

                                    let mut x = String::new();
                                    #writer?;

                                    write!(f, #s, x)
                                }
                            }
                        }
                    },
                    None => quote! {
                        impl #impl_generics std::fmt::Display for #name #ty_generics #where_clause {
                            fn fmt(&self, x: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                                #writer
                            }
                        }
                    }
                };

                stream.extend(impl_disp);
            }
        },
        Data::Struct(x) => {
            let mut from = None;

            let f_len = get_fields_len(&x.fields);
            for (i, field) in x.fields.iter().enumerate() {
                let Field {attrs, ..} = field;
                let mut set_from = false;

                for ErrerAttribSpanned {attrib, ..} in parse_err_attribs_vec(attrs.clone(), false, true)? {
                    match attrib {
                        ErrerAttrib::From(None) => set_from = true,
                        _ => ()
                    }
                }

                if set_from { from = Some((field, i)); }
            }

            let mut cause = quote!( None );
            if let Some((Field {ident, ty, ..}, i)) = from {
                if f_len == 1 && context.is_none() {
                    let constructor = match &ident {
                        Some(f_name) => quote! ( #name { #f_name: x } ),
                        None => quote! ( #name (x) )
                    };


                    stream.extend(quote! {
                        impl #impl_generics std::convert::From<#ty> for #name #ty_generics #where_clause {
                            fn from(x: #ty) -> Self {
                                #constructor
                            }
                        }
                    });
                }

                if std {
                    let mut getter = TokenStream::new();
                    match &ident {
                        Some(f_name) => f_name.to_tokens(&mut getter),
                        None => syn::Index::from(i).to_tokens(&mut getter)
                    }

                    cause = quote! ( Some(&self.#getter) );
                }
            }

            stream.extend(quote! {
                impl #impl_generics errer::ErrorCompat for #name #ty_generics #where_clause {
                    fn error_source(&self) -> Option<&(dyn std::error::Error + 'static)> {
                        #cause
                    }
                }
            });

            if let Some((span, id)) = context {
                if let Some(from) = &from {
                    let (ctx_struct, ctx_constructor) = make_context_struct(&x.fields, (from.0, Some(from.1)));
                    let ctx_name = id.unwrap_or_else(|| Ident::new(&format!("{}Context", name), Span::call_site()));
                    let ty = &from.0.ty;

                    stream.extend(quote! {
                            #vis struct #ctx_name #ctx_struct

                            impl #impl_generics errer::IntoErrorContext<#ty, #name #ty_generics> for #ctx_name #where_clause {
                                fn into_target(self, ctx: #ty) -> #name #ty_generics {
                                    #name #ctx_constructor
                                }
                            }
                        });
                } else {
                    return Err(span.error("context depends on from; from is not specified"));
                }
            }

            if let Some((_, s, comma, members)) = display_overarching {
                let members = members.iter();

                stream.extend(quote! {
                   impl #impl_generics std::fmt::Display for #name #ty_generics #where_clause {
                        fn fmt(&self, x: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                            write!(x, #s #comma #(self.#members),*)
                        }
                    }
                });
            }
        },
        _ => ()
    }

    Ok(stream)
}

#[proc_macro_derive(Errer, attributes(errer))]
pub fn derive_errer(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    match derive_errer_res(input) {
        Ok(x) => proc_macro::TokenStream::from(x),
        Err(x) => proc_macro::TokenStream::from(x.to_compile_error())
    }
}