use crate::bit_chunk_iterator::{UnalignedBitChunk, UnalignedBitChunkIterator};
use crate::bit_util::{ceil, get_bit_raw};
pub struct BitIterator<'a> {
buffer: &'a [u8],
current_offset: usize,
end_offset: usize,
}
impl<'a> BitIterator<'a> {
pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self {
let end_offset = offset.checked_add(len).unwrap();
let required_len = ceil(end_offset, 8);
assert!(
buffer.len() >= required_len,
"BitIterator buffer too small, expected {required_len} got {}",
buffer.len()
);
Self {
buffer,
current_offset: offset,
end_offset,
}
}
}
impl Iterator for BitIterator<'_> {
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
if self.current_offset == self.end_offset {
return None;
}
let v = unsafe { get_bit_raw(self.buffer.as_ptr(), self.current_offset) };
self.current_offset += 1;
Some(v)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining_bits = self.end_offset - self.current_offset;
(remaining_bits, Some(remaining_bits))
}
}
impl ExactSizeIterator for BitIterator<'_> {}
impl DoubleEndedIterator for BitIterator<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.current_offset == self.end_offset {
return None;
}
self.end_offset -= 1;
let v = unsafe { get_bit_raw(self.buffer.as_ptr(), self.end_offset) };
Some(v)
}
}
#[derive(Debug)]
pub struct BitSliceIterator<'a> {
iter: UnalignedBitChunkIterator<'a>,
len: usize,
current_offset: i64,
current_chunk: u64,
}
impl<'a> BitSliceIterator<'a> {
pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self {
let chunk = UnalignedBitChunk::new(buffer, offset, len);
let mut iter = chunk.iter();
let current_offset = -(chunk.lead_padding() as i64);
let current_chunk = iter.next().unwrap_or(0);
Self {
iter,
len,
current_offset,
current_chunk,
}
}
fn advance_to_set_bit(&mut self) -> Option<(i64, u32)> {
loop {
if self.current_chunk != 0 {
let bit_pos = self.current_chunk.trailing_zeros();
return Some((self.current_offset, bit_pos));
}
self.current_chunk = self.iter.next()?;
self.current_offset += 64;
}
}
}
impl Iterator for BitSliceIterator<'_> {
type Item = (usize, usize);
fn next(&mut self) -> Option<Self::Item> {
if self.len == 0 {
return None;
}
let (start_chunk, start_bit) = self.advance_to_set_bit()?;
self.current_chunk |= (1 << start_bit) - 1;
loop {
if self.current_chunk != u64::MAX {
let end_bit = self.current_chunk.trailing_ones();
self.current_chunk &= !((1 << end_bit) - 1);
return Some((
(start_chunk + start_bit as i64) as usize,
(self.current_offset + end_bit as i64) as usize,
));
}
match self.iter.next() {
Some(next) => {
self.current_chunk = next;
self.current_offset += 64;
}
None => {
return Some((
(start_chunk + start_bit as i64) as usize,
std::mem::replace(&mut self.len, 0),
));
}
}
}
}
}
#[derive(Debug)]
pub struct BitIndexIterator<'a> {
current_chunk: u64,
chunk_offset: i64,
iter: UnalignedBitChunkIterator<'a>,
}
impl<'a> BitIndexIterator<'a> {
pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self {
let chunks = UnalignedBitChunk::new(buffer, offset, len);
let mut iter = chunks.iter();
let current_chunk = iter.next().unwrap_or(0);
let chunk_offset = -(chunks.lead_padding() as i64);
Self {
current_chunk,
chunk_offset,
iter,
}
}
}
impl Iterator for BitIndexIterator<'_> {
type Item = usize;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.current_chunk != 0 {
let bit_pos = self.current_chunk.trailing_zeros();
self.current_chunk ^= 1 << bit_pos;
return Some((self.chunk_offset + bit_pos as i64) as usize);
}
self.current_chunk = self.iter.next()?;
self.chunk_offset += 64;
}
}
}
#[derive(Debug)]
pub struct BitIndexU32Iterator<'a> {
curr: u64,
chunk_offset: i64,
iter: UnalignedBitChunkIterator<'a>,
}
impl<'a> BitIndexU32Iterator<'a> {
pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self {
let chunks = UnalignedBitChunk::new(buffer, offset, len);
let mut iter = chunks.iter();
let curr = iter.next().unwrap_or(0);
let chunk_offset = -(chunks.lead_padding() as i64);
Self {
curr,
chunk_offset,
iter,
}
}
}
impl<'a> Iterator for BitIndexU32Iterator<'a> {
type Item = u32;
#[inline(always)]
fn next(&mut self) -> Option<u32> {
loop {
if self.curr != 0 {
let tz = self.curr.trailing_zeros();
self.curr &= self.curr - 1;
return Some((self.chunk_offset + tz as i64) as u32);
}
match self.iter.next() {
Some(next_chunk) => {
self.chunk_offset += 64;
self.curr = next_chunk;
}
None => return None,
}
}
}
}
#[inline]
pub fn try_for_each_valid_idx<E, F: FnMut(usize) -> Result<(), E>>(
len: usize,
offset: usize,
null_count: usize,
nulls: Option<&[u8]>,
f: F,
) -> Result<(), E> {
let valid_count = len - null_count;
if valid_count == len {
(0..len).try_for_each(f)
} else if null_count != len {
BitIndexIterator::new(nulls.unwrap(), offset, len).try_for_each(f)
} else {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bit_iterator_size_hint() {
let mut b = BitIterator::new(&[0b00000011], 0, 2);
assert_eq!(
b.size_hint(),
(2, Some(2)),
"Expected size_hint to be (2, Some(2))"
);
b.next();
assert_eq!(
b.size_hint(),
(1, Some(1)),
"Expected size_hint to be (1, Some(1)) after one bit consumed"
);
b.next();
assert_eq!(
b.size_hint(),
(0, Some(0)),
"Expected size_hint to be (0, Some(0)) after all bits consumed"
);
}
#[test]
fn test_bit_iterator() {
let mask = &[0b00010010, 0b00100011, 0b00000101, 0b00010001, 0b10010011];
let actual: Vec<_> = BitIterator::new(mask, 0, 5).collect();
assert_eq!(actual, &[false, true, false, false, true]);
let actual: Vec<_> = BitIterator::new(mask, 4, 5).collect();
assert_eq!(actual, &[true, false, false, false, true]);
let actual: Vec<_> = BitIterator::new(mask, 12, 14).collect();
assert_eq!(
actual,
&[
false, true, false, false, true, false, true, false, false, false, false, false,
true, false
]
);
assert_eq!(BitIterator::new(mask, 0, 0).count(), 0);
assert_eq!(BitIterator::new(mask, 40, 0).count(), 0);
}
#[test]
#[should_panic(expected = "BitIterator buffer too small, expected 3 got 2")]
fn test_bit_iterator_bounds() {
let mask = &[223, 23];
BitIterator::new(mask, 17, 0);
}
#[test]
fn test_bit_index_u32_iterator_basic() {
let mask = &[0b00010010, 0b00100011];
let result: Vec<u32> = BitIndexU32Iterator::new(mask, 0, 16).collect();
let expected: Vec<u32> = BitIndexIterator::new(mask, 0, 16)
.map(|i| i as u32)
.collect();
assert_eq!(result, expected);
let result: Vec<u32> = BitIndexU32Iterator::new(mask, 4, 8).collect();
let expected: Vec<u32> = BitIndexIterator::new(mask, 4, 8)
.map(|i| i as u32)
.collect();
assert_eq!(result, expected);
let result: Vec<u32> = BitIndexU32Iterator::new(mask, 10, 4).collect();
let expected: Vec<u32> = BitIndexIterator::new(mask, 10, 4)
.map(|i| i as u32)
.collect();
assert_eq!(result, expected);
let result: Vec<u32> = BitIndexU32Iterator::new(mask, 0, 0).collect();
let expected: Vec<u32> = BitIndexIterator::new(mask, 0, 0)
.map(|i| i as u32)
.collect();
assert_eq!(result, expected);
}
#[test]
fn test_bit_index_u32_iterator_all_set() {
let mask = &[0xFF, 0xFF];
let result: Vec<u32> = BitIndexU32Iterator::new(mask, 0, 16).collect();
let expected: Vec<u32> = BitIndexIterator::new(mask, 0, 16)
.map(|i| i as u32)
.collect();
assert_eq!(result, expected);
}
#[test]
fn test_bit_index_u32_iterator_none_set() {
let mask = &[0x00, 0x00];
let result: Vec<u32> = BitIndexU32Iterator::new(mask, 0, 16).collect();
let expected: Vec<u32> = BitIndexIterator::new(mask, 0, 16)
.map(|i| i as u32)
.collect();
assert_eq!(result, expected);
}
#[test]
fn test_bit_index_u32_cross_chunk() {
let mut buf = vec![0u8; 16];
for bit in 60..68 {
let byte = (bit / 8) as usize;
let bit_in_byte = bit % 8;
buf[byte] |= 1 << bit_in_byte;
}
let offset = 58;
let len = 10;
let result: Vec<u32> = BitIndexU32Iterator::new(&buf, offset, len).collect();
let expected: Vec<u32> = BitIndexIterator::new(&buf, offset, len)
.map(|i| i as u32)
.collect();
assert_eq!(result, expected);
}
#[test]
fn test_bit_index_u32_unaligned_offset() {
let mask = &[0b0110_1100, 0b1010_0000];
let offset = 2;
let len = 12;
let result: Vec<u32> = BitIndexU32Iterator::new(mask, offset, len).collect();
let expected: Vec<u32> = BitIndexIterator::new(mask, offset, len)
.map(|i| i as u32)
.collect();
assert_eq!(result, expected);
}
#[test]
fn test_bit_index_u32_long_all_set() {
let len = 200;
let num_bytes = len / 8 + if len % 8 != 0 { 1 } else { 0 };
let bytes = vec![0xFFu8; num_bytes];
let result: Vec<u32> = BitIndexU32Iterator::new(&bytes, 0, len).collect();
let expected: Vec<u32> = BitIndexIterator::new(&bytes, 0, len)
.map(|i| i as u32)
.collect();
assert_eq!(result, expected);
}
#[test]
fn test_bit_index_u32_none_set() {
let len = 50;
let num_bytes = len / 8 + if len % 8 != 0 { 1 } else { 0 };
let bytes = vec![0u8; num_bytes];
let result: Vec<u32> = BitIndexU32Iterator::new(&bytes, 0, len).collect();
let expected: Vec<u32> = BitIndexIterator::new(&bytes, 0, len)
.map(|i| i as u32)
.collect();
assert_eq!(result, expected);
}
}