use super::super::BitVector;
const SUPERBLOCK_BITS: usize = 512;
const BLOCK_BITS: usize = 64;
const BLOCKS_PER_SUPERBLOCK: usize = SUPERBLOCK_BITS / BLOCK_BITS;
const SELECT_SAMPLE_RATE: usize = 4096;
#[derive(Debug, Clone, serde::Serialize)]
pub struct SuccinctBitVector {
inner: BitVector,
superblock_ranks: Vec<u32>,
block_ranks: Vec<u16>,
select1_samples: Vec<u32>,
select0_samples: Vec<u32>,
ones_count: usize,
}
impl<'de> serde::Deserialize<'de> for SuccinctBitVector {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
struct Raw {
inner: BitVector,
#[allow(dead_code)]
superblock_ranks: Vec<u32>,
#[allow(dead_code)]
block_ranks: Vec<u16>,
#[allow(dead_code)]
select1_samples: Vec<u32>,
#[allow(dead_code)]
select0_samples: Vec<u32>,
#[allow(dead_code)]
ones_count: usize,
}
let raw = Raw::deserialize(deserializer)?;
Ok(Self::from_bitvec(raw.inner))
}
}
impl SuccinctBitVector {
#[must_use]
pub fn from_bools(bools: &[bool]) -> Self {
let inner = BitVector::from_bools(bools);
Self::from_bitvec(inner)
}
#[must_use]
pub fn from_bitvec(inner: BitVector) -> Self {
let len = inner.len();
let num_superblocks = (len + SUPERBLOCK_BITS - 1) / SUPERBLOCK_BITS + 1;
let num_blocks = (len + BLOCK_BITS - 1) / BLOCK_BITS;
let mut superblock_ranks = Vec::with_capacity(num_superblocks);
let mut block_ranks = Vec::with_capacity(num_blocks);
let mut select1_samples = Vec::new();
let mut select0_samples = Vec::new();
let mut cumulative_ones: u32 = 0;
let mut cumulative_zeros: u32 = 0;
let mut superblock_start_ones: u32 = 0;
let data = inner.data();
for (block_idx, word) in data.iter().enumerate() {
let bit_pos = block_idx * BLOCK_BITS;
if block_idx % BLOCKS_PER_SUPERBLOCK == 0 {
superblock_ranks.push(cumulative_ones);
superblock_start_ones = cumulative_ones;
}
let relative_rank = cumulative_ones - superblock_start_ones;
let relative_rank_u16 =
u16::try_from(relative_rank).expect("relative rank within superblock fits u16");
block_ranks.push(relative_rank_u16);
let bits_in_word = if bit_pos + BLOCK_BITS <= len {
BLOCK_BITS
} else {
len - bit_pos
};
let word_ones = if bits_in_word == BLOCK_BITS {
word.count_ones()
} else {
let mask = (1u64 << bits_in_word) - 1;
(word & mask).count_ones()
};
let next_ones = cumulative_ones + word_ones;
while select1_samples.len() * SELECT_SAMPLE_RATE < next_ones as usize {
#[allow(clippy::cast_possible_truncation)]
select1_samples.push(bit_pos as u32);
}
#[allow(clippy::cast_possible_truncation)]
let word_zeros = bits_in_word as u32 - word_ones;
let next_zeros = cumulative_zeros + word_zeros;
while select0_samples.len() * SELECT_SAMPLE_RATE < next_zeros as usize {
#[allow(clippy::cast_possible_truncation)]
select0_samples.push(bit_pos as u32);
}
cumulative_ones = next_ones;
cumulative_zeros = next_zeros;
}
if superblock_ranks.len() * SUPERBLOCK_BITS <= len || superblock_ranks.is_empty() {
superblock_ranks.push(cumulative_ones);
}
Self {
inner,
superblock_ranks,
block_ranks,
select1_samples,
select0_samples,
ones_count: cumulative_ones as usize,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn count_ones(&self) -> usize {
self.ones_count
}
#[must_use]
pub fn count_zeros(&self) -> usize {
self.len() - self.ones_count
}
#[must_use]
pub fn get(&self, index: usize) -> Option<bool> {
self.inner.get(index)
}
#[must_use]
pub fn inner(&self) -> &BitVector {
&self.inner
}
#[must_use]
pub fn into_inner(self) -> BitVector {
self.inner
}
#[must_use]
pub fn rank1(&self, pos: usize) -> usize {
if pos == 0 {
return 0;
}
if pos >= self.len() {
return self.ones_count;
}
let superblock_idx = pos / SUPERBLOCK_BITS;
let block_idx = pos / BLOCK_BITS;
let bit_offset = pos % BLOCK_BITS;
let mut rank = self.superblock_ranks[superblock_idx] as usize;
if block_idx < self.block_ranks.len() {
rank += self.block_ranks[block_idx] as usize;
}
if bit_offset > 0 && block_idx < self.inner.data().len() {
let word = self.inner.data()[block_idx];
let mask = (1u64 << bit_offset) - 1;
rank += (word & mask).count_ones() as usize;
}
rank
}
#[must_use]
pub fn rank0(&self, pos: usize) -> usize {
let pos = pos.min(self.len());
pos - self.rank1(pos)
}
#[must_use]
pub fn select1(&self, k: usize) -> Option<usize> {
if k >= self.ones_count {
return None;
}
let sample_idx = k / SELECT_SAMPLE_RATE;
let start_pos = if sample_idx < self.select1_samples.len() {
self.select1_samples[sample_idx] as usize
} else {
0
};
let start_superblock = start_pos / SUPERBLOCK_BITS;
let target_rank = k + 1;
let superblock_idx = self.binary_search_superblock(target_rank, start_superblock);
let block_start = superblock_idx * BLOCKS_PER_SUPERBLOCK;
let block_end = ((superblock_idx + 1) * BLOCKS_PER_SUPERBLOCK).min(self.block_ranks.len());
let superblock_base_rank = self.superblock_ranks[superblock_idx] as usize;
let mut block_idx = block_start;
for i in block_start..block_end {
let block_rank = superblock_base_rank + self.block_ranks[i] as usize;
if block_rank >= target_rank {
break;
}
block_idx = i;
}
let block_base_rank = superblock_base_rank + self.block_ranks[block_idx] as usize;
let remaining = k - block_base_rank;
if block_idx >= self.inner.data().len() {
return None;
}
let word = self.inner.data()[block_idx];
let bit_pos = Self::select_in_word(word, remaining)?;
let result = block_idx * BLOCK_BITS + bit_pos;
if result < self.len() {
Some(result)
} else {
None
}
}
#[must_use]
pub fn select0(&self, k: usize) -> Option<usize> {
let zeros = self.count_zeros();
if k >= zeros {
return None;
}
let sample_idx = k / SELECT_SAMPLE_RATE;
let start_pos = if sample_idx < self.select0_samples.len() {
self.select0_samples[sample_idx] as usize
} else {
0
};
let mut lo = start_pos;
let mut hi = self.len();
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.rank0(mid + 1) <= k {
lo = mid + 1;
} else {
hi = mid;
}
}
if lo < self.len() && self.rank0(lo + 1) == k + 1 {
Some(lo)
} else {
None
}
}
fn binary_search_superblock(&self, target_rank: usize, start: usize) -> usize {
let mut lo = start;
let mut hi = self.superblock_ranks.len();
while lo + 1 < hi {
let mid = lo + (hi - lo) / 2;
if (self.superblock_ranks[mid] as usize) < target_rank {
lo = mid;
} else {
hi = mid;
}
}
lo
}
fn select_in_word(word: u64, k: usize) -> Option<usize> {
let ones = word.count_ones() as usize;
if k >= ones {
return None;
}
let mut remaining = k;
let mut pos = 0;
for byte_idx in 0..8 {
let byte = ((word >> (byte_idx * 8)) & 0xFF) as u8;
let byte_ones = byte.count_ones() as usize;
if remaining < byte_ones {
for bit in 0..8 {
if (byte >> bit) & 1 == 1 {
if remaining == 0 {
return Some(pos + bit);
}
remaining -= 1;
}
}
}
remaining -= byte_ones;
pos += 8;
}
None
}
#[must_use]
pub fn auxiliary_size_bytes(&self) -> usize {
self.superblock_ranks.len() * 4
+ self.block_ranks.len() * 2
+ self.select1_samples.len() * 4
+ self.select0_samples.len() * 4
}
#[must_use]
pub fn size_bytes(&self) -> usize {
self.inner.data().len() * 8 + self.auxiliary_size_bytes()
}
#[must_use]
pub fn space_overhead(&self) -> f64 {
let data_size = self.inner.data().len() * 8;
if data_size == 0 {
return 0.0;
}
self.auxiliary_size_bytes() as f64 / data_size as f64
}
}
impl Default for SuccinctBitVector {
fn default() -> Self {
Self::from_bitvec(BitVector::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty() {
let sbv = SuccinctBitVector::from_bools(&[]);
assert!(sbv.is_empty());
assert_eq!(sbv.rank1(0), 0);
assert_eq!(sbv.rank0(0), 0);
assert_eq!(sbv.select1(0), None);
assert_eq!(sbv.select0(0), None);
}
#[test]
fn test_all_zeros() {
let sbv = SuccinctBitVector::from_bools(&[false; 100]);
assert_eq!(sbv.count_ones(), 0);
assert_eq!(sbv.count_zeros(), 100);
assert_eq!(sbv.rank1(50), 0);
assert_eq!(sbv.rank0(50), 50);
assert_eq!(sbv.select1(0), None);
assert_eq!(sbv.select0(0), Some(0));
assert_eq!(sbv.select0(99), Some(99));
}
#[test]
fn test_all_ones() {
let sbv = SuccinctBitVector::from_bools(&[true; 100]);
assert_eq!(sbv.count_ones(), 100);
assert_eq!(sbv.count_zeros(), 0);
assert_eq!(sbv.rank1(50), 50);
assert_eq!(sbv.rank0(50), 0);
assert_eq!(sbv.select1(0), Some(0));
assert_eq!(sbv.select1(99), Some(99));
assert_eq!(sbv.select0(0), None);
}
#[test]
fn test_small() {
let sbv = SuccinctBitVector::from_bools(&[true, false, true, true, false]);
assert_eq!(sbv.len(), 5);
assert_eq!(sbv.count_ones(), 3);
assert_eq!(sbv.count_zeros(), 2);
assert_eq!(sbv.rank1(0), 0);
assert_eq!(sbv.rank1(1), 1);
assert_eq!(sbv.rank1(2), 1);
assert_eq!(sbv.rank1(3), 2);
assert_eq!(sbv.rank1(4), 3);
assert_eq!(sbv.rank1(5), 3);
assert_eq!(sbv.rank0(0), 0);
assert_eq!(sbv.rank0(1), 0);
assert_eq!(sbv.rank0(2), 1);
assert_eq!(sbv.rank0(3), 1);
assert_eq!(sbv.rank0(4), 1);
assert_eq!(sbv.rank0(5), 2);
assert_eq!(sbv.select1(0), Some(0));
assert_eq!(sbv.select1(1), Some(2));
assert_eq!(sbv.select1(2), Some(3));
assert_eq!(sbv.select1(3), None);
assert_eq!(sbv.select0(0), Some(1));
assert_eq!(sbv.select0(1), Some(4));
assert_eq!(sbv.select0(2), None);
}
#[test]
fn test_rank_select_consistency() {
let bits: Vec<bool> = (0..1000).map(|i| i % 3 == 0).collect();
let sbv = SuccinctBitVector::from_bools(&bits);
let ones_count = sbv.count_ones();
for k in 0..ones_count {
let pos = sbv.select1(k).expect("select1 should succeed");
assert_eq!(sbv.rank1(pos), k, "rank1(select1({})) != {}", k, k);
assert!(
sbv.get(pos) == Some(true),
"bit at select1({}) should be 1",
k
);
}
let zeros_count = sbv.count_zeros();
for k in 0..zeros_count {
let pos = sbv.select0(k).expect("select0 should succeed");
assert_eq!(sbv.rank0(pos), k, "rank0(select0({})) != {}", k, k);
assert!(
sbv.get(pos) == Some(false),
"bit at select0({}) should be 0",
k
);
}
}
#[test]
fn test_large() {
let bits: Vec<bool> = (0..10000).map(|i| i % 7 == 0).collect();
let sbv = SuccinctBitVector::from_bools(&bits);
let expected_ones = bits.iter().filter(|&&b| b).count();
assert_eq!(sbv.count_ones(), expected_ones);
for pos in [0, 100, 500, 1000, 5000, 9999] {
let expected_rank = bits[..pos].iter().filter(|&&b| b).count();
assert_eq!(sbv.rank1(pos), expected_rank, "rank1({}) mismatch", pos);
}
}
#[test]
fn test_boundary_conditions() {
let bits: Vec<bool> = (0..128).map(|i| i < 64).collect();
let sbv = SuccinctBitVector::from_bools(&bits);
assert_eq!(sbv.rank1(64), 64);
assert_eq!(sbv.rank1(128), 64);
assert_eq!(sbv.select1(63), Some(63));
assert_eq!(sbv.select1(64), None);
}
#[test]
fn test_superblock_boundary() {
let bits: Vec<bool> = (0..1024).map(|i| i < 512).collect();
let sbv = SuccinctBitVector::from_bools(&bits);
assert_eq!(sbv.rank1(512), 512);
assert_eq!(sbv.rank1(1024), 512);
}
#[test]
fn test_space_overhead() {
let bits: Vec<bool> = (0..1_000_000).map(|i| i % 2 == 0).collect();
let sbv = SuccinctBitVector::from_bools(&bits);
let overhead = sbv.space_overhead();
assert!(
overhead < 0.40,
"Space overhead {overhead} is too high (expected < 40%)",
);
}
#[test]
fn test_from_bitvec() {
let bv = BitVector::from_bools(&[true, false, true]);
let sbv = SuccinctBitVector::from_bitvec(bv.clone());
assert_eq!(sbv.inner().to_bools(), bv.to_bools());
assert_eq!(sbv.count_ones(), 2);
}
#[test]
fn test_into_inner() {
let original = BitVector::from_bools(&[true, false, true]);
let sbv = SuccinctBitVector::from_bitvec(original.clone());
let recovered = sbv.into_inner();
assert_eq!(recovered.to_bools(), original.to_bools());
}
}