use std::{
fmt::{Debug, Display},
ops::{BitAnd, BitOr},
};
#[derive(Default, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PlayerBitSet {
set: u16,
}
impl std::hash::Hash for PlayerBitSet {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.set.hash(state);
}
}
impl Debug for PlayerBitSet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "PlayerBitSet[")?;
for idx in 0..16 {
if self.get(idx) {
write!(f, "A")?;
} else {
write!(f, "_")?;
}
}
write!(f, "]")
}
}
impl Display for PlayerBitSet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
for idx in 0..16 {
if self.get(idx) {
write!(f, "A")?;
} else {
write!(f, "_")?;
}
}
write!(f, "]")
}
}
impl PlayerBitSet {
pub fn new(players: usize) -> Self {
debug_assert!(
players <= 16,
"PlayerBitSet supports at most 16 players, got {players}"
);
let set = if players >= 16 {
u16::MAX
} else {
(1u16 << players) - 1
};
Self { set }
}
pub fn count(&self) -> usize {
self.set.count_ones() as usize
}
pub fn empty(&self) -> bool {
self.set == 0
}
pub fn enable(&mut self, idx: usize) {
self.set |= 1 << idx;
}
pub fn disable(&mut self, idx: usize) {
self.set &= !(1 << idx);
}
pub fn get(&self, idx: usize) -> bool {
(self.set & (1 << idx)) != 0
}
pub fn ones(self) -> ActivePlayerBitSetIter {
ActivePlayerBitSetIter { set: self.set }
}
}
impl BitOr for PlayerBitSet {
type Output = PlayerBitSet;
fn bitor(self, rhs: Self) -> Self::Output {
Self {
set: self.set | rhs.set,
}
}
}
impl BitAnd for PlayerBitSet {
type Output = PlayerBitSet;
fn bitand(self, rhs: Self) -> Self::Output {
Self {
set: self.set & rhs.set,
}
}
}
pub struct ActivePlayerBitSetIter {
set: u16,
}
impl Iterator for ActivePlayerBitSetIter {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
if self.set == 0 {
None
} else {
let idx = self.set.trailing_zeros() as usize;
self.set &= !(1 << idx);
Some(idx)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_count() {
assert_eq!(7, PlayerBitSet::new(7).count());
}
#[test]
fn test_new_max_players_count() {
let s = PlayerBitSet::new(16);
assert_eq!(16, s.count());
for idx in 0..16 {
assert!(s.get(idx), "player {idx} should be active");
}
}
#[test]
fn test_new_max_players_iter() {
let s = PlayerBitSet::new(16);
let collected: Vec<usize> = s.ones().collect();
assert_eq!(collected, (0..16).collect::<Vec<_>>());
}
#[test]
fn test_new_all_valid_widths_exact_mask() {
for players in 0usize..=16 {
let s = PlayerBitSet::new(players);
assert_eq!(s.count(), players, "count mismatch at players={players}");
for idx in 0..players {
assert!(s.get(idx), "bit {idx} should be set at players={players}");
}
for idx in players..16 {
assert!(
!s.get(idx),
"bit {idx} should be clear at players={players}"
);
}
}
}
#[test]
fn test_default_zero_count() {
assert_eq!(0, PlayerBitSet::default().count());
}
#[test]
fn test_disable_count() {
let mut s = PlayerBitSet::new(7);
assert_eq!(7, s.count());
s.disable(6);
assert_eq!(6, s.count());
s.disable(0);
assert_eq!(5, s.count());
}
#[test]
fn test_enable_count() {
let mut s = PlayerBitSet::default();
assert_eq!(0, s.count());
s.enable(0);
assert_eq!(1, s.count());
s.enable(0);
assert_eq!(1, s.count());
s.enable(2);
assert_eq!(2, s.count());
s.disable(0);
assert_eq!(1, s.count());
}
#[test]
fn test_iter() {
let s = PlayerBitSet::new(2);
let mut iter = s.ones();
assert_eq!(Some(0), iter.next());
assert_eq!(Some(1), iter.next());
assert_eq!(None, iter.next());
}
#[test]
fn test_iter_with_disabled() {
let mut s = PlayerBitSet::new(3);
let mut iter = s.ones();
assert_eq!(Some(0), iter.next());
assert_eq!(Some(1), iter.next());
assert_eq!(Some(2), iter.next());
assert_eq!(None, iter.next());
s.disable(0);
let mut after_iter = s.ones();
assert_eq!(Some(1), after_iter.next());
assert_eq!(Some(2), after_iter.next());
assert_eq!(None, after_iter.next());
}
#[test]
fn test_iter_with_enabled() {
let mut s = PlayerBitSet::default();
let mut iter = s.ones();
assert_eq!(None, iter.next());
s.enable(3);
let mut after_iter = s.ones();
assert_eq!(Some(3), after_iter.next());
assert_eq!(None, after_iter.next());
}
#[test]
fn test_display() {
let mut s = PlayerBitSet::new(6);
s.disable(2);
assert_eq!("[AA_AAA__________]", format!("{s}"))
}
#[test]
fn test_get() {
let mut s = PlayerBitSet::default();
s.enable(0);
s.enable(2);
assert!(s.get(0));
assert!(!s.get(1));
assert!(s.get(2));
s.disable(0);
assert!(!s.get(0));
assert!(!s.get(1));
assert!(s.get(2));
}
#[test]
fn test_hash() {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
fn hash_it<T: Hash>(t: &T) -> u64 {
let mut s = DefaultHasher::new();
t.hash(&mut s);
s.finish()
}
let mut s1 = PlayerBitSet::default();
s1.enable(0);
s1.enable(2);
let mut s2 = PlayerBitSet::default();
s2.enable(0);
s2.enable(2);
let mut s3 = PlayerBitSet::default();
s3.enable(1);
assert_eq!(hash_it(&s1), hash_it(&s2));
assert_ne!(hash_it(&s1), hash_it(&s3));
}
#[test]
fn test_debug() {
let mut s = PlayerBitSet::new(4);
s.disable(1);
let debug_str = format!("{:?}", s);
assert!(!debug_str.is_empty());
assert!(debug_str.contains("PlayerBitSet"));
assert!(debug_str.contains("A")); assert!(debug_str.contains("_")); }
}