use proc_macro::TokenStream;
use quote::quote;
use syn::{
Attribute, Expr, Generics, Ident, Result, Token, Type, Visibility, WhereClause,
parse::{Parse, ParseStream},
punctuated::Punctuated,
};
struct FieldWithValue {
attrs: Vec<Attribute>,
vis: Visibility,
ident: Ident,
ty: Type,
default: Option<Expr>,
}
impl Parse for FieldWithValue {
fn parse(input: ParseStream) -> Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let vis = input.parse::<Visibility>()?;
let ident = input.parse::<Ident>()?;
input.parse::<Token![:]>()?;
let ty = input.parse::<Type>()?;
let default = if input.peek(Token![=]) {
input.parse::<Token![=]>()?;
Some(input.parse::<Expr>()?)
} else {
None
};
Ok(FieldWithValue {
attrs,
vis,
ident,
ty,
default,
})
}
}
struct StructDef {
attrs: Vec<Attribute>,
vis: Visibility,
ident: Ident,
generics: Generics,
fields: Punctuated<FieldWithValue, Token![,]>,
}
impl Parse for StructDef {
fn parse(input: ParseStream) -> Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let vis = input.parse::<Visibility>()?;
input.parse::<Token![struct]>()?;
let ident = input.parse::<Ident>()?;
let mut generics = input.parse::<Generics>()?;
if input.peek(Token![where]) {
generics.where_clause = Some(input.parse::<WhereClause>()?);
}
let content;
syn::braced!(content in input);
let fields = content.parse_terminated(FieldWithValue::parse, Token![,])?;
Ok(StructDef {
attrs,
vis,
ident,
generics,
fields,
})
}
}
#[proc_macro]
pub fn default(item: TokenStream) -> TokenStream {
let input = match syn::parse::<StructDef>(item) {
Ok(input) => input,
Err(e) => return e.to_compile_error().into(),
};
let struct_name = &input.ident;
let vis = &input.vis;
let mut use_const_default = false;
let mut attrs = input.attrs.clone();
attrs.retain(|attr| {
if attr.path().is_ident("const_default") {
use_const_default = true;
false } else {
true }
});
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let fields_no_defaults = input.fields.iter().map(|f| {
let field_name = &f.ident;
let field_ty = &f.ty;
let field_vis = &f.vis;
let field_attrs = &f.attrs;
quote! { #(#field_attrs)* #field_vis #field_name: #field_ty }
});
let field_defaults = input.fields.iter().map(|f| {
let field_name = &f.ident;
if let Some(default_expr) = &f.default {
quote! { #field_name: #default_expr }
} else {
quote! { #field_name: std::default::Default::default() }
}
});
let const_default_impl = if use_const_default {
let const_field_defaults = field_defaults.clone();
quote! {
impl #impl_generics #struct_name #ty_generics #where_clause {
pub const fn const_default() -> Self {
Self {
#(#const_field_defaults,)*
}
}
}
}
} else {
quote! {}
};
let expanded = quote! {
#(#attrs)*
#vis struct #struct_name #ty_generics #where_clause {
#(#fields_no_defaults,)*
}
impl #impl_generics std::default::Default for #struct_name #ty_generics #where_clause {
fn default() -> Self {
Self {
#(#field_defaults,)*
}
}
}
#const_default_impl
};
TokenStream::from(expanded)
}