extern crate errer;
extern crate proc_macro;
extern crate proc_macro2;
extern crate syn;
extern crate quote;
use std::string::ToString;
use proc_macro2::*;
use quote::{quote, ToTokens};
use syn::{ punctuated::Punctuated, parenthesized, spanned::Spanned, parse::{Parse, Parser, ParseStream, ParseBuffer},
DeriveInput, Data, Field, Fields,
Token, Member, Ident, Index, Attribute};
const FROM_DEFAULT: &str = "from";
fn idx_to_ident(i: u32) -> Ident {
Ident::new(&format!("idx_{}", i), Span::call_site())
}
fn unique_name(ident: &Ident) -> Ident {
Ident::new(&format!("field_{}", ident), ident.span())
}
enum ErrerAttrib {
From(Option<Member>),
Context,
Display(syn::LitStr, Option<Token![,]>, Punctuated<Member, Token![,]>),
Std
}
struct ErrerAttribSpanned {
span: Span,
attrib: ErrerAttrib
}
trait SpanExt {
fn error<M: std::fmt::Display>(self, msg: M) -> syn::Error;
}
impl SpanExt for Span {
fn error<M: std::fmt::Display>(self, msg: M) -> syn::Error {
syn::Error::new(self, msg)
}
}
fn parse_err_attrib(input: ParseStream, outer: bool, is_struct: bool) -> syn::Result<ErrerAttribSpanned> {
let start = input.cursor().span();
let id = input.parse::<Ident>()?;
let eq = input.peek(Token![=]);
if eq {
input.parse::<Token![=]>()?;
}
let spanned = |start: Span, attrib: ErrerAttrib| -> ErrerAttribSpanned {
ErrerAttribSpanned {
attrib, span: start
}
};
match &*id.to_string() {
"context" if !eq => {
Ok(spanned(start, ErrerAttrib::Context))
},
"from" => {
if outer && is_struct {
Err(start.error("outer from on structs are not used. use from on a property instead"))
} else if (outer || is_struct) && eq {
Err(start.error("from members are not permitted in outer attributes or struct properties. try using this on a variant instead"))
} else if eq {
Ok(spanned(start, ErrerAttrib::From(Some(input.parse::<Member>()?))))
} else {
Ok(spanned(start, ErrerAttrib::From(None)))
}
},
"display" => {
if !eq {
Err(start.error("a format string must be provided"))
} else {
let lstr = input.parse::<syn::LitStr>()?;
let comma = { if !input.is_empty() { Some(input.parse::<Token![,]>()?) } else { None } };
let members = input.parse_terminated::<Member, Token![,]>(Member::parse)?;
Ok(spanned(start, ErrerAttrib::Display(lstr, comma, members)))
}
},
"std" if !eq => {
Ok(spanned(start, ErrerAttrib::Std))
},
_ => Err(start.error(format!("unrecognized attribute {}", id)))
}
}
fn parse_err_attribs(x: Attribute, outer: bool, is_struct: bool) -> syn::Result<Vec<ErrerAttribSpanned>> {
if x.path.is_ident("errer") {
Parser::parse2(|pstream: ParseStream|{
let inside: ParseBuffer;
parenthesized!(inside in pstream);
let mut attribs = Vec::new();
if !inside.is_empty() {
loop {
attribs.push(parse_err_attrib(&inside, outer, is_struct)?);
if inside.is_empty() {
break;
} else {
inside.parse::<Token![,]>()?;
}
}
}
Ok(attribs)
}, x.tts)
} else {
Ok(Vec::new())
}
}
fn parse_err_attribs_vec(x: Vec<Attribute>, outer: bool, is_struct: bool) -> syn::Result<Vec<ErrerAttribSpanned>> {
let mut errer_attrs = Vec::new();
for attr in x { errer_attrs.append(&mut parse_err_attribs(attr, outer, is_struct)?); }
Ok(errer_attrs)
}
fn member_ident(span: &Span, member: Member) -> syn::Result<Ident> {
if let Member::Named(i) = member { Ok(i) }
else { Err(span.error("expected identifier")) }
}
fn member_idx(span: &Span, member: Member) -> syn::Result<Index> {
if let Member::Unnamed(i) = member { Ok(i) }
else { Err(span.error("expected index")) }
}
type IndexedField<'a> = (&'a Field, Option<usize>);
fn get_primary_field<'a>(span: &Span, mem: Option<Member>, default_ident: &str, f: &'a Fields) -> syn::Result<IndexedField<'a>> {
let field = match f {
Fields::Named(nf) => {
let mem_name = mem.map(|x| member_ident(&span, x)).transpose()?;
nf.named.iter().find(|x|
match &mem_name {
Some(name) => x.ident.as_ref().unwrap() == name,
None => x.ident.as_ref().unwrap() == default_ident
})
.map(|x| (x, None))
},
Fields::Unnamed(uf) => {
let mem_idx = mem.map(|x| member_idx(&span, x)).transpose()?;
match mem_idx {
Some(idx) => {
let idx = idx.index as usize;
uf.unnamed.iter().enumerate().find(|(i, _)| *i == idx).map(|(i, x)| (x, Some(i)))
},
None => uf.unnamed.iter().next().map(|x| (x, Some(0)))
}
},
Fields::Unit =>
return Err(span.error("cannot get field from unit enum/struct"))
};
field.ok_or_else(|| span.error("field not found"))
}
fn get_fields_len(f: &Fields) -> usize {
match f {
Fields::Named(nf) => nf.named.len(),
Fields::Unnamed(uf) => uf.unnamed.len(),
Fields::Unit => 0
}
}
fn extractor_variant_field(fields: &Fields, field: IndexedField) -> TokenStream {
match &fields {
Fields::Named(_) => {
let f_name = field.0.ident.as_ref().unwrap();
quote! ( { #f_name: x, .. })
},
Fields::Unnamed(uf) => {
let from_i = field.1.unwrap();
let extractor = (0..uf.unnamed.len()).into_iter().map(|i| {
if i == from_i {
quote!(x)
} else {
quote!(_)
}
});
quote! ( (#(#extractor),*) )
},
Fields::Unit => unreachable!()
}
}
fn make_default(vec: &mut Vec<TokenStream>, default: TokenStream, len: usize) {
if vec.len() < len {
vec.push(quote! ( _ => #default ));
}
}
fn make_context_struct(fields: &Fields, field: IndexedField) -> (TokenStream, TokenStream) {
let (field, field_i) = field;
match fields {
Fields::Named(nf) => {
let field_i = field.ident.as_ref().unwrap();
let context_fields: Vec<&Field> = nf.named.iter().filter(|x| x.ident.as_ref().unwrap() != field_i).collect();
let context_fields_ref = &context_fields;
let context_fields_ident = context_fields.iter().map(|f| f.ident.as_ref().unwrap());
let context_fields_ident_cloned = context_fields_ident.clone();
(quote! ( { #(#context_fields_ref),* } ),
quote! ( { #field_i: ctx, #(#context_fields_ident: self.#context_fields_ident_cloned),* } ))
},
Fields::Unnamed(uf) => {
let context_fields: Vec<&Field> = uf.unnamed.iter().enumerate()
.filter_map(|(i, x)| if i == field_i.unwrap() { None } else { Some(x) }).collect();
let mut i = 0;
let nums = (0..uf.unnamed.len()).into_iter().map(|num| {
if num == field_i.unwrap() {
quote! (ctx)
} else {
let idx = syn::Index::from(i);
i += 1;
quote! (self.#idx)
}
});
(quote! ( ( #(#context_fields),* ); ),
quote! ( ( #(#nums),* ) ))
},
_ => unreachable!()
}
}
fn derive_errer_res(input: proc_macro::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
let input: DeriveInput = Parser::parse(DeriveInput::parse, input)?;
let mut stream = TokenStream::new();
let is_struct;
match &input.data {
Data::Enum(_) => is_struct = false,
Data::Struct(_) => is_struct = true,
Data::Union(_) => return Err(input.span().error("Errer cannot be derived on unions"))
}
let vis = input.vis;
let name = input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut display_overarching = None;
let mut from = false;
let mut context = None;
let mut std = false;
{
for ErrerAttribSpanned {span, attrib} in parse_err_attribs_vec(input.attrs, true, is_struct)? {
match attrib {
ErrerAttrib::Display(s, comma, members) =>
display_overarching = Some((span, s, comma, members)),
ErrerAttrib::From(None) => from = true,
ErrerAttrib::Context => context = Some(span),
ErrerAttrib::Std => std = true,
_ => ()
}
}
}
let mut display_branches = Vec::new();
let mut from_branches = Vec::new();
match input.data {
Data::Enum(x) => {
let v_len = x.variants.len();
for variant in x.variants.into_iter() {
let v_span = variant.span();
let (attrs, fields, variant_name) = (variant.attrs, variant.fields, variant.ident);
let variant_path = quote! { #name::#variant_name };
let mut from =
if from { get_primary_field(&v_span, None, FROM_DEFAULT, &fields).ok() } else { None };
let mut context = context;
for ErrerAttribSpanned {span, attrib} in parse_err_attribs_vec(attrs, false, false)? {
match attrib {
ErrerAttrib::Display(s, comma, members) => {
match &fields {
Fields::Named(_) => {
let mut members_i = Vec::new();
for x in members { members_i.push(member_ident(&span, x)?) }
let mut members_i_unique = Vec::new();
for x in &members_i { members_i_unique.push(unique_name(&x)); }
let members_i_unique_ref = &members_i_unique;
display_branches.push(quote! { #variant_path { #(#members_i: #members_i_unique_ref,)* .. } => write!(x, #s #comma #(#members_i_unique),*), });
},
Fields::Unnamed(uf) => {
let mut members_i = Vec::new();
for x in members { members_i.push(member_idx(&span, x)?.index) }
let mut names = TokenStream::new();
uf.unnamed.iter().enumerate().for_each(|(i, _)| {
let i = i as u32;
if members_i.contains(&i) {
idx_to_ident(i).to_tokens(&mut names);
} else {
Token![_](Span::call_site()).to_tokens(&mut names);
}
Token![,](Span::call_site()).to_tokens(&mut names);
});
display_branches.push(quote! { #variant_path ( #names ) => write!(x, #s #comma #(#members_i),*), });
},
Fields::Unit =>
display_branches.push(quote! { #variant_path => write!(x, #s #comma #(#members),*), })
}
},
ErrerAttrib::From(mem) =>
from = Some(get_primary_field(&span, mem, FROM_DEFAULT, &fields)?),
ErrerAttrib::Context => context = Some(span),
_ => ()
}
}
if std {
for x in from {
let extractor = extractor_variant_field(&fields, x);
from_branches.push(quote!(#variant_path #extractor => Some(x), ));
}
}
if let Some((Field {ident, ty, ..}, _)) = &from {
if get_fields_len(&fields) == 1 && context.is_none() {
let constructor = match &ident {
Some(f_name) => quote! ( #variant_path { #f_name: x } ),
None => quote! ( #variant_path (x) )
};
stream.extend(quote! {
impl #impl_generics std::convert::From<#ty> for #name #ty_generics #where_clause {
fn from(x: #ty) -> Self {
#constructor
}
}
});
}
}
if let Some(span) = context {
if let Some(from) = from {
let (ctx_struct, ctx_constructor) = make_context_struct(&fields, from);
let ty = &from.0.ty;
stream.extend(quote! {
#vis struct #variant_name #ctx_struct
impl #impl_generics errer::IntoErrorContext<#ty, #name #ty_generics> for #variant_name #where_clause {
fn into_target(self, ctx: #ty) -> #name #ty_generics {
#variant_path #ctx_constructor
}
}
});
} else {
return Err(span.error("context depends on from; from is not specified"));
}
}
}
let cause = {
if from_branches.len() > 0 {
make_default(&mut from_branches, quote!(None), v_len);
quote! {
match self {
#(#from_branches)*
}
}
} else { quote!(None) }
};
stream.extend(quote! {
impl #impl_generics errer::ErrorCompat for #name #ty_generics #where_clause {
fn error_source(&self) -> Option<&(dyn std::error::Error + 'static)> {
#cause
}
}
});
if display_branches.len() > 0 {
make_default(&mut display_branches, quote!( Ok(()) ), v_len);
let writer = quote! {
match self {
#(#display_branches)*
}
};
let impl_disp = match display_overarching {
Some((span, s, _, members)) => {
if members.len() > 0 {
return Err(span.error("no extra formatting arguments are permitted in an outer enum display attribute"));
}
quote! {
impl #impl_generics std::fmt::Display for #name #ty_generics #where_clause {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use std::fmt::Write;
let mut x = String::new();
#writer?;
write!(f, #s, x)
}
}
}
},
None => quote! {
impl #impl_generics std::fmt::Display for #name #ty_generics #where_clause {
fn fmt(&self, x: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
#writer
}
}
}
};
stream.extend(impl_disp);
}
},
Data::Struct(x) => {
let mut from = None;
let f_len = get_fields_len(&x.fields);
for (i, field) in x.fields.iter().enumerate() {
let Field {attrs, ..} = field;
let mut set_from = false;
for ErrerAttribSpanned {attrib, ..} in parse_err_attribs_vec(attrs.clone(), false, true)? {
match attrib {
ErrerAttrib::From(None) => set_from = true,
_ => ()
}
}
if set_from { from = Some((field, i)); }
}
if let Some(span) = context {
if let Some(from) = &from {
let (ctx_struct, ctx_constructor) = make_context_struct(&x.fields, (from.0, Some(from.1)));
let ctx_name = Ident::new(&format!("{}Context", name), Span::call_site());
let ty = &from.0.ty;
stream.extend(quote! {
#vis struct #ctx_name #ctx_struct
impl #impl_generics errer::IntoErrorContext<#ty, #name #ty_generics> for #ctx_name #where_clause {
fn into_target(self, ctx: #ty) -> #name #ty_generics {
#name #ctx_constructor
}
}
});
} else {
return Err(span.error("context depends on from; from is not specified"));
}
}
let mut cause = quote!( None );
if let Some((Field {ident, ty, ..}, i)) = from {
if f_len == 1 && context.is_none() {
let constructor = match &ident {
Some(f_name) => quote! ( #name { #f_name: x } ),
None => quote! ( #name (x) )
};
stream.extend(quote! {
impl #impl_generics std::convert::From<#ty> for #name #ty_generics #where_clause {
fn from(x: #ty) -> Self {
#constructor
}
}
});
}
if std {
let mut getter = TokenStream::new();
match &ident {
Some(f_name) => f_name.to_tokens(&mut getter),
None => syn::Index::from(i).to_tokens(&mut getter)
}
cause = quote! ( Some(&self.#getter) );
}
}
stream.extend(quote! {
impl #impl_generics errer::ErrorCompat for #name #ty_generics #where_clause {
fn error_source(&self) -> Option<&(dyn std::error::Error + 'static)> {
#cause
}
}
});
if let Some((_, s, comma, members)) = display_overarching {
let members = members.iter();
stream.extend(quote! {
impl #impl_generics std::fmt::Display for #name #ty_generics #where_clause {
fn fmt(&self, x: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(x, #s #comma #(self.#members),*)
}
}
});
}
},
_ => ()
}
Ok(stream)
}
#[proc_macro_derive(Errer, attributes(errer))]
pub fn derive_errer(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
match derive_errer_res(input) {
Ok(x) => proc_macro::TokenStream::from(x),
Err(x) => proc_macro::TokenStream::from(x.to_compile_error())
}
}