use std::ops::Range;
const BIT_MASK_LEN: u32 = u32::ilog2(u64::BITS);
const BIT_MASK: u64 = (1 << BIT_MASK_LEN) - 1;
#[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BlockedBitVec<const BLOCK_SIZE_BITS: usize> {
bits: Vec<u64>,
}
impl<const BLOCK_SIZE_BITS: usize> BlockedBitVec<BLOCK_SIZE_BITS> {
const BLOCK_SIZE: usize = BLOCK_SIZE_BITS / 64;
const LOG2_BLOCK_SIZE: u32 = u32::ilog2(Self::BLOCK_SIZE as u32);
#[inline]
const fn block_range(index: usize) -> Range<usize> {
let block_index = index * Self::BLOCK_SIZE;
block_index..(block_index + Self::BLOCK_SIZE)
}
#[inline]
pub fn num_blocks(&self) -> usize {
self.bits.len() >> Self::LOG2_BLOCK_SIZE
}
#[inline]
pub fn get_block(&self, i: usize) -> &[u64] {
&self.bits[Self::block_range(i)]
}
#[inline]
pub fn get_block_mut(&mut self, index: usize) -> &mut [u64] {
&mut self.bits[Self::block_range(index)]
}
#[inline]
const fn coordinate(bit_index: usize) -> (usize, u64) {
let index = bit_index.wrapping_shr(BIT_MASK_LEN);
let bit = 1u64 << (bit_index as u64 & BIT_MASK);
(index, bit)
}
#[inline]
pub fn set_for_block(block: &mut [u64], bit_index: usize) -> bool {
let (index, bit) = Self::coordinate(bit_index);
let previously_contained = block[index] & bit > 0;
block[index] |= bit;
previously_contained
}
#[inline]
pub fn check_for_block(block: &[u64], bit_index: usize) -> bool {
let (index, bit) = Self::coordinate(bit_index);
block[index] & bit > 0
}
#[inline]
pub fn as_slice(&self) -> &[u64] {
&self.bits
}
#[inline]
pub fn clear(&mut self) {
for i in 0..self.bits.len() {
self.bits[i] = 0;
}
}
}
impl<const BLOCK_SIZE_BITS: usize> From<Vec<u64>> for BlockedBitVec<BLOCK_SIZE_BITS> {
fn from(mut bits: Vec<u64>) -> Self {
let num_u64s_per_block = BLOCK_SIZE_BITS / 64;
let r = bits.len() % num_u64s_per_block;
if r != 0 {
bits.extend(vec![0; num_u64s_per_block - r]);
}
bits.shrink_to_fit();
Self { bits }
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use std::collections::HashSet;
#[test]
fn test_to_from_vec() {
fn to_from_<const N: usize>(size: usize) {
let b: BlockedBitVec<N> = vec![0u64; size].into();
assert_eq!(b.num_blocks() * N, b.as_slice().len() * 64);
assert!(size <= b.as_slice().len());
assert!((size + N) > b.as_slice().len());
}
for size in 1..=10009 {
to_from_::<64>(size);
to_from_::<128>(size);
to_from_::<256>(size);
to_from_::<512>(size);
}
}
#[test]
fn test_only_random_inserts_are_contained() {
let mut vec = BlockedBitVec::<64>::from(vec![0; 80]);
let mut control = HashSet::new();
let mut rng = rand::rng();
for _ in 0..100000 {
let block_index = rng.random_range(0..vec.num_blocks());
let bit_index = rng.random_range(0..64);
let block = vec.get_block(block_index);
if !control.contains(&(block_index, bit_index)) {
assert!(!BlockedBitVec::<64>::check_for_block(block, bit_index));
}
let block_mut = vec.get_block_mut(block_index);
control.insert((block_index, bit_index));
BlockedBitVec::<64>::set_for_block(block_mut, bit_index);
assert!(BlockedBitVec::<64>::check_for_block(block_mut, bit_index));
}
}
}