#[derive(Debug, Clone)]
pub struct BitMask {
pub(crate) words: Vec<u64>,
pub(crate) nrows: usize,
}
impl BitMask {
pub fn all_true(nrows: usize) -> Self {
let nwords = nwords_for(nrows);
let mut words = vec![u64::MAX; nwords];
if nrows % 64 != 0 && nwords > 0 {
let tail = nrows % 64;
words[nwords - 1] = (1u64 << tail) - 1;
}
BitMask { words, nrows }
}
pub fn all_false(nrows: usize) -> Self {
let nwords = nwords_for(nrows);
BitMask {
words: vec![0u64; nwords],
nrows,
}
}
pub fn from_bools(bools: &[bool]) -> Self {
let nrows = bools.len();
let nwords = nwords_for(nrows);
let mut words = vec![0u64; nwords];
for (i, &b) in bools.iter().enumerate() {
if b {
words[i / 64] |= 1u64 << (i % 64);
}
}
BitMask { words, nrows }
}
#[inline]
pub fn get(&self, i: usize) -> bool {
debug_assert!(i < self.nrows);
(self.words[i / 64] >> (i % 64)) & 1 == 1
}
#[inline]
pub fn set(&mut self, i: usize) {
debug_assert!(i < self.nrows);
self.words[i / 64] |= 1u64 << (i % 64);
}
#[inline]
pub fn clear(&mut self, i: usize) {
debug_assert!(i < self.nrows);
self.words[i / 64] &= !(1u64 << (i % 64));
}
pub fn count_ones(&self) -> usize {
self.words.iter().map(|w| w.count_ones() as usize).sum()
}
pub fn and(&self, other: &BitMask) -> BitMask {
assert_eq!(
self.nrows, other.nrows,
"BitMask::and: nrows mismatch ({} vs {})",
self.nrows, other.nrows
);
let words = self
.words
.iter()
.zip(other.words.iter())
.map(|(a, b)| a & b)
.collect();
BitMask {
words,
nrows: self.nrows,
}
}
pub fn or(&self, other: &BitMask) -> BitMask {
assert_eq!(self.nrows, other.nrows);
let words = self
.words
.iter()
.zip(other.words.iter())
.map(|(a, b)| a | b)
.collect();
BitMask {
words,
nrows: self.nrows,
}
}
pub fn iter_set(&self) -> impl Iterator<Item = usize> + '_ {
self.words.iter().enumerate().flat_map(|(word_idx, &word)| {
let base = word_idx * 64;
let mut w = word;
std::iter::from_fn(move || {
if w == 0 {
return None;
}
let bit = w.trailing_zeros() as usize;
w &= w - 1; Some(base + bit)
})
})
}
pub fn nrows(&self) -> usize {
self.nrows
}
pub fn nwords(&self) -> usize {
self.words.len()
}
pub fn size_bytes(&self) -> usize {
self.words.len() * 8
}
}
#[inline]
pub fn nwords_for(nrows: usize) -> usize {
(nrows + 63) / 64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_all_true() {
let mask = BitMask::all_true(100);
assert_eq!(mask.count_ones(), 100);
for i in 0..100 {
assert!(mask.get(i));
}
}
#[test]
fn test_all_false() {
let mask = BitMask::all_false(100);
assert_eq!(mask.count_ones(), 0);
}
#[test]
fn test_iter_set() {
let mask = BitMask::from_bools(&[true, false, true, false, true]);
let indices: Vec<usize> = mask.iter_set().collect();
assert_eq!(indices, vec![0, 2, 4]);
}
#[test]
fn test_and() {
let a = BitMask::from_bools(&[true, true, false, false]);
let b = BitMask::from_bools(&[true, false, true, false]);
let c = a.and(&b);
let indices: Vec<usize> = c.iter_set().collect();
assert_eq!(indices, vec![0]);
}
#[test]
fn test_iter_set_word_boundary() {
let mut bools = vec![false; 128];
bools[0] = true;
bools[63] = true;
bools[64] = true;
bools[127] = true;
let mask = BitMask::from_bools(&bools);
let indices: Vec<usize> = mask.iter_set().collect();
assert_eq!(indices, vec![0, 63, 64, 127]);
}
#[test]
fn test_size_bytes() {
let mask = BitMask::all_true(1_000_000);
assert_eq!(mask.size_bytes(), 125_000);
}
}