use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct Flags64(u64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct Flags32(u32);
impl Flags64 {
pub const MAX_FLAGS: u8 = 64;
pub const fn new() -> Self {
Self(0)
}
pub const fn from_raw(value: u64) -> Self {
Self(value)
}
pub const fn as_u64(&self) -> u64 {
self.0
}
pub fn set(&mut self, bit: u8, value: bool) {
assert!(bit < 64, "Bit position must be < 64");
if value {
self.0 |= 1u64 << bit;
} else {
self.0 &= !(1u64 << bit);
}
}
pub const fn get(&self, bit: u8) -> bool {
assert!(bit < 64, "Bit position must be < 64");
(self.0 & (1u64 << bit)) != 0
}
pub fn set_multiple<I>(&mut self, flags: I)
where
I: IntoIterator<Item = (u8, bool)>,
{
for (bit, value) in flags {
self.set(bit, value);
}
}
pub const fn count_set(&self) -> u32 {
self.0.count_ones()
}
pub const fn any(&self) -> bool {
self.0 != 0
}
pub const fn none(&self) -> bool {
self.0 == 0
}
pub const fn all(&self) -> bool {
self.0 == u64::MAX
}
pub fn clear(&mut self) {
self.0 = 0;
}
pub fn merge(&mut self, other: Flags64) {
self.0 |= other.0;
}
pub fn intersect(&mut self, other: Flags64) {
self.0 &= other.0;
}
pub const fn contains(&self, other: Flags64) -> bool {
(self.0 & other.0) == other.0
}
}
impl Flags32 {
pub const MAX_FLAGS: u8 = 32;
pub const fn new() -> Self {
Self(0)
}
pub const fn from_raw(value: u32) -> Self {
Self(value)
}
pub const fn as_u32(&self) -> u32 {
self.0
}
pub fn set(&mut self, bit: u8, value: bool) {
assert!(bit < 32, "Bit position must be < 32");
if value {
self.0 |= 1u32 << bit;
} else {
self.0 &= !(1u32 << bit);
}
}
pub const fn get(&self, bit: u8) -> bool {
assert!(bit < 32, "Bit position must be < 32");
(self.0 & (1u32 << bit)) != 0
}
pub fn set_multiple<I>(&mut self, flags: I)
where
I: IntoIterator<Item = (u8, bool)>,
{
for (bit, value) in flags {
self.set(bit, value);
}
}
pub const fn count_set(&self) -> u32 {
self.0.count_ones()
}
pub const fn any(&self) -> bool {
self.0 != 0
}
pub const fn none(&self) -> bool {
self.0 == 0
}
pub fn clear(&mut self) {
self.0 = 0;
}
pub fn merge(&mut self, other: Flags32) {
self.0 |= other.0;
}
pub fn intersect(&mut self, other: Flags32) {
self.0 &= other.0;
}
pub const fn contains(&self, other: Flags32) -> bool {
(self.0 & other.0) == other.0
}
}
impl From<Flags64> for u64 {
fn from(f: Flags64) -> u64 {
f.0
}
}
impl From<u64> for Flags64 {
fn from(value: u64) -> Flags64 {
Flags64(value)
}
}
impl From<Flags32> for u32 {
fn from(f: Flags32) -> u32 {
f.0
}
}
impl From<u32> for Flags32 {
fn from(value: u32) -> Flags32 {
Flags32(value)
}
}
impl fmt::Binary for Flags64 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:064b}", self.0)
}
}
impl fmt::Binary for Flags32 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:032b}", self.0)
}
}
#[macro_export]
macro_rules! define_flags {
(
$name:ident: $base:ty {
$($flag:ident = $bit:expr),* $(,)?
}
) => {
pub struct $name($base);
impl $name {
pub const fn new() -> Self {
Self(<$base>::new())
}
pub const fn from_raw(value: impl Into<$base>) -> Self {
Self(value.into())
}
$(
pub const fn $flag(&self) -> bool {
self.0.get($bit)
}
paste::paste! {
pub fn [<set_ $flag>](&mut self, value: bool) {
self.0.set($bit, value);
}
}
)*
pub const fn as_raw(&self) -> $base {
self.0
}
}
impl Default for $name {
fn default() -> Self {
Self::new()
}
}
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flags64_basic() {
let mut flags = Flags64::new();
assert!(flags.none());
assert!(!flags.any());
flags.set(0, true);
assert!(flags.any());
assert!(flags.get(0));
assert!(!flags.get(1));
flags.set(63, true);
assert!(flags.get(63));
assert_eq!(flags.count_set(), 2);
}
#[test]
fn test_flags32_basic() {
let mut flags = Flags32::new();
assert!(flags.none());
flags.set(0, true);
flags.set(31, true);
assert_eq!(flags.count_set(), 2);
}
#[test]
fn test_set_multiple() {
let mut flags = Flags64::new();
flags.set_multiple(vec![(0, true), (5, true), (10, true)]);
assert!(flags.get(0));
assert!(flags.get(5));
assert!(flags.get(10));
assert!(!flags.get(1));
assert_eq!(flags.count_set(), 3);
}
#[test]
fn test_merge() {
let mut flags1 = Flags64::new();
flags1.set(0, true);
flags1.set(1, true);
let mut flags2 = Flags64::new();
flags2.set(2, true);
flags2.set(3, true);
flags1.merge(flags2);
assert!(flags1.get(0));
assert!(flags1.get(1));
assert!(flags1.get(2));
assert!(flags1.get(3));
assert_eq!(flags1.count_set(), 4);
}
#[test]
fn test_intersect() {
let mut flags1 = Flags64::new();
flags1.set(0, true);
flags1.set(1, true);
flags1.set(2, true);
let mut flags2 = Flags64::new();
flags2.set(1, true);
flags2.set(2, true);
flags2.set(3, true);
flags1.intersect(flags2);
assert!(!flags1.get(0));
assert!(flags1.get(1));
assert!(flags1.get(2));
assert!(!flags1.get(3));
}
#[test]
fn test_contains() {
let mut flags1 = Flags64::new();
flags1.set(0, true);
flags1.set(1, true);
flags1.set(2, true);
let mut flags2 = Flags64::new();
flags2.set(0, true);
flags2.set(1, true);
assert!(flags1.contains(flags2));
flags2.set(5, true);
assert!(!flags1.contains(flags2));
}
#[test]
fn test_clear() {
let mut flags = Flags64::new();
flags.set(0, true);
flags.set(10, true);
assert!(flags.any());
flags.clear();
assert!(flags.none());
}
#[test]
fn test_from_raw() {
let flags = Flags64::from_raw(0b1010);
assert!(!flags.get(0));
assert!(flags.get(1));
assert!(!flags.get(2));
assert!(flags.get(3));
}
}