use crate::Bitmap;
use super::{bitmask_for_key, index_for_key, vec::VecBitmap};
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CompressedBitmap {
block_map: Vec<usize>,
bitmap: Vec<usize>,
#[cfg(debug_assertions)]
max_key: usize,
}
impl CompressedBitmap {
pub fn new(max_key: usize) -> Self {
let blocks = index_for_key(max_key);
let num_blocks = match blocks % (u64::BITS as usize) {
0 => index_for_key(blocks),
_ => index_for_key(blocks) + 1, };
let block_map = vec![0; num_blocks];
CompressedBitmap {
bitmap: Vec::new(),
block_map,
#[cfg(debug_assertions)]
max_key,
}
}
pub fn size(&self) -> usize {
(self.block_map.capacity() * std::mem::size_of::<usize>())
+ (self.bitmap.capacity() * std::mem::size_of::<usize>())
+ std::mem::size_of_val(self)
}
pub fn shrink_to_fit(&mut self) {
self.bitmap.shrink_to_fit();
self.block_map.shrink_to_fit();
}
pub fn clear(&mut self) {
for block in self.block_map.iter_mut() {
*block = 0;
}
self.bitmap.truncate(0);
}
pub fn set(&mut self, key: usize, value: bool) {
#[cfg(debug_assertions)]
debug_assert!(key <= self.max_key, "key {} > {} max", key, self.max_key);
let block_index = index_for_key(key);
let block_map_index = index_for_key(block_index);
let block_map_bitmask = bitmask_for_key(block_index);
let offset: usize = (0..block_map_index)
.map(|i| self.block_map[i].count_ones() as usize)
.sum();
let mask = block_map_bitmask - 1;
let offset = offset + (self.block_map[block_map_index] & mask).count_ones() as usize;
if self.block_map[block_map_index] & block_map_bitmask == 0 {
if !value {
return;
}
if offset >= self.bitmap.len() {
self.bitmap.push(bitmask_for_key(key));
} else {
self.bitmap.insert(offset, bitmask_for_key(key));
}
self.block_map[block_map_index] |= block_map_bitmask;
return;
}
if value {
self.bitmap[offset] |= bitmask_for_key(key);
} else {
self.bitmap[offset] &= !bitmask_for_key(key);
}
}
pub fn get(&self, key: usize) -> bool {
let block_index = index_for_key(key);
let block_map_index = index_for_key(block_index);
let block_map_bitmask = bitmask_for_key(block_index);
if self.block_map[block_map_index] & block_map_bitmask == 0 {
return false;
}
let offset: usize = (0..block_map_index)
.map(|i| self.block_map[i].count_ones() as usize)
.sum();
let mask = block_map_bitmask - 1;
let offset: usize = offset + (self.block_map[block_map_index] & mask).count_ones() as usize;
self.bitmap[offset] & bitmask_for_key(key) != 0
}
pub fn or(&self, other: &Self) -> Self {
#[cfg(debug_assertions)]
debug_assert_eq!(self.max_key, other.max_key);
assert_eq!(self.block_map.len(), other.block_map.len());
let left = BlockMapIter::new(self);
let right = BlockMapIter::new(other);
let bitmap = left
.zip(right)
.filter_map(|(l, r)| {
Some(match (l, r) {
(None, None) => return None,
(None, Some(r)) => other.bitmap[r],
(Some(l), None) => self.bitmap[l],
(Some(l), Some(r)) => self.bitmap[l] | other.bitmap[r],
})
})
.collect::<Vec<_>>();
let block_map = self
.block_map
.iter()
.zip(&other.block_map)
.map(|(l, r)| l | r)
.collect::<Vec<_>>();
debug_assert_eq!(
block_map.iter().map(|v| v.count_ones()).sum::<u32>() as usize,
bitmap.len()
);
Self {
block_map,
bitmap,
#[cfg(debug_assertions)]
max_key: self.max_key,
}
}
}
#[derive(Debug)]
struct BlockMapIter<'a> {
bitmap: &'a CompressedBitmap,
block_idx: usize,
block_bit: u8,
physical_idx: usize,
}
impl<'a> BlockMapIter<'a> {
fn new(bitmap: &'a CompressedBitmap) -> Self {
Self {
bitmap,
block_idx: 0,
block_bit: 0,
physical_idx: 0,
}
}
}
impl Iterator for BlockMapIter<'_> {
type Item = Option<usize>;
fn next(&mut self) -> Option<Self::Item> {
let block = self.bitmap.block_map.get(self.block_idx)?;
let v = if (block & (1 << self.block_bit)) > 0 {
let idx = self.physical_idx;
self.physical_idx += 1;
Some(idx)
} else {
None
};
self.block_bit += 1;
if self.block_bit == usize::BITS as u8 {
self.block_bit = 0;
self.block_idx += 1;
}
Some(v)
}
}
impl Bitmap for CompressedBitmap {
fn get(&self, key: usize) -> bool {
self.get(key)
}
fn set(&mut self, key: usize, value: bool) {
self.set(key, value)
}
fn byte_size(&self) -> usize {
self.size()
}
fn or(&self, other: &Self) -> Self {
self.or(other)
}
fn new_with_capacity(max_key: usize) -> Self {
Self::new(max_key)
}
}
impl From<VecBitmap> for CompressedBitmap {
fn from(bitmap: VecBitmap) -> Self {
let (bitmap, max_key) = bitmap.into_parts();
let num_blocks = index_for_key(max_key);
let num_blocks = match num_blocks % (u64::BITS as usize) {
0 => index_for_key(num_blocks),
_ => index_for_key(num_blocks) + 1, };
let mut block_map = vec![0; num_blocks];
let mut compressed = Vec::default();
for (idx, block) in bitmap.into_iter().enumerate() {
if block == 0 {
continue;
}
compressed.push(block);
block_map[index_for_key(idx)] |= bitmask_for_key(idx);
}
CompressedBitmap {
block_map,
bitmap: compressed,
#[cfg(debug_assertions)]
max_key,
}
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use quickcheck_macros::quickcheck;
use super::*;
macro_rules! contains_only_truthy {
($bitmap:ident, $max:expr; $(
$element:expr
),*) => {
let truthy = vec![$($element,)*];
for i in 0..$max {
assert!($bitmap.get(i) == truthy.contains(&i), "unexpected value {}", i);
}
};
}
#[test]
fn test_set_contains() {
let mut b = CompressedBitmap::new(100);
b.set(100, true);
b.set(0, true);
b.set(42, true);
contains_only_truthy!(b, 100; 100, 0, 42);
assert!(b.get(100));
assert!(b.get(0));
assert!(b.get(42));
}
#[test]
fn test_clear() {
let mut b = CompressedBitmap::new(100);
b.set(100, true);
b.set(0, true);
b.set(42, true);
contains_only_truthy!(b, 100; 100, 0, 42);
b.clear();
contains_only_truthy!(b, 100;);
}
#[test]
fn test_set_true_false() {
let mut b = CompressedBitmap::new(100);
b.set(42, true);
assert!(b.get(42));
b.set(42, false);
assert!(!b.get(42));
}
#[test]
fn test_block_map_iter() {
let mut bitmap = CompressedBitmap::new(i16::MAX as _);
bitmap.set(1, true); bitmap.set(usize::BITS as usize * 4, true); bitmap.set(usize::BITS as usize * 64, true); bitmap.set(usize::BITS as usize * 65, true); bitmap.set(usize::BITS as usize * 128, true);
let mut iter = BlockMapIter::new(&bitmap).enumerate();
assert_eq!(iter.next().unwrap(), (0, Some(0))); assert_eq!(iter.next().unwrap(), (1, None)); assert_eq!(iter.next().unwrap(), (2, None)); assert_eq!(iter.next().unwrap(), (3, None)); assert_eq!(iter.next().unwrap(), (4, Some(1)));
let mut iter = iter.filter_map(|(idx, block)| block.map(|v| (idx, v)));
assert_eq!(iter.next().unwrap(), (64, 2)); assert_eq!(iter.next().unwrap(), (65, 3));
assert_eq!(iter.next().unwrap(), (128, 4));
assert!(iter.next().is_none());
}
#[quickcheck]
#[should_panic]
fn test_panic_exceeds_max(max: u16) {
let max = max as usize;
let mut b = CompressedBitmap::new(max);
b.set(max + 1, true);
}
#[quickcheck]
fn test_set_contains_prop(mut vals: Vec<u16>) {
vals.truncate(10);
let mut b = CompressedBitmap::new(u16::MAX.into());
for v in &vals {
b.set(*v as usize, true);
}
for i in 0..u16::MAX {
assert!(
b.get(i as usize) == vals.contains(&i),
"unexpected value {}",
i
);
}
}
#[quickcheck]
fn test_or(mut a: Vec<u16>, mut b: Vec<u16>) {
a.truncate(10);
let mut bitmap_a = CompressedBitmap::new(u16::MAX.into());
for v in &a {
bitmap_a.set(*v as usize, true);
}
b.truncate(10);
let mut bitmap_b = CompressedBitmap::new(u16::MAX.into());
for v in &b {
bitmap_b.set(*v as usize, true);
}
let merged = bitmap_a.or(&bitmap_b);
for i in 0..u16::MAX {
let want_hit = a.contains(&i) || b.contains(&i);
assert!(
merged.get(i as usize) == want_hit,
"unexpected value {} want={:?}",
i,
want_hit
);
}
}
#[cfg(feature = "serde")]
#[test]
fn serde() {
let mut b = CompressedBitmap::new(100);
b.set(1, true);
b.set(2, false);
b.set(3, true);
contains_only_truthy!(b, 100; 1, 3);
let encoded = serde_json::to_string(&b).unwrap();
let decoded: CompressedBitmap = serde_json::from_str(&encoded).unwrap();
contains_only_truthy!(decoded, 100; 1, 3);
}
const MAX_KEY: usize = 1028;
proptest! {
#[test]
fn prop_compress(
values in prop::collection::hash_set(0..MAX_KEY, 0..20),
) {
let mut b = VecBitmap::new_with_capacity(MAX_KEY);
for v in &values {
b.set(*v, true);
}
let b = CompressedBitmap::from(b);
for i in 0..MAX_KEY {
assert_eq!(b.get(i), values.contains(&i));
}
}
}
}