use super::RankSelectOps;
use crate::error::{Result, ZiporaError};
use crate::succinct::BitVector;
const LINE_BITS: usize = 512;
const WORDS_PER_LINE: usize = LINE_BITS / 64;
#[derive(Debug, Clone, Copy)]
#[repr(C, packed)]
struct RankCacheSE512 {
base: u32, rela: u64, }
#[inline(always)]
fn get_rela(rela: u64, k: usize) -> usize {
if k == 0 {
return 0;
}
((rela >> ((k - 1) * 9)) & 0x1FF) as usize
}
pub struct RankSelectSE512 {
bv: BitVector,
rank_cache: Vec<RankCacheSE512>,
sel0_cache: Option<Vec<u32>>,
sel1_cache: Option<Vec<u32>>,
size: usize,
max_rank0: usize,
max_rank1: usize,
}
pub type RankSelectSE512_32 = RankSelectSE512;
pub type RankSelectSE512_64 = RankSelectSE512;
impl RankSelectSE512 {
pub fn new(bv: BitVector) -> Result<Self> {
Self::with_options(bv, true, true)
}
pub fn with_options(bv: BitVector, speed_select0: bool, speed_select1: bool) -> Result<Self> {
let size = bv.len();
let blocks = bv.blocks();
let nlines = size.div_ceil(LINE_BITS);
let mut rank_cache = Vec::with_capacity(nlines + 1);
let mut cumulative = 0u32;
for i in 0..nlines {
let mut rela = 0u64;
let mut r = 0u64;
for j in 0..WORDS_PER_LINE {
let word_idx = i * WORDS_PER_LINE + j;
let pc = if word_idx < blocks.len() {
blocks[word_idx].count_ones() as u64
} else {
0
};
r += pc;
rela |= r << (j * 9);
}
rela &= u64::MAX >> 1;
rank_cache.push(RankCacheSE512 {
base: cumulative,
rela,
});
cumulative += r as u32;
}
rank_cache.push(RankCacheSE512 {
base: cumulative,
rela: 0,
});
let max_rank1 = cumulative as usize;
let max_rank0 = size - max_rank1;
let sel0_cache = if speed_select0 && max_rank0 > 0 {
Some(Self::build_select_cache(
&rank_cache,
max_rank0,
nlines,
false,
))
} else {
None
};
let sel1_cache = if speed_select1 && max_rank1 > 0 {
Some(Self::build_select_cache(
&rank_cache,
max_rank1,
nlines,
true,
))
} else {
None
};
Ok(Self {
bv,
rank_cache,
sel0_cache,
sel1_cache,
size,
max_rank0,
max_rank1,
})
}
fn build_select_cache(
rank_cache: &[RankCacheSE512],
max_rank: usize,
nlines: usize,
is_rank1: bool,
) -> Vec<u32> {
let slots = max_rank.div_ceil(LINE_BITS);
let mut cache = vec![0u32; slots + 1];
cache[0] = 0;
for j in 1..slots {
let mut k = cache[j - 1] as usize;
while k < nlines {
let rank_at_k = if is_rank1 {
rank_cache[k].base as usize
} else {
k * LINE_BITS - rank_cache[k].base as usize
};
if (is_rank1 && rank_at_k >= LINE_BITS * j)
|| (!is_rank1 && rank_at_k > LINE_BITS * j)
{
break;
}
k += 1;
}
cache[j] = k as u32;
}
cache[slots] = nlines as u32;
cache
}
#[inline(always)]
fn popcount_trail(word: u64, bit_count: usize) -> usize {
if bit_count == 0 {
return 0;
}
if bit_count >= 64 {
return word.count_ones() as usize;
}
(word & ((1u64 << bit_count) - 1)).count_ones() as usize
}
#[inline]
fn select_in_word(word: u64, k: usize) -> usize {
crate::algorithms::bit_ops::select_in_word(word, k)
}
fn upper_bound(&self, rank: usize, is_rank1: bool) -> usize {
let cache = if is_rank1 {
&self.sel1_cache
} else {
&self.sel0_cache
};
let (mut lo, mut hi) = if let Some(c) = cache {
let slot = rank / LINE_BITS;
(c[slot] as usize, c[slot + 1] as usize)
} else {
(0, self.rank_cache.len() - 1)
};
while lo < hi {
let mid = (lo + hi) / 2;
let val = if is_rank1 {
self.rank_cache[mid].base as usize
} else {
mid * LINE_BITS - self.rank_cache[mid].base as usize
};
if val <= rank {
lo = mid + 1;
} else {
hi = mid;
}
}
lo
}
#[inline]
pub fn max_rank0(&self) -> usize {
self.max_rank0
}
pub fn max_rank1(&self) -> usize {
self.max_rank1
}
#[inline]
pub fn mem_size(&self) -> usize {
self.bv.blocks().len() * 8
+ self.rank_cache.len() * std::mem::size_of::<RankCacheSE512>()
+ self.sel0_cache.as_ref().map_or(0, |c| c.len() * 4)
+ self.sel1_cache.as_ref().map_or(0, |c| c.len() * 4)
}
}
impl RankSelectOps for RankSelectSE512 {
#[inline(always)]
fn rank1(&self, bitpos: usize) -> usize {
assert!(bitpos <= self.size);
if bitpos == 0 {
return 0;
}
let block = bitpos / LINE_BITS;
let rc = self.rank_cache[block];
let k = (bitpos % LINE_BITS) / 64; let word_idx = bitpos / 64;
let bit_in_word = bitpos % 64;
rc.base as usize
+ get_rela(rc.rela, k)
+ Self::popcount_trail(
if word_idx < self.bv.blocks().len() {
self.bv.blocks()[word_idx]
} else {
0
},
bit_in_word,
)
}
#[inline(always)]
fn rank0(&self, pos: usize) -> usize {
pos - self.rank1(pos)
}
fn select1(&self, k: usize) -> Result<usize> {
if k >= self.max_rank1 {
return Err(ZiporaError::invalid_data("select1 out of range"));
}
let lo = self.upper_bound(k, true);
assert!(lo > 0);
let block = lo - 1;
let rc = self.rank_cache[block];
let hit = rc.base as usize;
let base_bitpos = block * LINE_BITS;
let target = k - hit;
for j in (0..WORDS_PER_LINE).rev() {
let rank_before_j = get_rela(rc.rela, j);
if target >= rank_before_j {
let remaining = target - rank_before_j;
let word_idx = block * WORDS_PER_LINE + j;
if word_idx < self.bv.blocks().len() {
return Ok(base_bitpos
+ j * 64
+ Self::select_in_word(self.bv.blocks()[word_idx], remaining));
}
}
}
Err(ZiporaError::invalid_data("select1 internal error"))
}
#[inline]
fn select0(&self, k: usize) -> Result<usize> {
if k >= self.max_rank0 {
return Err(ZiporaError::invalid_data("select0 out of range"));
}
let lo = self.upper_bound(k, false);
assert!(lo > 0);
let block = lo - 1;
let rc = self.rank_cache[block];
let hit = block * LINE_BITS - rc.base as usize; let base_bitpos = block * LINE_BITS;
let target = k - hit;
for j in (0..WORDS_PER_LINE).rev() {
let rank1_before_j = get_rela(rc.rela, j);
let zeros_before_j = j * 64 - rank1_before_j;
if target >= zeros_before_j {
let remaining = target - zeros_before_j;
let word_idx = block * WORDS_PER_LINE + j;
let word = if word_idx < self.bv.blocks().len() {
self.bv.blocks()[word_idx]
} else {
0
};
return Ok(base_bitpos + j * 64 + Self::select_in_word(!word, remaining));
}
}
Err(ZiporaError::invalid_data("select0 internal error"))
}
fn len(&self) -> usize {
self.size
}
fn count_ones(&self) -> usize {
self.max_rank1
}
fn get(&self, index: usize) -> Option<bool> {
if index >= self.size {
return None;
}
let word_idx = index / 64;
let bit_idx = index % 64;
if word_idx < self.bv.blocks().len() {
Some((self.bv.blocks()[word_idx] >> bit_idx) & 1 == 1)
} else {
Some(false)
}
}
fn space_overhead_percent(&self) -> f64 {
if self.size == 0 {
return 0.0;
}
let bit_bytes = self.size.div_ceil(8);
let overhead = self.mem_size() - bit_bytes;
(overhead as f64 / bit_bytes as f64) * 100.0
}
}
impl std::fmt::Debug for RankSelectSE512 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RankSelectSE512")
.field("size", &self.size)
.field("max_rank1", &self.max_rank1)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_rs(pattern: &[bool]) -> RankSelectSE512 {
let mut bv = BitVector::new();
for &b in pattern {
bv.push(b).unwrap();
}
RankSelectSE512::new(bv).unwrap()
}
#[test]
fn test_basic() {
let rs = make_rs(&[true, false, true, false, true]);
assert_eq!(rs.len(), 5);
assert_eq!(rs.count_ones(), 3);
assert_eq!(rs.rank1(1), 1);
assert_eq!(rs.rank1(5), 3);
assert_eq!(rs.select1(0).unwrap(), 0);
assert_eq!(rs.select1(2).unwrap(), 4);
}
#[test]
fn test_invariant() {
let pattern: Vec<bool> = (0..3000).map(|i| i % 7 == 0).collect();
let rs = make_rs(&pattern);
for i in 0..=rs.len() {
assert_eq!(rs.rank0(i) + rs.rank1(i), i, "invariant at {}", i);
}
}
#[test]
fn test_roundtrip() {
let pattern: Vec<bool> = (0..2000).map(|i| i % 5 == 0).collect();
let rs = make_rs(&pattern);
for k in 0..rs.count_ones() {
let pos = rs.select1(k).unwrap();
assert_eq!(rs.get(pos), Some(true));
}
}
#[test]
fn test_crossing_512_boundary() {
let mut pattern = vec![false; 600];
pattern[0] = true;
pattern[511] = true; pattern[512] = true; pattern[599] = true;
let rs = make_rs(&pattern);
assert_eq!(rs.count_ones(), 4);
assert_eq!(rs.select1(0).unwrap(), 0);
assert_eq!(rs.select1(1).unwrap(), 511);
assert_eq!(rs.select1(2).unwrap(), 512);
assert_eq!(rs.select1(3).unwrap(), 599);
}
#[test]
fn test_large() {
let pattern: Vec<bool> = (0..10000).map(|i| i % 13 == 0).collect();
let rs = make_rs(&pattern);
let expected = (0..10000).filter(|i| i % 13 == 0).count();
assert_eq!(rs.count_ones(), expected);
assert_eq!(rs.select1(0).unwrap(), 0);
assert_eq!(rs.select1(1).unwrap(), 13);
}
#[test]
fn test_get_rela() {
let mut rela = 0u64;
rela |= 5 << (0 * 9); rela |= 12 << (1 * 9); rela |= 20 << (2 * 9); assert_eq!(get_rela(rela, 0), 0);
assert_eq!(get_rela(rela, 1), 5);
assert_eq!(get_rela(rela, 2), 12);
assert_eq!(get_rela(rela, 3), 20);
}
#[test]
fn test_empty() {
let rs = make_rs(&[]);
assert_eq!(rs.len(), 0);
assert_eq!(rs.rank1(0), 0);
}
#[test]
fn test_select0() {
let pattern: Vec<bool> = (0..1000).map(|i| i % 3 == 0).collect();
let rs = make_rs(&pattern);
assert_eq!(rs.select0(0).unwrap(), 1);
assert_eq!(rs.select0(1).unwrap(), 2);
}
}