use std::{fmt, ops::Range};
#[derive(Debug, Eq, PartialEq)]
pub enum BlockSizeError {
NonPowerOfTwo,
TooSmall,
}
impl fmt::Display for BlockSizeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl std::error::Error for BlockSizeError {}
const BIT_MASK_LEN: u32 = u32::ilog2(u64::BITS);
const BIT_MASK: u64 = (1 << BIT_MASK_LEN) - 1;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct BlockedBitVec<const BLOCK_SIZE_BITS: usize> {
bits: Vec<u64>,
}
impl<const BLOCK_SIZE_BITS: usize> BlockedBitVec<BLOCK_SIZE_BITS> {
const BLOCK_SIZE: usize = BLOCK_SIZE_BITS / 64;
const LOG2_BLOCK_SIZE: u32 = u32::ilog2(Self::BLOCK_SIZE as u32);
const fn is_power_of_2() -> bool {
2usize.pow(usize::ilog2(BLOCK_SIZE_BITS)) == BLOCK_SIZE_BITS
}
pub fn new(num_blocks: usize) -> Result<Self, BlockSizeError> {
if BLOCK_SIZE_BITS < 64 {
Err(BlockSizeError::TooSmall)
} else if !Self::is_power_of_2() {
Err(BlockSizeError::NonPowerOfTwo)
} else {
Ok(Self {
bits: vec![0u64; num_blocks * Self::BLOCK_SIZE],
})
}
}
#[inline]
const fn block_range(index: usize) -> Range<usize> {
let block_index = index * Self::BLOCK_SIZE;
block_index..(block_index + Self::BLOCK_SIZE)
}
#[inline]
pub fn num_blocks(&self) -> usize {
self.bits.len() >> Self::LOG2_BLOCK_SIZE
}
#[inline]
pub fn get_block(&self, i: usize) -> &[u64] {
&self.bits[Self::block_range(i)]
}
#[inline]
pub fn get_block_mut(&mut self, index: usize) -> &mut [u64] {
&mut self.bits[Self::block_range(index)]
}
#[inline]
const fn coordinate(bit_index: usize) -> (usize, u64) {
let index = bit_index.wrapping_shr(BIT_MASK_LEN);
let bit = 1u64 << (bit_index as u64 & BIT_MASK);
(index, bit)
}
#[inline]
pub fn set_for_block(block: &mut [u64], bit_index: usize) {
let (index, bit) = Self::coordinate(bit_index);
block[index] |= bit;
}
#[inline]
pub fn check_for_block(block: &[u64], bit_index: usize) -> bool {
let (index, bit) = Self::coordinate(bit_index);
block[index] & bit > 0
}
#[inline]
pub fn as_slice(&self) -> &[u64] {
&self.bits
}
}
impl<const BLOCK_SIZE_BITS: usize> From<Vec<u64>> for BlockedBitVec<BLOCK_SIZE_BITS> {
fn from(mut bits: Vec<u64>) -> Self {
let num_u64s_per_block = BLOCK_SIZE_BITS / 64;
let r = bits.len() % num_u64s_per_block;
if r != 0 {
bits.extend(vec![0; num_u64s_per_block - r]);
}
bits.shrink_to_fit();
Self { bits }
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use std::collections::HashSet;
#[test]
fn test_build() {
assert_eq!(BlockedBitVec::<1>::new(10), Err(BlockSizeError::TooSmall));
assert_eq!(BlockedBitVec::<63>::new(10), Err(BlockSizeError::TooSmall));
assert_eq!(
BlockedBitVec::<65>::new(10),
Err(BlockSizeError::NonPowerOfTwo)
);
assert_eq!(
BlockedBitVec::<129>::new(10),
Err(BlockSizeError::NonPowerOfTwo)
);
assert!(BlockedBitVec::<64>::new(10).is_ok());
assert!(BlockedBitVec::<128>::new(10).is_ok());
assert!(BlockedBitVec::<256>::new(10).is_ok());
assert!(BlockedBitVec::<512>::new(10).is_ok());
assert!(BlockedBitVec::<1024>::new(10).is_ok());
}
#[test]
fn test_only_random_inserts_are_contained() {
let mut vec = BlockedBitVec::<64>::new(10).unwrap();
let mut control = HashSet::new();
let mut rng = rand::thread_rng();
for _ in 0..100000 {
let block_index = rng.gen_range(0..vec.num_blocks());
let bit_index = rng.gen_range(0..64);
let block = vec.get_block(block_index);
if !control.contains(&(block_index, bit_index)) {
assert!(!BlockedBitVec::<64>::check_for_block(block, bit_index));
}
let block_mut = vec.get_block_mut(block_index);
control.insert((block_index, bit_index));
BlockedBitVec::<64>::set_for_block(block_mut, bit_index);
assert!(BlockedBitVec::<64>::check_for_block(block_mut, bit_index));
}
}
}