use std::borrow::Cow;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::marker::PhantomData;
use vortex_error::VortexResult;
use crate::BitBuffer;
use crate::BitBufferMut;
pub struct BitView<'a, const NB: usize> {
bits: Cow<'a, [u8; NB]>,
true_count: usize,
}
impl<const NB: usize> Debug for BitView<'_, NB> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct(&format!("BitView[{}]", NB * 8))
.field("true_count", &self.true_count)
.field("bits", &self.as_raw())
.finish()
}
}
impl<const NB: usize> BitView<'static, NB> {
const ALL_TRUE: [u8; NB] = [u8::MAX; NB];
const ALL_FALSE: [u8; NB] = [0; NB];
pub const fn all_true() -> Self {
unsafe { BitView::new_unchecked(&Self::ALL_TRUE, NB * 8) }
}
pub const fn all_false() -> Self {
unsafe { BitView::new_unchecked(&Self::ALL_FALSE, 0) }
}
}
impl<'a, const NB: usize> BitView<'a, NB> {
pub const N: usize = NB * 8;
pub const N_WORDS: usize = NB * 8 / (usize::BITS as usize);
const _ASSERT_MULTIPLE_OF_8: () = assert!(
NB.is_multiple_of(8),
"NB must be a multiple of 8 for N to be a multiple of 64"
);
pub fn new(bits: &'a [u8; NB]) -> Self {
let ptr = bits.as_ptr().cast::<usize>();
let true_count = (0..Self::N_WORDS)
.map(|idx| unsafe { ptr.add(idx).read_unaligned().count_ones() as usize })
.sum();
BitView {
bits: Cow::Borrowed(bits),
true_count,
}
}
pub fn new_owned(bits: [u8; NB]) -> Self {
let ptr = bits.as_ptr().cast::<usize>();
let true_count = (0..Self::N_WORDS)
.map(|idx| unsafe { ptr.add(idx).read_unaligned().count_ones() as usize })
.sum();
BitView {
bits: Cow::Owned(bits),
true_count,
}
}
pub(crate) const unsafe fn new_unchecked(bits: &'a [u8; NB], true_count: usize) -> Self {
BitView {
bits: Cow::Borrowed(bits),
true_count,
}
}
pub fn from_slice(bits: &'a [u8]) -> Self {
assert_eq!(bits.len(), NB);
let bits_array = unsafe { &*(bits.as_ptr().cast::<[u8; NB]>()) };
BitView::new(bits_array)
}
pub fn with_prefix(n_true: usize) -> Self {
assert!(n_true <= Self::N);
let mut bits = [0u8; NB];
let n_full_words = n_true / (usize::BITS as usize);
let remaining_bits = n_true % (usize::BITS as usize);
let ptr = bits.as_mut_ptr().cast::<usize>();
for word_idx in 0..n_full_words {
unsafe { ptr.add(word_idx).write_unaligned(usize::MAX) };
}
if remaining_bits > 0 {
let mask = (1usize << remaining_bits) - 1;
unsafe { ptr.add(n_full_words).write_unaligned(mask) };
}
Self {
bits: Cow::Owned(bits),
true_count: n_true,
}
}
pub fn true_count(&self) -> usize {
self.true_count
}
pub fn iter_words(&self) -> impl Iterator<Item = usize> + '_ {
let ptr = self.bits.as_ptr().cast::<usize>();
(0..Self::N_WORDS).map(move |idx| unsafe { ptr.add(idx).read_unaligned() })
}
pub fn iter_ones<F>(&self, mut f: F)
where
F: FnMut(usize),
{
match self.true_count {
0 => {}
n if n == Self::N => (0..Self::N).for_each(&mut f),
_ => {
let mut bit_idx = 0;
for mut raw in self.iter_words() {
while raw != 0 {
let bit_pos = raw.trailing_zeros();
f(bit_idx + bit_pos as usize);
raw &= raw - 1; }
bit_idx += usize::BITS as usize;
}
}
}
}
pub fn try_iter_ones<F>(&self, mut f: F) -> VortexResult<()>
where
F: FnMut(usize) -> VortexResult<()>,
{
match self.true_count {
0 => Ok(()),
n if n == Self::N => {
for i in 0..Self::N {
f(i)?;
}
Ok(())
}
_ => {
let mut bit_idx = 0;
for mut raw in self.iter_words() {
while raw != 0 {
let bit_pos = raw.trailing_zeros();
f(bit_idx + bit_pos as usize)?;
raw &= raw - 1; }
bit_idx += usize::BITS as usize;
}
Ok(())
}
}
}
pub fn iter_zeros<F>(&self, mut f: F)
where
F: FnMut(usize),
{
match self.true_count {
0 => (0..Self::N).for_each(&mut f),
n if n == Self::N => {}
_ => {
let mut bit_idx = 0;
for mut raw in self.iter_words() {
while raw != usize::MAX {
let bit_pos = raw.trailing_ones();
f(bit_idx + bit_pos as usize);
raw |= 1usize << bit_pos; }
bit_idx += usize::BITS as usize;
}
}
}
}
pub fn iter_slices<F>(&self, mut f: F)
where
F: FnMut(BitSlice),
{
if self.true_count == 0 {
return;
}
let mut abs_bit_offset: usize = 0; let mut slice_start_bit: usize = 0; let mut slice_length: usize = 0;
for mut word in self.iter_words() {
match word {
0 => {
if slice_length > 0 {
f(BitSlice {
start: slice_start_bit,
len: slice_length,
});
slice_length = 0;
}
}
usize::MAX => {
if slice_length == 0 {
slice_start_bit = abs_bit_offset;
}
slice_length += usize::BITS as usize;
}
_ => {
while word != 0 {
let zeros = word.trailing_zeros() as usize;
if slice_length > 0 && zeros > 0 {
f(BitSlice {
start: slice_start_bit,
len: slice_length,
});
slice_length = 0; }
word >>= zeros;
if word == 0 {
break;
}
let ones = word.trailing_ones() as usize;
if slice_length == 0 {
let current_word_idx = abs_bit_offset + zeros;
slice_start_bit = current_word_idx;
}
slice_length += ones;
word >>= ones;
}
}
}
abs_bit_offset += usize::BITS as usize;
}
if slice_length > 0 {
f(BitSlice {
start: slice_start_bit,
len: slice_length,
});
}
}
pub fn as_raw(&self) -> &[u8; NB] {
self.bits.as_ref()
}
}
pub struct BitSlice {
pub start: usize,
pub len: usize,
}
impl BitBuffer {
pub fn iter_bit_views<const NB: usize>(&self) -> impl Iterator<Item = BitView<'_, NB>> + '_ {
assert_eq!(
self.offset(),
0,
"BitView iteration requires zero bit offset"
);
BitViewIterator::new(self.inner().as_ref(), self.len())
}
}
impl BitBufferMut {
pub fn iter_bit_views<const NB: usize>(&self) -> impl Iterator<Item = BitView<'_, NB>> + '_ {
assert_eq!(
self.offset(),
0,
"BitView iteration requires zero bit offset"
);
BitViewIterator::new(self.inner().as_ref(), self.len())
}
}
pub(super) struct BitViewIterator<'a, const NB: usize> {
bits: &'a [u8],
view_idx: usize,
n_views: usize,
_phantom: PhantomData<[u8; NB]>,
}
impl<'a, const NB: usize> BitViewIterator<'a, NB> {
fn new(bits: &'a [u8], len: usize) -> Self {
debug_assert_eq!(len.div_ceil(8), bits.len());
let n_views = bits.len().div_ceil(NB);
BitViewIterator {
bits,
view_idx: 0,
n_views,
_phantom: PhantomData,
}
}
}
impl<'a, const NB: usize> Iterator for BitViewIterator<'a, NB> {
type Item = BitView<'a, NB>;
fn next(&mut self) -> Option<Self::Item> {
if self.view_idx == self.n_views {
return None;
}
let start_byte = self.view_idx * NB;
let end_byte = start_byte + NB;
let bits = if end_byte <= self.bits.len() {
BitView::from_slice(&self.bits[start_byte..end_byte])
} else {
let remaining_bytes = self.bits.len() - start_byte;
let mut remaining = [0u8; NB];
remaining[..remaining_bytes].copy_from_slice(&self.bits[start_byte..]);
BitView::new_owned(remaining)
};
self.view_idx += 1;
Some(bits)
}
}
#[cfg(test)]
mod tests {
use super::*;
const NB: usize = 128; const N: usize = NB * 8;
#[test]
fn test_iter_ones_empty() {
let bits = [0; NB];
let view = BitView::<NB>::new(&bits);
let mut ones = Vec::new();
view.iter_ones(|idx| ones.push(idx));
assert_eq!(ones, Vec::<usize>::new());
assert_eq!(view.true_count(), 0);
}
#[test]
fn test_iter_ones_all_set() {
let view = BitView::<NB>::all_true();
let mut ones = Vec::new();
view.iter_ones(|idx| ones.push(idx));
assert_eq!(ones.len(), N);
assert_eq!(ones, (0..N).collect::<Vec<_>>());
assert_eq!(view.true_count(), N);
}
#[test]
fn test_iter_zeros_empty() {
let bits = [0; NB];
let view = BitView::<NB>::new(&bits);
let mut zeros = Vec::new();
view.iter_zeros(|idx| zeros.push(idx));
assert_eq!(zeros.len(), N);
assert_eq!(zeros, (0..N).collect::<Vec<_>>());
}
#[test]
fn test_iter_zeros_all_set() {
let view = BitView::<NB>::all_true();
let mut zeros = Vec::new();
view.iter_zeros(|idx| zeros.push(idx));
assert_eq!(zeros, Vec::<usize>::new());
}
#[test]
fn test_iter_ones_single_bit() {
let mut bits = [0; NB];
bits[0] = 1; let view = BitView::new(&bits);
let mut ones = Vec::new();
view.iter_ones(|idx| ones.push(idx));
assert_eq!(ones, vec![0]);
assert_eq!(view.true_count(), 1);
}
#[test]
fn test_iter_zeros_single_bit_unset() {
let mut bits = [u8::MAX; NB];
bits[0] = u8::MAX ^ 1; let view = BitView::new(&bits);
let mut zeros = Vec::new();
view.iter_zeros(|idx| zeros.push(idx));
assert_eq!(zeros, vec![0]);
}
#[test]
fn test_iter_ones_multiple_bits_first_word() {
let mut bits = [0; NB];
bits[0] = 0b1010101; let view = BitView::new(&bits);
let mut ones = Vec::new();
view.iter_ones(|idx| ones.push(idx));
assert_eq!(ones, vec![0, 2, 4, 6]);
assert_eq!(view.true_count(), 4);
}
#[test]
fn test_iter_zeros_multiple_bits_first_word() {
let mut bits = [u8::MAX; NB];
bits[0] = !0b1010101; let view = BitView::new(&bits);
let mut zeros = Vec::new();
view.iter_zeros(|idx| zeros.push(idx));
assert_eq!(zeros, vec![0, 2, 4, 6]);
}
#[test]
fn test_lsb_bit_ordering() {
let mut bits = [0; NB];
bits[0] = 0b11111111; let view = BitView::new(&bits);
let mut ones = Vec::new();
view.iter_ones(|idx| ones.push(idx));
assert_eq!(ones, vec![0, 1, 2, 3, 4, 5, 6, 7]);
assert_eq!(view.true_count(), 8);
}
#[test]
fn test_all_false_static() {
let view = BitView::<NB>::all_false();
let mut ones = Vec::new();
let mut zeros = Vec::new();
view.iter_ones(|idx| ones.push(idx));
view.iter_zeros(|idx| zeros.push(idx));
assert_eq!(ones, Vec::<usize>::new());
assert_eq!(zeros, (0..N).collect::<Vec<_>>());
assert_eq!(view.true_count(), 0);
}
#[test]
fn test_compatibility_with_mask_all_true() {
let view = BitView::<NB>::all_true();
let mut bitview_ones = Vec::new();
view.iter_ones(|idx| bitview_ones.push(idx));
let expected_indices: Vec<usize> = (0..N).collect();
assert_eq!(bitview_ones, expected_indices);
assert_eq!(view.true_count(), N);
}
#[test]
fn test_compatibility_with_mask_all_false() {
let view = BitView::<NB>::all_false();
let mut bitview_ones = Vec::new();
view.iter_ones(|idx| bitview_ones.push(idx));
let mut bitview_zeros = Vec::new();
view.iter_zeros(|idx| bitview_zeros.push(idx));
assert_eq!(bitview_ones, Vec::<usize>::new());
assert_eq!(bitview_zeros, (0..N).collect::<Vec<_>>());
assert_eq!(view.true_count(), 0);
}
#[test]
fn test_compatibility_with_mask_from_indices() {
let indices = vec![0, 10, 20, 63, 64, 100, 500, 1023];
let mut bits = [0; NB];
for idx in &indices {
let word_idx = idx / 8;
let bit_idx = idx % 8;
bits[word_idx] |= 1u8 << bit_idx;
}
let view = BitView::new(&bits);
let mut bitview_ones = Vec::new();
view.iter_ones(|idx| bitview_ones.push(idx));
assert_eq!(bitview_ones, indices);
assert_eq!(view.true_count(), indices.len());
}
#[test]
fn test_compatibility_with_mask_slices() {
let slices = vec![(0, 10), (100, 110), (500, 510)];
let mut bits = [0; NB];
for (start, end) in &slices {
for idx in *start..*end {
let word_idx = idx / 8;
let bit_idx = idx % 8;
bits[word_idx] |= 1u8 << bit_idx;
}
}
let view = BitView::new(&bits);
let mut bitview_ones = Vec::new();
view.iter_ones(|idx| bitview_ones.push(idx));
let mut expected_indices = Vec::new();
for (start, end) in &slices {
expected_indices.extend(*start..*end);
}
assert_eq!(bitview_ones, expected_indices);
assert_eq!(view.true_count(), expected_indices.len());
}
#[test]
fn test_with_prefix() {
assert_eq!(BitView::<NB>::with_prefix(0).true_count(), 0);
for i in 1..N {
let view = BitView::<NB>::with_prefix(i);
let mut slices = vec![];
view.iter_slices(|slice| slices.push(slice));
assert_eq!(slices.len(), 1);
}
}
}