use super::WORD_SIZE;
use crate::BitVec;
#[derive(Debug, Clone)]
pub struct MaskedBitVec<'a, 'b, F: Fn(u64, u64) -> u64> {
vec: &'a BitVec,
mask: &'b BitVec,
bin_op: F,
}
impl<'a, 'b, F> MaskedBitVec<'a, 'b, F>
where
F: Fn(u64, u64) -> u64,
{
#[inline]
pub(crate) fn new(vec: &'a BitVec, mask: &'b BitVec, bin_op: F) -> Result<Self, String> {
if vec.len != mask.len {
return Err(String::from(
"mask cannot have different length than vector",
));
}
Ok(MaskedBitVec { vec, mask, bin_op })
}
#[inline]
fn iter_limbs<'s>(&'s self) -> impl Iterator<Item = u64> + 's
where
'a: 's,
'b: 's,
{
self.vec
.data
.iter()
.zip(&self.mask.data)
.map(|(&a, &b)| (self.bin_op)(a, b))
}
#[inline]
#[must_use]
pub fn get(&self, pos: usize) -> Option<u64> {
if pos >= self.vec.len {
None
} else {
Some(self.get_unchecked(pos))
}
}
#[inline]
#[must_use]
pub fn get_unchecked(&self, pos: usize) -> u64 {
((self.bin_op)(
self.vec.data[pos / WORD_SIZE],
self.mask.data[pos / WORD_SIZE],
) >> (pos % WORD_SIZE))
& 1
}
#[inline]
#[must_use]
pub fn is_bit_set(&self, pos: usize) -> Option<bool> {
if pos >= self.vec.len {
None
} else {
Some(self.is_bit_set_unchecked(pos))
}
}
#[inline]
#[must_use]
pub fn is_bit_set_unchecked(&self, pos: usize) -> bool {
self.get_unchecked(pos) != 0
}
#[inline]
#[must_use]
pub fn get_bits(&self, pos: usize, len: usize) -> Option<u64> {
if len > WORD_SIZE || len == 0 {
return None;
}
if pos + len > self.vec.len {
None
} else {
Some(self.get_bits_unchecked(pos, len))
}
}
#[must_use]
#[allow(clippy::inline_always)]
#[allow(clippy::comparison_chain)] #[inline]
pub fn get_bits_unchecked(&self, pos: usize, len: usize) -> u64 {
debug_assert!(len <= WORD_SIZE);
let partial_word = (self.bin_op)(
self.vec.data[pos / WORD_SIZE],
self.mask.data[pos / WORD_SIZE],
) >> (pos % WORD_SIZE);
if pos % WORD_SIZE + len == WORD_SIZE {
partial_word
} else if pos % WORD_SIZE + len < WORD_SIZE {
partial_word & ((1 << (len % WORD_SIZE)) - 1)
} else {
let next_half = (self.bin_op)(
self.vec.data[pos / WORD_SIZE + 1],
self.mask.data[pos / WORD_SIZE + 1],
) << (WORD_SIZE - pos % WORD_SIZE);
(partial_word | next_half) & ((1 << (len % WORD_SIZE)) - 1)
}
}
#[inline]
#[must_use]
pub fn count_zeros(&self) -> u64 {
self.vec.len as u64 - self.count_ones()
}
#[inline]
#[must_use]
#[allow(clippy::missing_panics_doc)] pub fn count_ones(&self) -> u64 {
let mut ones = self
.iter_limbs()
.take(self.vec.len / WORD_SIZE)
.map(|limb| u64::from(limb.count_ones()))
.sum();
if !self.vec.len.is_multiple_of(WORD_SIZE) {
ones += u64::from(
((self.bin_op)(
*self.vec.data.last().unwrap(),
*self.mask.data.last().unwrap(),
) & ((1 << (self.vec.len % WORD_SIZE)) - 1))
.count_ones(),
);
}
ones
}
#[inline]
#[must_use]
pub fn to_bit_vec(&self) -> BitVec {
BitVec {
data: self.iter_limbs().collect(),
len: self.vec.len,
}
}
}