bilge-impl 0.3.0

Use bitsized types as if they were a feature of rust.
Documentation
use itertools::Itertools;
use proc_macro2::{Ident, TokenStream};
use proc_macro_error2::{abort, abort_call_site};
use quote::quote;
use syn::{punctuated::Iter, Data, DeriveInput, Fields, Type, Variant};

use crate::shared::{self, discriminant_assigner::DiscriminantAssigner, enum_fills_bitsize, fallback::Fallback, unreachable, BitSize};

pub(super) fn from_bits(item: TokenStream) -> TokenStream {
    let derive_input = parse(item);
    let (derive_data, arb_int, name, internal_bitsize, fallback) = analyze(&derive_input);
    let expanded = match &derive_data {
        Data::Struct(struct_data) => generate_struct(arb_int, name, &struct_data.fields),
        Data::Enum(enum_data) => {
            let variants = enum_data.variants.iter();
            let match_arms = analyze_enum(variants, name, internal_bitsize, fallback.as_ref(), &arb_int);
            generate_enum(arb_int, name, match_arms, fallback)
        }
        _ => unreachable(()),
    };
    generate_common(expanded)
}

fn parse(item: TokenStream) -> DeriveInput {
    shared::parse_derive(item)
}

fn analyze(derive_input: &DeriveInput) -> (&syn::Data, TokenStream, &Ident, BitSize, Option<Fallback>) {
    shared::analyze_derive(derive_input, false)
}

fn analyze_enum(
    variants: Iter<Variant>, name: &Ident, internal_bitsize: BitSize, fallback: Option<&Fallback>, arb_int: &TokenStream,
) -> (Vec<TokenStream>, Vec<TokenStream>) {
    validate_enum_variants(variants.clone(), fallback);

    let enum_is_filled = enum_fills_bitsize(internal_bitsize, variants.len());
    if !enum_is_filled && fallback.is_none() {
        abort_call_site!("enum doesn't fill its bitsize"; help = "you need to use `#[derive(TryFromBits)]` instead, or specify one of the variants as #[fallback]")
    }
    if enum_is_filled && fallback.is_some() {
        // NOTE: I've shortly tried pointing to `#[fallback]` here but it wasn't easy enough
        abort_call_site!("enum already has {} variants", variants.len(); help = "remove the `#[fallback]` attribute")
    }

    let mut assigner = DiscriminantAssigner::new(internal_bitsize);

    let is_fallback = |variant_name| {
        if let Some(Fallback::Unit(name) | Fallback::WithValue(name)) = fallback {
            variant_name == name
        } else {
            false
        }
    };

    let is_value_fallback = |variant_name| {
        if let Some(Fallback::WithValue(name)) = fallback {
            variant_name == name
        } else {
            false
        }
    };

    variants
        .map(|variant| {
            let variant_name = &variant.ident;
            let variant_value = assigner.assign_unsuffixed(variant);

            let from_int_match_arm = if is_fallback(variant_name) {
                // this value will be handled by the catch-all arm
                quote!()
            } else {
                quote! { #variant_value => Self::#variant_name, }
            };

            let to_int_match_arm = if is_value_fallback(variant_name) {
                quote! { #name::#variant_name(number) => number, }
            } else {
                shared::to_int_match_arm(name, variant_name, arb_int, variant_value)
            };

            (from_int_match_arm, to_int_match_arm)
        })
        .unzip()
}

fn generate_enum(
    arb_int: TokenStream, enum_type: &Ident, match_arms: (Vec<TokenStream>, Vec<TokenStream>), fallback: Option<Fallback>,
) -> TokenStream {
    let (from_int_match_arms, to_int_match_arms) = match_arms;

    let const_ = if cfg!(feature = "nightly") { quote!(const) } else { quote!() };

    let from_enum_impl = shared::generate_from_enum_impl(&arb_int, enum_type, to_int_match_arms, &const_);

    let catch_all_arm = match fallback {
        Some(Fallback::WithValue(fallback_ident)) => quote! {
            _ => Self::#fallback_ident(number),
        },
        Some(Fallback::Unit(fallback_ident)) => quote! {
            _ => Self::#fallback_ident,
        },
        None => quote! {
            // constness: unreachable!() is not const yet
            _ => ::core::panic!("unreachable: arbitrary_int already validates that this is unreachable")
        },
    };

    quote! {
        impl #const_ ::core::convert::From<#arb_int> for #enum_type {
            fn from(number: #arb_int) -> Self {
                match number.value() {
                    #( #from_int_match_arms )*
                    #catch_all_arm
                }
            }
        }
        #from_enum_impl
    }
}

/// a type is considered "filled" if it implements `Bitsized` with `BITS == N`,
/// and additionally is allowed to have any unsigned value from `0` to `2^N - 1`.
/// such a type can then safely implement `From<uN>`.
/// a filled type automatically implements the trait `Filled` thanks to a blanket impl.
/// the check generated by this function will prevent compilation if `ty` is not `Filled`.
fn generate_filled_check_for(ty: &Type, vec: &mut Vec<TokenStream>) {
    use Type::*;
    match ty {
        Path(_) => {
            let assume = quote! { ::bilge::assume_filled::<#ty>(); };
            vec.push(assume);
        }
        Tuple(tuple) => {
            for elem in &tuple.elems {
                generate_filled_check_for(elem, vec)
            }
        }
        Array(array) => generate_filled_check_for(&array.elem, vec),
        _ => unreachable(()),
    }
}

fn generate_struct(arb_int: TokenStream, struct_type: &Ident, fields: &Fields) -> TokenStream {
    let const_ = if cfg!(feature = "nightly") { quote!(const) } else { quote!() };

    let mut assumes = Vec::new();
    for field in fields {
        generate_filled_check_for(&field.ty, &mut assumes)
    }

    // a single check per type is enough, so the checks can be deduped
    let assumes = assumes.into_iter().unique_by(TokenStream::to_string);

    quote! {
        impl #const_ ::core::convert::From<#arb_int> for #struct_type {
            fn from(value: #arb_int) -> Self {
                #( #assumes )*
                Self { value }
            }
        }
        impl #const_ ::core::convert::From<#struct_type> for #arb_int {
            fn from(value: #struct_type) -> Self {
                value.value
            }
        }
    }
}

fn generate_common(expanded: TokenStream) -> TokenStream {
    quote! {
        #expanded
    }
}

fn validate_enum_variants(variants: Iter<Variant>, fallback: Option<&Fallback>) {
    for variant in variants {
        // we've already validated the correctness of the fallback variant, and that there's at most one such variant.
        // this means we can safely skip a fallback variant if we find one.
        if let Some(fallback) = &fallback {
            if fallback.is_fallback_variant(&variant.ident) {
                continue;
            }
        }

        if !matches!(variant.fields, Fields::Unit) {
            let help_message = if fallback.is_some() {
                "change this variant to a unit"
            } else {
                "add a fallback variant or change this variant to a unit"
            };
            abort!(variant, "FromBits only supports unit variants for variants without `#[fallback]`"; help = help_message);
        }
    }
}