#[derive(Clone)]
pub struct BitVector {
storage: Vec<u64>,
select1_index: Vec<u32>,
select0_index: Vec<u32>,
len: usize,
}
impl std::fmt::Debug for BitVector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BitVector")
.field("len", &self.len)
.field("ones", &self.rank1(self.len))
.finish()
}
}
impl BitVector {
pub fn new(bits: &[u64], len: usize) -> Self {
let num_blocks = len.div_ceil(512);
let mut storage = vec![0u64; num_blocks * 10 + 10]; let mut select1_index = Vec::new();
let mut select0_index = Vec::new();
let mut total_rank = 0u64;
let mut next_select1_threshold = 0u64;
let mut next_select0_threshold = 0u64;
for i in 0..num_blocks {
let base = i * 10;
storage[base] = total_rank;
let total_zeros = (i as u64 * 512) - total_rank;
while total_rank >= next_select1_threshold {
select1_index.push(i as u32);
next_select1_threshold += 512;
}
while total_zeros >= next_select0_threshold {
select0_index.push(i as u32);
next_select0_threshold += 512;
}
let mut relative_ranks = 0u64;
let mut current_rel = 0u64;
for j in 0..8 {
let data_idx = i * 8 + j;
let word = if data_idx < bits.len() {
bits[data_idx]
} else {
0
};
storage[base + 2 + j] = word;
if j > 0 {
relative_ranks |= current_rel << (9 * (j - 1));
}
current_rel += word.count_ones() as u64;
}
storage[base + 1] = relative_ranks;
total_rank += current_rel;
}
let last_base = num_blocks * 10;
storage[last_base] = total_rank;
let total_zeros = (num_blocks as u64 * 512) - total_rank;
while total_rank >= next_select1_threshold {
select1_index.push(num_blocks as u32);
next_select1_threshold += 512;
}
while total_zeros >= next_select0_threshold {
select0_index.push(num_blocks as u32);
next_select0_threshold += 512;
}
Self {
storage,
select1_index,
select0_index,
len,
}
}
pub fn from_ones(positions: impl Iterator<Item = usize>, len: usize) -> Self {
let num_words = len.div_ceil(64);
let mut bits = vec![0u64; num_words];
for pos in positions {
if pos < len {
bits[pos / 64] |= 1u64 << (pos % 64);
}
}
Self::new(&bits, len)
}
pub fn from_parts(
storage: Vec<u64>,
select1_index: Vec<u32>,
select0_index: Vec<u32>,
len: usize,
) -> crate::error::Result<Self> {
if storage.len() < 10 {
return Err(crate::error::Error::InvalidEncoding(
"bitvec storage too small (need >= 10 words)".to_string(),
));
}
if !storage.len().is_multiple_of(10) {
return Err(crate::error::Error::InvalidEncoding(
"bitvec storage len must be multiple of 10".to_string(),
));
}
let num_blocks = storage.len() / 10 - 1;
let max_bits = num_blocks * 512;
if len > max_bits {
return Err(crate::error::Error::InvalidEncoding(format!(
"bitvec len ({len}) exceeds capacity of {num_blocks} blocks ({max_bits} bits)"
)));
}
Ok(Self {
storage,
select1_index,
select0_index,
len,
})
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::new();
out.extend_from_slice(b"SBITBV01");
out.extend_from_slice(&(self.len as u64).to_le_bytes());
out.extend_from_slice(&(self.storage.len() as u64).to_le_bytes());
for &w in &self.storage {
out.extend_from_slice(&w.to_le_bytes());
}
out.extend_from_slice(&(self.select1_index.len() as u64).to_le_bytes());
for &w in &self.select1_index {
out.extend_from_slice(&w.to_le_bytes());
}
out.extend_from_slice(&(self.select0_index.len() as u64).to_le_bytes());
for &w in &self.select0_index {
out.extend_from_slice(&w.to_le_bytes());
}
out
}
pub fn from_bytes(bytes: &[u8]) -> crate::error::Result<Self> {
use crate::error::ByteReader;
let mut r = ByteReader::new(bytes);
r.read_magic(b"SBITBV01", "BitVector")?;
let len = r.read_u64()? as usize;
let storage_len = r.read_u64()? as usize;
let storage = r.read_u64_vec(storage_len)?;
let select1_len = r.read_u64()? as usize;
let select1_index = r.read_u32_vec(select1_len)?;
let select0_len = r.read_u64()? as usize;
let select0_index = r.read_u32_vec(select0_len)?;
r.expect_eof("BitVector")?;
Self::from_parts(storage, select1_index, select0_index, len)
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn heap_bytes(&self) -> usize {
self.storage.len() * 8 + self.select1_index.len() * 4 + self.select0_index.len() * 4
}
pub fn get(&self, i: usize) -> bool {
if i >= self.len {
return false;
}
let block_idx = i / 512;
let word_in_block = (i % 512) / 64;
let bit_in_word = i % 64;
let word = self.storage[block_idx * 10 + 2 + word_in_block];
(word & (1u64 << bit_in_word)) != 0
}
pub fn rank1(&self, i: usize) -> usize {
if i == 0 {
return 0;
}
let i = i.min(self.len);
let block_idx = i / 512;
let sub_block_idx = (i % 512) / 64;
let bit_offset = i % 64;
let base = block_idx * 10;
let mut rank = self.storage[base] as usize;
if sub_block_idx > 0 {
let relative_ranks = self.storage[base + 1];
rank += ((relative_ranks >> (9 * (sub_block_idx - 1))) & 0x1FF) as usize;
}
let word = self.storage[base + 2 + sub_block_idx];
let mask = (1u64 << bit_offset).wrapping_sub(1);
rank += (word & mask).count_ones() as usize;
rank
}
pub fn rank0(&self, i: usize) -> usize {
i - self.rank1(i)
}
pub fn select1(&self, k: usize) -> Option<usize> {
if k >= self.rank1(self.len) {
return None;
}
let target = k + 1;
let select_idx = k / 512;
let mut block_low = self.select1_index[select_idx] as usize;
let mut block_high = if select_idx + 1 < self.select1_index.len() {
self.select1_index[select_idx + 1] as usize + 1
} else {
self.storage.len() / 10
};
while block_low < block_high {
let mid = block_low + (block_high - block_low) / 2;
if self.storage[mid * 10] < target as u64 {
block_low = mid + 1;
} else {
block_high = mid;
}
}
let block_idx = block_low - 1;
let mut remaining_k = target - (self.storage[block_idx * 10] as usize);
let relative_ranks = self.storage[block_idx * 10 + 1];
let mut sub_block_idx = 0;
for j in 1..8 {
let rel_rank = ((relative_ranks >> (9 * (j - 1))) & 0x1FF) as usize;
if rel_rank < remaining_k {
sub_block_idx = j;
} else {
break;
}
}
if sub_block_idx > 0 {
let rel_rank = ((relative_ranks >> (9 * (sub_block_idx - 1))) & 0x1FF) as usize;
remaining_k -= rel_rank;
}
let word = self.storage[block_idx * 10 + 2 + sub_block_idx];
let pos_in_word = self.select_in_word(word, remaining_k - 1);
Some(block_idx * 512 + sub_block_idx * 64 + pos_in_word)
}
pub fn select0(&self, k: usize) -> Option<usize> {
if k >= self.rank0(self.len) {
return None;
}
let target = k + 1;
let select_idx = k / 512;
let mut block_low = self.select0_index[select_idx] as usize;
let mut block_high = if select_idx + 1 < self.select0_index.len() {
self.select0_index[select_idx + 1] as usize + 1
} else {
self.storage.len() / 10
};
while block_low < block_high {
let mid = block_low + (block_high - block_low) / 2;
let rank0_at_mid = (mid * 512) - (self.storage[mid * 10] as usize);
if rank0_at_mid < target {
block_low = mid + 1;
} else {
block_high = mid;
}
}
let block_idx = block_low - 1;
let mut remaining_k =
target - ((block_idx * 512) - (self.storage[block_idx * 10] as usize));
let relative_ranks1 = self.storage[block_idx * 10 + 1];
let mut sub_block_idx = 0;
for j in 1..8 {
let rel_rank1 = ((relative_ranks1 >> (9 * (j - 1))) & 0x1FF) as usize;
let rel_rank0 = (j * 64) - rel_rank1;
if rel_rank0 < remaining_k {
sub_block_idx = j;
} else {
break;
}
}
if sub_block_idx > 0 {
let rel_rank1 = ((relative_ranks1 >> (9 * (sub_block_idx - 1))) & 0x1FF) as usize;
remaining_k -= (sub_block_idx * 64) - rel_rank1;
}
let word = !self.storage[block_idx * 10 + 2 + sub_block_idx];
let pos_in_word = self.select_in_word(word, remaining_k - 1);
Some(block_idx * 512 + sub_block_idx * 64 + pos_in_word)
}
fn select_in_word(&self, word: u64, k: usize) -> usize {
#[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
{
unsafe {
let mask = 1u64 << k;
let res = core::arch::x86_64::_pdep_u64(mask, word);
return res.trailing_zeros() as usize;
}
}
let mut w = word;
for _ in 0..k {
debug_assert_ne!(w, 0, "select_in_word: k exceeds popcount");
w &= w.wrapping_sub(1);
}
debug_assert_ne!(w, 0, "select_in_word: k exceeds popcount");
w.trailing_zeros() as usize
}
pub fn ones(&self) -> OnesIter<'_> {
OnesIter {
bv: self,
remaining: self.rank1(self.len),
block: 0,
word_in_block: 0,
current_word: self.first_data_word(0),
base_pos: 0,
}
}
pub fn zeros(&self) -> ZerosIter<'_> {
ZerosIter {
bv: self,
remaining: self.rank0(self.len),
block: 0,
word_in_block: 0,
current_word: self.first_data_word_inverted(0),
base_pos: 0,
}
}
fn first_data_word(&self, block: usize) -> u64 {
if block * 10 + 2 < self.storage.len() {
self.storage[block * 10 + 2]
} else {
0
}
}
fn first_data_word_inverted(&self, block: usize) -> u64 {
if block * 10 + 2 < self.storage.len() {
!self.storage[block * 10 + 2]
} else {
0
}
}
}
pub struct OnesIter<'a> {
bv: &'a BitVector,
remaining: usize,
block: usize,
word_in_block: usize,
current_word: u64,
base_pos: usize,
}
impl Iterator for OnesIter<'_> {
type Item = usize;
fn next(&mut self) -> Option<usize> {
loop {
if self.current_word != 0 {
let bit = self.current_word.trailing_zeros() as usize;
self.current_word &= self.current_word.wrapping_sub(1);
let pos = self.base_pos + bit;
if pos < self.bv.len {
self.remaining -= 1;
return Some(pos);
}
return None;
}
self.word_in_block += 1;
if self.word_in_block >= 8 {
self.block += 1;
self.word_in_block = 0;
}
let idx = self.block * 10 + 2 + self.word_in_block;
if idx >= self.bv.storage.len() {
return None;
}
self.base_pos = self.block * 512 + self.word_in_block * 64;
if self.base_pos >= self.bv.len {
return None;
}
self.current_word = self.bv.storage[idx];
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
impl ExactSizeIterator for OnesIter<'_> {}
pub struct ZerosIter<'a> {
bv: &'a BitVector,
remaining: usize,
block: usize,
word_in_block: usize,
current_word: u64,
base_pos: usize,
}
impl Iterator for ZerosIter<'_> {
type Item = usize;
fn next(&mut self) -> Option<usize> {
loop {
if self.current_word != 0 {
let bit = self.current_word.trailing_zeros() as usize;
self.current_word &= self.current_word.wrapping_sub(1);
let pos = self.base_pos + bit;
if pos < self.bv.len {
self.remaining -= 1;
return Some(pos);
}
return None;
}
self.word_in_block += 1;
if self.word_in_block >= 8 {
self.block += 1;
self.word_in_block = 0;
}
let idx = self.block * 10 + 2 + self.word_in_block;
if idx >= self.bv.storage.len() {
return None;
}
self.base_pos = self.block * 512 + self.word_in_block * 64;
if self.base_pos >= self.bv.len {
return None;
}
self.current_word = !self.bv.storage[idx];
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
impl ExactSizeIterator for ZerosIter<'_> {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bitvector_rank_basic() {
let data = vec![0b1011, 0b1101];
let bv = BitVector::new(&data, 128);
assert_eq!(bv.rank1(0), 0);
assert_eq!(bv.rank1(1), 1);
assert_eq!(bv.rank1(4), 3);
assert!(bv.get(0));
assert!(!bv.get(2));
}
#[test]
fn test_bitvector_select_basic() {
let data = vec![0b1011];
let bv = BitVector::new(&data, 64);
assert_eq!(bv.select1(0), Some(0));
assert_eq!(bv.select1(1), Some(1));
assert_eq!(bv.select1(2), Some(3));
assert_eq!(bv.select1(3), None);
assert_eq!(bv.select0(0), Some(2));
assert_eq!(bv.select0(1), Some(4));
}
#[test]
fn test_bitvector_serialization_roundtrip() {
let data = vec![0b1011, 0b1101];
let bv = BitVector::new(&data, 128);
let bytes = bv.to_bytes();
let bv2 = BitVector::from_bytes(&bytes).unwrap();
assert_eq!(bv2.len(), 128);
assert_eq!(bv2.rank1(4), 3);
assert!(bv2.get(0));
assert!(!bv2.get(2));
}
#[test]
fn test_bitvector_from_parts_rejects_bad_len() {
let storage = vec![0u64; 20]; assert!(BitVector::from_parts(storage.clone(), vec![], vec![], 512).is_ok());
assert!(BitVector::from_parts(storage, vec![], vec![], 513).is_err());
}
#[test]
fn test_bitvector_from_bytes_rejects_allocation_bomb() {
let mut bytes = Vec::new();
bytes.extend_from_slice(b"SBITBV01");
bytes.extend_from_slice(&0u64.to_le_bytes());
bytes.extend_from_slice(&(u64::MAX).to_le_bytes());
assert!(BitVector::from_bytes(&bytes).is_err());
}
#[test]
fn test_bitvector_empty() {
let bv = BitVector::new(&[], 0);
assert!(bv.is_empty());
assert_eq!(bv.len(), 0);
assert_eq!(bv.rank1(0), 0);
assert_eq!(bv.rank0(0), 0);
assert_eq!(bv.select1(0), None);
assert_eq!(bv.select0(0), None);
assert!(!bv.get(0));
assert_eq!(bv.ones().count(), 0);
assert_eq!(bv.zeros().count(), 0);
}
#[test]
fn test_bitvector_get_oob_returns_false() {
let bv = BitVector::new(&[0xFFFF], 16);
assert!(bv.get(0));
assert!(bv.get(15));
assert!(!bv.get(16));
assert!(!bv.get(1000));
}
#[test]
fn test_bitvector_cross_block_boundary() {
let data = vec![u64::MAX; 9];
let bv = BitVector::new(&data, 576);
assert_eq!(bv.rank1(512), 512);
assert_eq!(bv.rank1(576), 576);
for k in 0..576 {
assert_eq!(bv.select1(k), Some(k), "select1({k}) failed");
}
assert_eq!(bv.select1(576), None);
}
#[test]
fn test_bitvector_serialization_verifies_select() {
let data = vec![0b1011, 0b1101];
let bv = BitVector::new(&data, 128);
let bytes = bv.to_bytes();
let bv2 = BitVector::from_bytes(&bytes).unwrap();
assert_eq!(bv2.select1(0), bv.select1(0));
assert_eq!(bv2.select1(2), bv.select1(2));
assert_eq!(bv2.select0(0), bv.select0(0));
assert_eq!(bv2.select0(2), bv.select0(2));
}
#[test]
fn test_bitvector_ones_iter() {
let bv = BitVector::new(&[0b1011], 64);
let ones: Vec<usize> = bv.ones().collect();
assert_eq!(ones, vec![0, 1, 3]);
}
#[test]
fn test_bitvector_zeros_iter() {
let bv = BitVector::new(&[0b1011], 4);
let zeros: Vec<usize> = bv.zeros().collect();
assert_eq!(zeros, vec![2]);
}
}