use itertools::Itertools;
use proc_macro2::{Ident, TokenStream};
use proc_macro_error2::{abort, abort_call_site};
use quote::quote;
use syn::{punctuated::Iter, Data, DeriveInput, Fields, Type, Variant};
use crate::shared::{self, discriminant_assigner::DiscriminantAssigner, enum_fills_bitsize, fallback::Fallback, unreachable, BitSize};
pub(super) fn from_bits(item: TokenStream) -> TokenStream {
let derive_input = parse(item);
let (derive_data, arb_int, name, internal_bitsize, fallback) = analyze(&derive_input);
let expanded = match &derive_data {
Data::Struct(struct_data) => generate_struct(arb_int, name, &struct_data.fields),
Data::Enum(enum_data) => {
let variants = enum_data.variants.iter();
let match_arms = analyze_enum(variants, name, internal_bitsize, fallback.as_ref(), &arb_int);
generate_enum(arb_int, name, match_arms, fallback)
}
_ => unreachable(()),
};
generate_common(expanded)
}
fn parse(item: TokenStream) -> DeriveInput {
shared::parse_derive(item)
}
fn analyze(derive_input: &DeriveInput) -> (&syn::Data, TokenStream, &Ident, BitSize, Option<Fallback>) {
shared::analyze_derive(derive_input, false)
}
fn analyze_enum(
variants: Iter<Variant>, name: &Ident, internal_bitsize: BitSize, fallback: Option<&Fallback>, arb_int: &TokenStream,
) -> (Vec<TokenStream>, Vec<TokenStream>) {
validate_enum_variants(variants.clone(), fallback);
let enum_is_filled = enum_fills_bitsize(internal_bitsize, variants.len());
if !enum_is_filled && fallback.is_none() {
abort_call_site!("enum doesn't fill its bitsize"; help = "you need to use `#[derive(TryFromBits)]` instead, or specify one of the variants as #[fallback]")
}
if enum_is_filled && fallback.is_some() {
abort_call_site!("enum already has {} variants", variants.len(); help = "remove the `#[fallback]` attribute")
}
let mut assigner = DiscriminantAssigner::new(internal_bitsize);
let is_fallback = |variant_name| {
if let Some(Fallback::Unit(name) | Fallback::WithValue(name)) = fallback {
variant_name == name
} else {
false
}
};
let is_value_fallback = |variant_name| {
if let Some(Fallback::WithValue(name)) = fallback {
variant_name == name
} else {
false
}
};
variants
.map(|variant| {
let variant_name = &variant.ident;
let variant_value = assigner.assign_unsuffixed(variant);
let from_int_match_arm = if is_fallback(variant_name) {
quote!()
} else {
quote! { #variant_value => Self::#variant_name, }
};
let to_int_match_arm = if is_value_fallback(variant_name) {
quote! { #name::#variant_name(number) => number, }
} else {
shared::to_int_match_arm(name, variant_name, arb_int, variant_value)
};
(from_int_match_arm, to_int_match_arm)
})
.unzip()
}
fn generate_enum(
arb_int: TokenStream, enum_type: &Ident, match_arms: (Vec<TokenStream>, Vec<TokenStream>), fallback: Option<Fallback>,
) -> TokenStream {
let (from_int_match_arms, to_int_match_arms) = match_arms;
let const_ = if cfg!(feature = "nightly") { quote!(const) } else { quote!() };
let from_enum_impl = shared::generate_from_enum_impl(&arb_int, enum_type, to_int_match_arms, &const_);
let catch_all_arm = match fallback {
Some(Fallback::WithValue(fallback_ident)) => quote! {
_ => Self::#fallback_ident(number),
},
Some(Fallback::Unit(fallback_ident)) => quote! {
_ => Self::#fallback_ident,
},
None => quote! {
_ => ::core::panic!("unreachable: arbitrary_int already validates that this is unreachable")
},
};
quote! {
impl #const_ ::core::convert::From<#arb_int> for #enum_type {
fn from(number: #arb_int) -> Self {
match number.value() {
#( #from_int_match_arms )*
#catch_all_arm
}
}
}
#from_enum_impl
}
}
fn generate_filled_check_for(ty: &Type, vec: &mut Vec<TokenStream>) {
use Type::*;
match ty {
Path(_) => {
let assume = quote! { ::bilge::assume_filled::<#ty>(); };
vec.push(assume);
}
Tuple(tuple) => {
for elem in &tuple.elems {
generate_filled_check_for(elem, vec)
}
}
Array(array) => generate_filled_check_for(&array.elem, vec),
_ => unreachable(()),
}
}
fn generate_struct(arb_int: TokenStream, struct_type: &Ident, fields: &Fields) -> TokenStream {
let const_ = if cfg!(feature = "nightly") { quote!(const) } else { quote!() };
let mut assumes = Vec::new();
for field in fields {
generate_filled_check_for(&field.ty, &mut assumes)
}
let assumes = assumes.into_iter().unique_by(TokenStream::to_string);
quote! {
impl #const_ ::core::convert::From<#arb_int> for #struct_type {
fn from(value: #arb_int) -> Self {
#( #assumes )*
Self { value }
}
}
impl #const_ ::core::convert::From<#struct_type> for #arb_int {
fn from(value: #struct_type) -> Self {
value.value
}
}
}
}
fn generate_common(expanded: TokenStream) -> TokenStream {
quote! {
#expanded
}
}
fn validate_enum_variants(variants: Iter<Variant>, fallback: Option<&Fallback>) {
for variant in variants {
if let Some(fallback) = &fallback {
if fallback.is_fallback_variant(&variant.ident) {
continue;
}
}
if !matches!(variant.fields, Fields::Unit) {
let help_message = if fallback.is_some() {
"change this variant to a unit"
} else {
"add a fallback variant or change this variant to a unit"
};
abort!(variant, "FromBits only supports unit variants for variants without `#[fallback]`"; help = help_message);
}
}
}