use proc_macro::TokenStream;
use quote::{ToTokens, quote};
pub(crate) fn error(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr_ts = proc_macro2::TokenStream::from(attr);
let item_ts = proc_macro2::TokenStream::from(item);
let item_ast = match syn::parse2::<syn::ItemStruct>(item_ts.clone()) {
Ok(s) => s,
Err(e) => return e.to_compile_error().into(),
};
let (format_str, format_args) = parse_error_attr(&attr_ts, &item_ast.ident);
let display_impl = generate_display(&item_ast, &format_str, &format_args);
let error_impl = generate_error(&item_ast);
let derive_debug = quote! { #[derive(::core::fmt::Debug)] };
let struct_def = strip_error_attrs(&item_ast);
let expanded = quote! {
#[::macro_magic::export_tokens]
#derive_debug
#struct_def
#display_impl
#error_impl
};
expanded.into()
}
fn parse_error_attr(
attr_ts: &proc_macro2::TokenStream,
struct_name: &syn::Ident,
) -> (syn::LitStr, proc_macro2::TokenStream) {
let tokens = attr_ts.clone().into_iter().collect::<Vec<_>>();
if tokens.is_empty() {
let default_msg = struct_name.to_string();
return (
syn::LitStr::new(&default_msg, struct_name.span()),
proc_macro2::TokenStream::new(),
);
}
let mut iter = tokens.into_iter().peekable();
let first = iter.next();
let second = iter.next();
let (format_lit, rest): (syn::LitStr, proc_macro2::TokenStream) = match (first, second) {
(Some(proc_macro2::TokenTree::Literal(lit)), _) => {
let s = lit.to_string();
let inner = s
.strip_prefix('"')
.and_then(|s| s.strip_suffix('"'))
.unwrap_or(&s);
let lit_str = syn::LitStr::new(inner, lit.span());
let rest: proc_macro2::TokenStream = iter.collect();
(lit_str, rest)
}
(Some(proc_macro2::TokenTree::Group(g)), _) => {
let mut inner_iter = g.stream().into_iter().peekable();
let first_inner = inner_iter.next();
let second_inner = inner_iter.next();
match (first_inner, second_inner) {
(Some(proc_macro2::TokenTree::Literal(lit)), _) => {
let s = lit.to_string();
let inner = s
.strip_prefix('"')
.and_then(|s| s.strip_suffix('"'))
.unwrap_or(&s);
let lit_str = syn::LitStr::new(inner, lit.span());
let rest: proc_macro2::TokenStream = inner_iter.collect();
(lit_str, rest)
}
_ => {
let default_msg = struct_name.to_string();
(
syn::LitStr::new(&default_msg, struct_name.span()),
proc_macro2::TokenStream::new(),
)
}
}
}
_ => {
let default_msg = struct_name.to_string();
(
syn::LitStr::new(&default_msg, struct_name.span()),
proc_macro2::TokenStream::new(),
)
}
};
(format_lit, rest)
}
fn generate_display(
item_ast: &syn::ItemStruct,
format_str: &syn::LitStr,
format_args: &proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let ident = &item_ast.ident;
let (impl_generics, ty_generics, where_clause) = item_ast.generics.split_for_impl();
if format_args.is_empty() {
quote! {
impl #impl_generics ::core::fmt::Display for #ident #ty_generics #where_clause {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
write!(f, #format_str)
}
}
}
} else {
quote! {
impl #impl_generics ::core::fmt::Display for #ident #ty_generics #where_clause {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
write!(f, #format_str, #format_args)
}
}
}
}
}
fn generate_error(item_ast: &syn::ItemStruct) -> proc_macro2::TokenStream {
let ident = &item_ast.ident;
let (impl_generics, ty_generics, where_clause) = item_ast.generics.split_for_impl();
let mut source_field = None;
for field in item_ast.fields.iter() {
for attr in &field.attrs {
if attr.path().is_ident("error_source") {
source_field = Some(field);
}
}
}
if let Some(field) = source_field {
let field_ident = field.ident.as_ref().unwrap();
quote! {
impl #impl_generics ::core::error::Error for #ident #ty_generics #where_clause {
fn source(&self) -> Option<&(dyn ::core::error::Error + 'static)> {
Some(&self.#field_ident)
}
}
}
} else {
quote! {
impl #impl_generics ::core::error::Error for #ident #ty_generics #where_clause {}
}
}
}
fn strip_error_attrs(item_ast: &syn::ItemStruct) -> proc_macro2::TokenStream {
let mut cleaned = item_ast.clone();
cleaned.attrs.retain(|attr| !attr.path().is_ident("error"));
for field in cleaned.fields.iter_mut() {
field
.attrs
.retain(|attr| !attr.path().is_ident("error_source"));
}
cleaned.to_token_stream()
}