bool_to_bitflags 0.1.3

A library to compact multiple bools into a single bitflags field automatically with getters and setters.
Documentation
use std::borrow::Cow;

use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use syn::{Attribute, Field, Fields, Ident, Token};
use to_arraystring::ToArrayString;

use crate::{
    args::Args,
    derive_hijack::{hijack_derives, HijackOutput},
    error::Error,
    impl_from_into::{impl_from, impl_into},
    impl_get_set::generate_getters_setters,
    strip_spans::strip_spans,
};

pub struct BoolFieldInner {
    pub field_ident: Ident,
    pub flag_ident: Ident,
    pub attrs: Vec<syn::Attribute>,
    pub vis: syn::Visibility,
}

pub enum BoolField {
    Normal(BoolFieldInner),
    Opt {
        bool_bit: BoolFieldInner,
        tag_bit_flag_ident: Ident,
    },
}

impl BoolField {
    fn from_field(field: &Field) -> Self {
        let field_ident = field.ident.clone().unwrap();
        BoolField::Normal(BoolFieldInner {
            flag_ident: Ident::new(&field_ident.to_string().to_uppercase(), Span::call_site()),
            attrs: field.attrs.clone(),
            vis: field.vis.clone(),
            field_ident,
        })
    }

    fn from_opt_bool_field(field: &Field) -> Self {
        match Self::from_field(field) {
            BoolField::Opt { .. } => unreachable!(),
            BoolField::Normal(bool_bit) => BoolField::Opt {
                tag_bit_flag_ident: format_ident!("{}_OPT_TAG", bool_bit.flag_ident),
                bool_bit,
            },
        }
    }

    pub fn tag_bit_flag_ident(&self) -> Option<&Ident> {
        match self {
            BoolField::Normal(_) => None,
            BoolField::Opt {
                tag_bit_flag_ident, ..
            } => Some(tag_bit_flag_ident),
        }
    }
}

impl std::ops::Deref for BoolField {
    type Target = BoolFieldInner;
    fn deref(&self) -> &Self::Target {
        match self {
            BoolField::Normal(inner) => inner,
            BoolField::Opt { bool_bit, .. } => bool_bit,
        }
    }
}

fn path_from_ident(ident: Ident) -> syn::Path {
    syn::Path {
        leading_colon: None,
        segments: [syn::PathSegment {
            ident,
            arguments: syn::PathArguments::None,
        }]
        .into_iter()
        .collect(),
    }
}

pub fn generate_pub_crate() -> syn::Visibility {
    syn::Visibility::Restricted(syn::VisRestricted {
        pub_token: <Token![pub]>::default(),
        paren_token: syn::token::Paren::default(),
        in_token: None,
        path: Box::new(path_from_ident(Ident::new("crate", Span::call_site()))),
    })
}

pub fn ty_from_ident(ident: syn::Ident) -> syn::Type {
    let path = path_from_ident(ident);
    syn::Type::Path(syn::TypePath { qself: None, path })
}

pub fn extract_cfgs(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> + Clone {
    attrs.iter().filter(|attr| {
        if attr.style != syn::AttrStyle::Outer {
            return false;
        }

        let syn::Meta::List(meta_list) = &attr.meta else {
            return false;
        };

        let Some(ident) = meta_list.path.get_ident() else {
            return false;
        };

        ident == "cfg"
    })
}

fn generate_flag_field(flags_ident: Ident, field_ident: Ident) -> Field {
    Field {
        attrs: Vec::new(),
        ident: Some(field_ident),
        vis: generate_pub_crate(),
        mutability: syn::FieldMutability::None,
        colon_token: Some(<Token![:]>::default()),
        ty: ty_from_ident(flags_ident),
    }
}

fn generate_generic(ty: syn::Type) -> syn::PathArguments {
    syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
        colon2_token: None,
        lt_token: <Token![<]>::default(),
        args: [syn::GenericArgument::Type(ty)].into_iter().collect(),
        gt_token: <Token![>]>::default(),
    })
}

fn is_bool_field(bool_fields: &mut Vec<BoolField>) -> impl FnMut(&Field) -> bool + '_ {
    let bool_ident = Ident::new("bool", Span::call_site());
    let opt_ident = Ident::new("Option", Span::call_site());
    let bool_generic = generate_generic(ty_from_ident(bool_ident.clone()));

    move |field| {
        if let syn::Type::Path(ty) = &field.ty {
            let segments = &ty.path.segments;
            let first_seg = segments.first().expect("field type path has one segment");

            if first_seg.ident == opt_ident && first_seg.arguments == bool_generic {
                bool_fields.push(BoolField::from_opt_bool_field(field));
            } else if first_seg.ident == bool_ident {
                bool_fields.push(BoolField::from_field(field));
            } else {
                return true;
            }

            return false;
        }

        true
    }
}

fn extract_bool_fields(
    flag_field: Field,
    fields: Fields,
) -> Result<(Fields, Vec<BoolField>), Error> {
    let Fields::Named(mut fields) = fields else {
        return Err(Error::Custom(
            Span::call_site(),
            Cow::Borrowed("bool_to_bitflags: Only structs with named fields are supported!"),
        ));
    };

    let mut bool_fields = Vec::new();
    fields.named = fields
        .named
        .into_iter()
        .filter(is_bool_field(&mut bool_fields))
        .chain(std::iter::once(flag_field))
        .collect();

    Ok((Fields::Named(fields), bool_fields))
}

fn get_flag_size(bool_count: usize) -> Result<syn::Type, Error> {
    let ty_name = match bool_count {
        0..=8 => "u8",
        9..=16 => "u16",
        17..=32 => "u32",
        33..=64 => "u64",
        65..=128 => "u128",
        _ => {
            let err_msg = format!(
                "bool_to_bitflags: Cannot fit {bool_count} bool fields into single bitflags type!"
            );

            return Err(Error::Custom(Span::call_site(), Cow::Owned(err_msg)));
        }
    };

    Ok(ty_from_ident(Ident::new(ty_name, Span::call_site())))
}

fn generate_bitflags_type(
    flags_name: &Ident,
    flags_size: &syn::Type,
    bool_fields: &[BoolField],
    flags_derives: &[TokenStream],
) -> TokenStream {
    let opt_bools = bool_fields.iter().filter_map(|f| f.tag_bit_flag_ident());
    let flag_values = (0..(bool_fields.len() + opt_bools.clone().count()))
        .map(|i| (1_u128 << i).to_arraystring())
        .map(|i| syn::LitInt::new(i.as_str(), Span::call_site()));

    let flag_names = bool_fields.iter().map(|f| &f.flag_ident).chain(opt_bools);
    let flag_cfgs = {
        let field_tag_cfgs = bool_fields.iter().map(|f| {
            let field_cfgs = extract_cfgs(&f.attrs);
            let tag_cfgs = if matches!(f, BoolField::Opt { .. }) {
                extract_cfgs(&f.attrs)
            } else {
                extract_cfgs(&[])
            };

            (quote!(#(#field_cfgs)*), quote!(#(#tag_cfgs)*))
        });

        let (field_cfgs, tag_cfgs): (Vec<_>, Vec<_>) = field_tag_cfgs.unzip();
        field_cfgs.into_iter().chain(tag_cfgs)
    };

    #[cfg(feature = "typesize")]
    let typesize_impl = Some(quote!(impl ::typesize::TypeSize for #flags_name {}));
    #[cfg(not(feature = "typesize"))]
    let typesize_impl: Option<TokenStream> = None;

    quote!(
        bitflags::bitflags! {
            #(#flags_derives)*
            pub(crate) struct #flags_name: #flags_size {
                #(#flag_cfgs const #flag_names = #flag_values;)*
            }
        }

        #typesize_impl
    )
}

pub fn bool_to_bitflags(
    args: TokenStream,
    mut struct_item: syn::ItemStruct,
) -> Result<TokenStream, Error> {
    let args = Args::parse(args)?;

    // Hidden flags type should not have the span of the struct's name.
    let flag_field_name = Ident::new("__generated_flags", Span::call_site());
    let flags_name = format_ident!(
        "{}GeneratedFlags",
        struct_item.ident,
        span = Span::call_site()
    );

    let mut original_struct = struct_item.clone();
    original_struct.ident = format_ident!("{}GeneratedOriginal", original_struct.ident);
    strip_spans(&mut original_struct);

    let flag_field = generate_flag_field(flags_name.clone(), flag_field_name.clone());
    let (fields, bool_fields) = extract_bool_fields(flag_field, struct_item.fields)?;
    struct_item.fields = fields;

    let HijackOutput {
        compacted_struct_attrs,
        flags_derives,
    } = hijack_derives(&mut struct_item, &original_struct.ident)?;

    let from_impl = impl_from(
        &struct_item,
        &original_struct.ident,
        &flag_field_name,
        &flags_name,
        &bool_fields,
    );

    let into_impl = impl_into(
        &struct_item,
        &original_struct.ident,
        &flag_field_name,
        &flags_name,
        &bool_fields,
    );

    let flags_size = get_flag_size(bool_fields.len())?;
    let bitflags_def =
        generate_bitflags_type(&flags_name, &flags_size, &bool_fields, &flags_derives);
    let func_impls = generate_getters_setters(
        &struct_item,
        &flags_name,
        &flag_field_name,
        &bool_fields,
        &args,
    );

    Ok(quote!(
        #[allow(clippy::struct_excessive_bools)]
        #original_struct
        #from_impl
        #into_impl

        #bitflags_def
        #(#compacted_struct_attrs)*
        #struct_item
        #func_impls
    ))
}