#, "?style=flat-square&logo=rust)](https://crates.io/crates/", env!("CARGO_PKG_NAME"), ")")]
#, "?style=flat-square&logo=docs.rs)](https://docs.rs/", env!("CARGO_PKG_NAME"), ")")]
#"]
#, "-blue?style=flat-square&logo=rust)")]
#![doc = concat!(env!("CARGO_PKG_NAME"), " = ", "\"", env!("CARGO_PKG_VERSION_MAJOR"), ".", env!("CARGO_PKG_VERSION_MINOR"), "\"")]
use proc_macro::TokenStream;
use quote::quote;
use std::{
collections::{HashMap, HashSet},
hash::Hash,
};
use syn::{
parenthesized,
parse::{Parse, ParseBuffer, ParseStream},
parse_macro_input, parse_quote,
punctuated::Punctuated,
spanned::Spanned,
visit_mut::VisitMut,
Attribute, Error, Expr, Field, Ident, Item, ItemStruct, Token, Type, TypeArray,
};
#[proc_macro_attribute]
pub fn subdef(args: TokenStream, input: TokenStream) -> TokenStream {
let mut adt = parse_macro_input!(input as Item);
let args = proc_macro2::TokenStream::from(args);
let mut errors = Vec::new();
let mut expanded_adts = Vec::new();
let mut always_applicable_attrs = Vec::new();
let mut labelled_attrs = HashMap::new();
let mut applicable_labels = HashSet::new();
expand_subdef_attrs(
&mut vec![parse_quote!(#[subdef(#args)])],
&mut always_applicable_attrs,
&mut labelled_attrs,
&mut applicable_labels,
&mut errors,
);
expand_adt(
&mut adt,
&mut expanded_adts,
&mut errors,
&mut always_applicable_attrs,
&mut labelled_attrs,
&mut applicable_labels,
);
let errors = errors
.into_iter()
.reduce(|mut errors, error| {
errors.combine(error);
errors
})
.map(Error::into_compile_error);
quote! {
#errors
#adt
#(#expanded_adts)*
}
.into()
}
fn expand_adt(
adt: &mut Item,
expanded_adts: &mut Vec<Item>,
errors: &mut Vec<Error>,
always_applicable_attrs: &mut Vec<proc_macro2::TokenStream>,
labelled_attrs: &mut HashMap<IdentHash, proc_macro2::TokenStream>,
applicable_labels: &mut HashSet<IdentHash>,
) {
let (attrs, fields): (_, Box<dyn Iterator<Item = &mut Field>>) = match adt {
Item::Struct(adt) => (&mut adt.attrs, Box::new(adt.fields.iter_mut())),
Item::Enum(adt) => (
&mut adt.attrs,
Box::new(
adt.variants
.iter_mut()
.flat_map(|variant| variant.fields.iter_mut()),
),
),
Item::Union(adt) => (&mut adt.attrs, Box::new(adt.fields.named.iter_mut())),
item => {
errors.push(Error::new(
item.span(),
"expected `struct`, `enum`, or `union`",
));
return;
}
};
expand_subdef_attrs(
attrs,
always_applicable_attrs,
labelled_attrs,
applicable_labels,
errors,
);
for field in fields {
let mut always_applicable_attrs = always_applicable_attrs.clone();
let mut labelled_attrs = labelled_attrs.clone();
let mut applicable_labels = applicable_labels.clone();
expand_field(
field,
expanded_adts,
errors,
&mut always_applicable_attrs,
&mut labelled_attrs,
&mut applicable_labels,
);
}
}
fn expand_field(
field: &mut Field,
expanded_adts: &mut Vec<Item>,
errors: &mut Vec<Error>,
always_applicable_attrs: &mut Vec<proc_macro2::TokenStream>,
labelled_attrs: &mut HashMap<IdentHash, proc_macro2::TokenStream>,
applicable_labels: &mut HashSet<IdentHash>,
) {
match &mut field.ty {
Type::Array(TypeArray {
elem: field_ty,
len: Expr::Block(block),
..
}) if block.attrs.is_empty() && block.label.is_none() && block.block.stmts.len() == 1 => {
if let syn::Stmt::Item(
Item::Struct(ItemStruct {
ident, generics, ..
})
| Item::Enum(syn::ItemEnum {
ident, generics, ..
})
| Item::Union(syn::ItemUnion {
ident, generics, ..
}),
) = block.block.stmts.first().expect(".len() == 1")
{
let is_generic = generics.lt_token.is_some();
ReplaceTyInferWithIdent {
replace_with: ident.clone(),
is_generic,
}
.visit_type_mut(field_ty);
let Some(syn::Stmt::Item(item)) = block.block.stmts.iter_mut().next() else {
unreachable!("see match condition")
};
expand_adt(
item,
expanded_adts,
errors,
always_applicable_attrs,
labelled_attrs,
applicable_labels,
);
let item = std::mem::replace(item, syn::Item::Verbatim(TokenStream::new().into()));
field.ty = *field_ty.clone();
expanded_adts.push(item);
}
}
_ => (),
};
}
struct ReplaceTyInferWithIdent {
replace_with: Ident,
is_generic: bool,
}
impl syn::visit_mut::VisitMut for ReplaceTyInferWithIdent {
fn visit_type_mut(&mut self, ty: &mut crate::Type) {
syn::visit_mut::visit_type_mut(self, ty);
if let Type::Infer(infer) = ty {
if !self.is_generic {
let mut ident = self.replace_with.clone();
ident.set_span(infer.span());
*ty = Type::Path(syn::TypePath {
qself: None,
path: ident.into(),
});
}
};
}
}
#[derive(Eq, Clone, Debug)]
struct IdentHash(Ident);
impl PartialEq for IdentHash {
fn eq(&self, other: &Self) -> bool {
other.0 == self.0
}
}
impl Hash for IdentHash {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_string().hash(state);
}
}
fn expand_subdef_attrs(
adt_attrs: &mut Vec<Attribute>,
always_applicable_attrs: &mut Vec<proc_macro2::TokenStream>,
labelled_attrs: &mut HashMap<IdentHash, proc_macro2::TokenStream>,
applicable_labels: &mut HashSet<IdentHash>,
errors: &mut Vec<Error>,
) {
let mut skip_just_this_time = HashSet::new();
let mut apply_just_this_time = HashSet::new();
let mut extracted = Vec::new();
adt_attrs.retain_mut(|attr| {
if attr.path().is_ident("subdef") {
extracted.push(std::mem::replace(attr, parse_quote!(#[dummy])));
false
} else {
true
}
});
for attr in extracted {
let subdefs = match attr
.parse_args_with(Punctuated::<AttrSubdefSingle, Token![,]>::parse_terminated)
{
Ok(subdefs) => subdefs,
Err(err) => {
errors.push(err);
continue;
}
};
for subdef in subdefs {
match subdef {
AttrSubdefSingle::Attr { attr } => {
always_applicable_attrs.push(attr);
}
AttrSubdefSingle::AttrLabel { label, attr } => {
labelled_attrs.insert(IdentHash(label.clone()), attr);
applicable_labels.insert(IdentHash(label));
}
AttrSubdefSingle::Skip(labels) => {
for label in labels {
skip_just_this_time.insert(IdentHash(label));
}
}
AttrSubdefSingle::SkipRecursively(labels) => {
for label in labels {
applicable_labels.remove(&IdentHash(label));
}
}
AttrSubdefSingle::Apply(labels) => {
for label in labels {
apply_just_this_time.insert(IdentHash(label));
}
}
AttrSubdefSingle::ApplyRecursively(labels) => {
for label in labels {
applicable_labels.insert(IdentHash(label));
}
}
}
}
}
let applicable_labels = applicable_labels
.union(&apply_just_this_time)
.cloned()
.collect::<HashSet<_>>();
let applicable_labels = applicable_labels
.difference(&skip_just_this_time)
.collect::<HashSet<_>>();
let attrs = always_applicable_attrs
.iter()
.chain(
labelled_attrs
.iter()
.filter_map(|(label, attr)| applicable_labels.contains(label).then_some(attr)),
)
.cloned();
adt_attrs.splice(0..0, attrs.map(|attr| parse_quote!(#[#attr])));
}
enum AttrSubdefSingle {
Attr { attr: proc_macro2::TokenStream },
AttrLabel {
label: Ident,
attr: proc_macro2::TokenStream,
},
Skip(Punctuated<Ident, Token![,]>),
SkipRecursively(Punctuated<Ident, Token![,]>),
Apply(Punctuated<Ident, Token![,]>),
ApplyRecursively(Punctuated<Ident, Token![,]>),
}
impl Parse for AttrSubdefSingle {
fn parse(input: ParseStream) -> syn::Result<Self> {
let single = if input.parse::<Option<kw::skip>>()?.is_some() {
Self::Skip
} else if input.parse::<Option<kw::skip_recursively>>()?.is_some() {
Self::SkipRecursively
} else if input.parse::<Option<kw::apply>>()?.is_some() {
Self::Apply
} else if input.parse::<Option<kw::apply_recursively>>()?.is_some() {
Self::ApplyRecursively
} else if input.peek2(Token![=]) {
let label = input.parse::<Ident>()?;
input.parse::<Token![=]>()?;
return Ok(Self::AttrLabel {
label,
attr: parse_until_comma(input)?,
});
} else {
return Ok(Self::Attr {
attr: parse_until_comma(input)?,
});
};
let labels;
parenthesized!(labels in input);
Ok(single(labels.parse_terminated(Ident::parse, Token![,])?))
}
}
fn parse_until_comma(input: &ParseBuffer) -> syn::Result<proc_macro2::TokenStream> {
let mut attr = proc_macro2::TokenStream::new();
while !input.peek(Token![,]) && !input.is_empty() {
let tt: proc_macro2::TokenTree = input.parse()?;
attr.extend([tt]);
}
Ok(attr)
}
mod kw {
syn::custom_keyword!(skip);
syn::custom_keyword!(skip_recursively);
syn::custom_keyword!(apply);
syn::custom_keyword!(apply_recursively);
}