use std::ops::{BitAnd, BitOr, BitXor, Not, RangeBounds};
use crate::bit::ops::{bitwise_binary_op, bitwise_unary_op};
use crate::bit::{
BitChunks, BitIndexIterator, BitIterator, BitSliceIterator, UnalignedBitChunk,
get_bit_unchecked,
};
use crate::{Alignment, BitBufferMut, Buffer, BufferMut, ByteBuffer, buffer};
#[derive(Debug, Clone, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BitBuffer {
buffer: ByteBuffer,
offset: usize,
len: usize,
}
impl PartialEq for BitBuffer {
fn eq(&self, other: &Self) -> bool {
if self.len != other.len {
return false;
}
self.chunks()
.iter_padded()
.zip(other.chunks().iter_padded())
.all(|(a, b)| a == b)
}
}
impl BitBuffer {
pub fn new(buffer: ByteBuffer, len: usize) -> Self {
assert!(
buffer.len() * 8 >= len,
"provided ByteBuffer not large enough to back BoolBuffer with len {len}"
);
let buffer = buffer.aligned(Alignment::none());
Self {
buffer,
len,
offset: 0,
}
}
pub fn new_with_offset(buffer: ByteBuffer, len: usize, offset: usize) -> Self {
assert!(
len.saturating_add(offset) <= buffer.len().saturating_mul(8),
"provided ByteBuffer (len={}) not large enough to back BoolBuffer with offset {offset} len {len}",
buffer.len()
);
let buffer = buffer.aligned(Alignment::none());
Self {
buffer,
offset,
len,
}
}
pub fn new_set(len: usize) -> Self {
let words = len.div_ceil(8);
let buffer = buffer![0xFF; words];
Self {
buffer,
len,
offset: 0,
}
}
pub fn new_unset(len: usize) -> Self {
let words = len.div_ceil(8);
let buffer = Buffer::zeroed(words);
Self {
buffer,
len,
offset: 0,
}
}
pub fn empty() -> Self {
Self::new_set(0)
}
pub fn full(value: bool, len: usize) -> Self {
if value {
Self::new_set(len)
} else {
Self::new_unset(len)
}
}
pub fn collect_bool<F: FnMut(usize) -> bool>(len: usize, mut f: F) -> Self {
let mut buffer = BufferMut::with_capacity(len.div_ceil(64) * 8);
let chunks = len / 64;
let remainder = len % 64;
for chunk in 0..chunks {
let mut packed = 0;
for bit_idx in 0..64 {
let i = bit_idx + chunk * 64;
packed |= (f(i) as u64) << bit_idx;
}
unsafe { buffer.push_unchecked(packed) }
}
if remainder != 0 {
let mut packed = 0;
for bit_idx in 0..remainder {
let i = bit_idx + chunks * 64;
packed |= (f(i) as u64) << bit_idx;
}
unsafe { buffer.push_unchecked(packed) }
}
buffer.truncate(len.div_ceil(8));
Self::new(buffer.freeze().into_byte_buffer(), len)
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline(always)]
pub fn offset(&self) -> usize {
self.offset
}
#[inline(always)]
pub fn inner(&self) -> &ByteBuffer {
&self.buffer
}
#[inline]
pub fn value(&self, index: usize) -> bool {
assert!(index < self.len);
unsafe { self.value_unchecked(index) }
}
#[inline]
pub unsafe fn value_unchecked(&self, index: usize) -> bool {
unsafe { get_bit_unchecked(self.buffer.as_ptr(), index + self.offset) }
}
pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
let start = match range.start_bound() {
std::ops::Bound::Included(&s) => s,
std::ops::Bound::Excluded(&s) => s + 1,
std::ops::Bound::Unbounded => 0,
};
let end = match range.end_bound() {
std::ops::Bound::Included(&e) => e + 1,
std::ops::Bound::Excluded(&e) => e,
std::ops::Bound::Unbounded => self.len,
};
assert!(start <= end);
assert!(start <= self.len);
assert!(end <= self.len);
let len = end - start;
Self::new_with_offset(self.buffer.clone(), len, self.offset + start)
}
pub fn shrink_offset(self) -> Self {
let bit_offset = self.offset % 8;
let len = self.len;
let buffer = self.into_inner();
BitBuffer::new_with_offset(buffer, len, bit_offset)
}
pub fn unaligned_chunks(&self) -> UnalignedBitChunk<'_> {
UnalignedBitChunk::new(self.buffer.as_slice(), self.offset, self.len)
}
pub fn chunks(&self) -> BitChunks<'_> {
BitChunks::new(self.buffer.as_slice(), self.offset, self.len)
}
pub fn true_count(&self) -> usize {
self.unaligned_chunks().count_ones()
}
pub fn false_count(&self) -> usize {
self.len - self.true_count()
}
pub fn iter(&self) -> BitIterator<'_> {
BitIterator::new(self.buffer.as_slice(), self.offset, self.len)
}
pub fn set_indices(&self) -> BitIndexIterator<'_> {
BitIndexIterator::new(self.buffer.as_slice(), self.offset, self.len)
}
pub fn set_slices(&self) -> BitSliceIterator<'_> {
BitSliceIterator::new(self.buffer.as_slice(), self.offset, self.len)
}
pub fn sliced(&self) -> Self {
if self.offset % 8 == 0 {
return Self::new(
self.buffer.slice(self.offset / 8..self.len.div_ceil(8)),
self.len,
);
}
bitwise_unary_op(self, |a| a)
}
}
impl BitBuffer {
pub fn into_inner(self) -> ByteBuffer {
let word_start = self.offset / 8;
let word_end = (self.offset + self.len).div_ceil(8);
self.buffer.slice(word_start..word_end)
}
pub fn try_into_mut(self) -> Result<BitBufferMut, Self> {
match self.buffer.try_into_mut() {
Ok(buffer) => Ok(BitBufferMut::from_buffer(buffer, self.offset, self.len)),
Err(buffer) => Err(BitBuffer::new_with_offset(buffer, self.len, self.offset)),
}
}
pub fn into_mut(self) -> BitBufferMut {
let offset = self.offset;
let len = self.len;
let inner = self.into_inner().into_mut();
BitBufferMut::from_buffer(inner, offset, len)
}
}
impl From<&[bool]> for BitBuffer {
fn from(value: &[bool]) -> Self {
BitBufferMut::from(value).freeze()
}
}
impl From<Vec<bool>> for BitBuffer {
fn from(value: Vec<bool>) -> Self {
BitBufferMut::from(value).freeze()
}
}
impl FromIterator<bool> for BitBuffer {
fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
BitBufferMut::from_iter(iter).freeze()
}
}
impl BitOr for &BitBuffer {
type Output = BitBuffer;
fn bitor(self, rhs: Self) -> Self::Output {
bitwise_binary_op(self, rhs, |a, b| a | b)
}
}
impl BitOr<&BitBuffer> for BitBuffer {
type Output = BitBuffer;
fn bitor(self, rhs: &BitBuffer) -> Self::Output {
(&self).bitor(rhs)
}
}
impl BitAnd for &BitBuffer {
type Output = BitBuffer;
fn bitand(self, rhs: Self) -> Self::Output {
bitwise_binary_op(self, rhs, |a, b| a & b)
}
}
impl BitAnd<BitBuffer> for &BitBuffer {
type Output = BitBuffer;
fn bitand(self, rhs: BitBuffer) -> Self::Output {
self.bitand(&rhs)
}
}
impl BitAnd<&BitBuffer> for BitBuffer {
type Output = BitBuffer;
fn bitand(self, rhs: &BitBuffer) -> Self::Output {
(&self).bitand(rhs)
}
}
impl Not for &BitBuffer {
type Output = BitBuffer;
fn not(self) -> Self::Output {
bitwise_unary_op(self, |a| !a)
}
}
impl Not for BitBuffer {
type Output = BitBuffer;
fn not(self) -> Self::Output {
(&self).not()
}
}
impl BitXor for &BitBuffer {
type Output = BitBuffer;
fn bitxor(self, rhs: Self) -> Self::Output {
bitwise_binary_op(self, rhs, |a, b| a ^ b)
}
}
impl BitXor<&BitBuffer> for BitBuffer {
type Output = BitBuffer;
fn bitxor(self, rhs: &BitBuffer) -> Self::Output {
(&self).bitxor(rhs)
}
}
impl BitBuffer {
pub fn bitand_not(&self, rhs: &BitBuffer) -> BitBuffer {
bitwise_binary_op(self, rhs, |a, b| a & !b)
}
#[inline]
pub fn iter_bits<F>(&self, mut f: F)
where
F: FnMut(usize, bool),
{
let total_bits = self.len;
if total_bits == 0 {
return;
}
let is_bit_set = |byte: u8, bit_idx: usize| (byte & (1 << bit_idx)) != 0;
let bit_offset = self.offset % 8;
let mut buffer_ptr = unsafe { self.buffer.as_ptr().add(self.offset / 8) };
let mut callback_idx = 0;
if bit_offset > 0 {
let bits_in_first_byte = (8 - bit_offset).min(total_bits);
let byte = unsafe { *buffer_ptr };
for bit_idx in 0..bits_in_first_byte {
f(callback_idx, is_bit_set(byte, bit_offset + bit_idx));
callback_idx += 1;
}
buffer_ptr = unsafe { buffer_ptr.add(1) };
}
let complete_bytes = (total_bits - callback_idx) / 8;
for _ in 0..complete_bytes {
let byte = unsafe { *buffer_ptr };
for bit_idx in 0..8 {
f(callback_idx, is_bit_set(byte, bit_idx));
callback_idx += 1;
}
buffer_ptr = unsafe { buffer_ptr.add(1) };
}
let remaining_bits = total_bits - callback_idx;
if remaining_bits > 0 {
let byte = unsafe { *buffer_ptr };
for bit_idx in 0..remaining_bits {
f(callback_idx, is_bit_set(byte, bit_idx));
callback_idx += 1;
}
}
}
}
impl<'a> IntoIterator for &'a BitBuffer {
type Item = bool;
type IntoIter = BitIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use crate::bit::BitBuffer;
use crate::{ByteBuffer, buffer};
#[test]
fn test_bool() {
let buffer: ByteBuffer = buffer![1 << 7; 1024];
let bools = BitBuffer::new(buffer, 1024 * 8);
assert_eq!(bools.len(), 1024 * 8);
assert!(!bools.is_empty());
assert_eq!(bools.true_count(), 1024);
assert_eq!(bools.false_count(), 1024 * 7);
for word in 0..1024 {
for bit in 0..8 {
if bit == 7 {
assert!(bools.value(word * 8 + bit));
} else {
assert!(!bools.value(word * 8 + bit));
}
}
}
let sliced = bools.slice(64..72);
assert_eq!(sliced.len(), 8);
assert!(!sliced.is_empty());
assert_eq!(sliced.true_count(), 1);
assert_eq!(sliced.false_count(), 7);
for bit in 0..8 {
if bit == 7 {
assert!(sliced.value(bit));
} else {
assert!(!sliced.value(bit));
}
}
}
#[test]
fn test_padded_equaltiy() {
let buf1 = BitBuffer::new_set(64); let buf2 = BitBuffer::collect_bool(64, |x| x < 32);
for i in 0..32 {
assert_eq!(buf1.value(i), buf2.value(i), "Bit {} should be the same", i);
}
for i in 32..64 {
assert_ne!(buf1.value(i), buf2.value(i), "Bit {} should differ", i);
}
assert_eq!(
buf1.slice(0..32),
buf2.slice(0..32),
"Buffer slices with same bits should be equal (`PartialEq` needs `iter_padded()`)"
);
assert_ne!(
buf1.slice(32..64),
buf2.slice(32..64),
"Buffer slices with different bits should not be equal (`PartialEq` needs `iter_padded()`)"
);
}
#[test]
fn test_slice_offset_calculation() {
let buf = BitBuffer::collect_bool(16, |_| true);
let sliced = buf.slice(10..16);
assert_eq!(sliced.offset(), 10);
}
#[rstest]
#[case(5)]
#[case(8)]
#[case(10)]
#[case(13)]
#[case(16)]
#[case(23)]
#[case(100)]
fn test_iter_bits(#[case] len: usize) {
let buf = BitBuffer::collect_bool(len, |i| i % 2 == 0);
let mut collected = Vec::new();
buf.iter_bits(|idx, is_set| {
collected.push((idx, is_set));
});
assert_eq!(collected.len(), len);
for (idx, is_set) in collected {
assert_eq!(is_set, idx % 2 == 0);
}
}
#[rstest]
#[case(3, 5)]
#[case(3, 8)]
#[case(5, 10)]
#[case(2, 16)]
#[case(8, 16)]
#[case(9, 16)]
#[case(17, 16)]
fn test_iter_bits_with_offset(#[case] offset: usize, #[case] len: usize) {
let total_bits = offset + len;
let buf = BitBuffer::collect_bool(total_bits, |i| i % 2 == 0);
let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
let mut collected = Vec::new();
buf_with_offset.iter_bits(|idx, is_set| {
collected.push((idx, is_set));
});
assert_eq!(collected.len(), len);
for (idx, is_set) in collected {
assert_eq!(is_set, (offset + idx) % 2 == 0);
}
}
#[rstest]
#[case(8, 10)]
#[case(9, 7)]
#[case(16, 8)]
#[case(17, 10)]
fn test_iter_bits_catches_wrong_byte_offset(#[case] offset: usize, #[case] len: usize) {
let total_bits = offset + len;
let buf = BitBuffer::collect_bool(total_bits, |i| (i / 8) % 2 == 0);
let buf_with_offset = BitBuffer::new_with_offset(buf.inner().clone(), len, offset);
let mut collected = Vec::new();
buf_with_offset.iter_bits(|idx, is_set| {
collected.push((idx, is_set));
});
assert_eq!(collected.len(), len);
for (idx, is_set) in collected {
let bit_position = offset + idx;
let byte_index = bit_position / 8;
let expected_is_set = byte_index % 2 == 0;
assert_eq!(
is_set, expected_is_set,
"Bit mismatch at index {}: expected {} got {}",
bit_position, expected_is_set, is_set
);
}
}
}