use crate::bitvec::BitVector;
use crate::error::{ByteReader, Result};
use alloc::vec;
use alloc::vec::Vec;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct WaveletTree {
levels: Vec<BitVector>,
zeros: Vec<usize>,
len: usize,
sigma: u32,
depth: usize,
}
impl WaveletTree {
pub fn new(data: &[u32], sigma: u32) -> Self {
for (i, &v) in data.iter().enumerate() {
assert!(
v < sigma,
"WaveletTree: symbol {} at index {} >= sigma {}",
v,
i,
sigma
);
}
let depth = if sigma <= 1 {
0
} else {
(u32::BITS - (sigma - 1).leading_zeros()) as usize
};
let n = data.len();
let mut levels = Vec::with_capacity(depth);
let mut zeros = Vec::with_capacity(depth);
let mut current: Vec<u32> = data.to_vec();
for level in 0..depth {
let bit_pos = depth - 1 - level; let mut bits = vec![0u64; n.div_ceil(64)];
let mut left = Vec::new();
let mut right = Vec::new();
for (i, &v) in current.iter().enumerate() {
if (v >> bit_pos) & 1 == 1 {
bits[i / 64] |= 1u64 << (i % 64);
right.push(v);
} else {
left.push(v);
}
}
let bv = BitVector::new(&bits, n);
zeros.push(bv.rank0(n));
levels.push(bv);
current.clear();
current.extend_from_slice(&left);
current.extend_from_slice(&right);
}
Self {
levels,
zeros,
len: n,
sigma,
depth,
}
}
pub fn sigma(&self) -> u32 {
self.sigma
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn access(&self, mut i: usize) -> u32 {
assert!(
i < self.len,
"WaveletTree::access: index {i} >= len {}",
self.len
);
let mut symbol = 0u32;
for level in 0..self.depth {
let bit_pos = self.depth - 1 - level;
let r1 = self.levels[level].rank1(i);
if self.levels[level].get(i) {
symbol |= 1 << bit_pos;
i = self.zeros[level] + r1;
} else {
i -= r1;
}
}
symbol
}
pub fn rank(&self, symbol: u32, mut i: usize) -> usize {
let mut start = 0usize;
for level in 0..self.depth {
let bit_pos = self.depth - 1 - level;
if (symbol >> bit_pos) & 1 == 1 {
start = self.zeros[level] + self.levels[level].rank1(start);
i = self.zeros[level] + self.levels[level].rank1(i);
} else {
start = self.levels[level].rank0(start);
i = self.levels[level].rank0(i);
}
}
i - start
}
pub fn select(&self, symbol: u32, k: usize) -> Option<usize> {
let start = self.symbol_start(symbol);
let mut i = start + k;
for level in (0..self.depth).rev() {
let bit_pos = self.depth - 1 - level;
if (symbol >> bit_pos) & 1 == 1 {
let rank_in_right = i - self.zeros[level];
i = self.levels[level].select1(rank_in_right)?;
} else {
i = self.levels[level].select0(i)?;
}
}
if i < self.len {
Some(i)
} else {
None
}
}
pub fn get(&self, i: usize) -> Option<u32> {
if i < self.len {
Some(self.access(i))
} else {
None
}
}
pub fn heap_bytes(&self) -> usize {
self.levels.iter().map(|bv| bv.heap_bytes()).sum::<usize>()
+ self.zeros.len() * core::mem::size_of::<usize>()
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::new();
out.extend_from_slice(b"SBITWM01");
out.extend_from_slice(&(self.len as u64).to_le_bytes());
out.extend_from_slice(&self.sigma.to_le_bytes());
out.extend_from_slice(&(self.depth as u32).to_le_bytes());
for level in &self.levels {
let bv_bytes = level.to_bytes();
out.extend_from_slice(&(bv_bytes.len() as u64).to_le_bytes());
out.extend_from_slice(&bv_bytes);
}
for &z in &self.zeros {
out.extend_from_slice(&(z as u64).to_le_bytes());
}
out
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
let mut r = ByteReader::new(bytes);
r.read_magic(b"SBITWM01", "WaveletMatrix")?;
let len = r.read_u64()? as usize;
let sigma = r.read_u32()?;
let depth = r.read_u32()? as usize;
let mut levels = Vec::with_capacity(depth);
for _ in 0..depth {
let bv_len = r.read_u64()? as usize;
let bv_bytes = r.take(bv_len)?;
levels.push(BitVector::from_bytes(bv_bytes)?);
}
let mut zeros = Vec::with_capacity(depth);
for _ in 0..depth {
zeros.push(r.read_u64()? as usize);
}
r.expect_eof("WaveletMatrix")?;
Ok(Self {
levels,
zeros,
len,
sigma,
depth,
})
}
fn symbol_start(&self, symbol: u32) -> usize {
let mut lo = 0usize;
let mut hi = self.len;
for level in 0..self.depth {
let bit_pos = self.depth - 1 - level;
if (symbol >> bit_pos) & 1 == 1 {
lo = self.zeros[level] + self.levels[level].rank1(lo);
hi = self.zeros[level] + self.levels[level].rank1(hi);
} else {
lo = self.levels[level].rank0(lo);
hi = self.levels[level].rank0(hi);
}
}
lo
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn test_wavelet_tree_basic() {
let data = vec![3, 1, 2, 0, 3, 0, 1, 2];
let wt = WaveletTree::new(&data, 4);
assert_eq!(wt.len(), 8);
assert_eq!(wt.access(0), 3);
assert_eq!(wt.access(3), 0);
assert_eq!(wt.rank(3, 8), 2);
assert_eq!(wt.rank(0, 8), 2);
assert_eq!(wt.rank(1, 8), 2);
assert_eq!(wt.rank(2, 8), 2);
assert_eq!(wt.rank(3, 4), 1);
assert_eq!(wt.rank(0, 4), 1);
}
#[test]
fn test_wavelet_tree_select() {
let data = vec![3, 1, 2, 0, 3, 0, 1, 2];
let wt = WaveletTree::new(&data, 4);
assert_eq!(wt.select(3, 0), Some(0));
assert_eq!(wt.select(3, 1), Some(4));
assert_eq!(wt.select(0, 0), Some(3));
assert_eq!(wt.select(0, 1), Some(5));
assert_eq!(wt.select(2, 1), Some(7));
assert_eq!(wt.select(3, 2), None);
assert_eq!(wt.select(0, 2), None);
assert_eq!(wt.select(1, 2), None);
assert_eq!(wt.select(2, 2), None);
}
#[test]
fn test_wavelet_tree_sigma_1() {
let data = vec![0, 0, 0, 0];
let wt = WaveletTree::new(&data, 1);
assert_eq!(wt.len(), 4);
assert_eq!(wt.access(0), 0);
assert_eq!(wt.access(3), 0);
assert_eq!(wt.rank(0, 4), 4);
assert_eq!(wt.select(0, 0), Some(0));
assert_eq!(wt.select(0, 3), Some(3));
assert_eq!(wt.select(0, 4), None);
}
#[test]
fn test_wavelet_tree_sigma_2() {
let data = vec![0, 1, 0, 1, 1];
let wt = WaveletTree::new(&data, 2);
assert_eq!(wt.rank(0, 5), 2);
assert_eq!(wt.rank(1, 5), 3);
assert_eq!(wt.select(0, 0), Some(0));
assert_eq!(wt.select(0, 1), Some(2));
assert_eq!(wt.select(1, 0), Some(1));
assert_eq!(wt.select(1, 2), Some(4));
}
#[test]
fn test_wavelet_tree_access_all() {
let data = vec![3, 1, 2, 0, 3, 0, 1, 2];
let wt = WaveletTree::new(&data, 4);
for (i, &expected) in data.iter().enumerate() {
assert_eq!(wt.access(i), expected);
}
}
#[test]
fn test_wavelet_tree_distinct_ranks() {
let data = vec![0, 0, 0, 1, 1, 2];
let wt = WaveletTree::new(&data, 3);
assert_eq!(wt.rank(0, 6), 3);
assert_eq!(wt.rank(1, 6), 2);
assert_eq!(wt.rank(2, 6), 1);
}
#[test]
fn test_wavelet_matrix_serialization() {
let data = vec![3, 1, 2, 0, 3, 0, 1, 2];
let wt = WaveletTree::new(&data, 4);
let bytes = wt.to_bytes();
let wt2 = WaveletTree::from_bytes(&bytes).unwrap();
assert_eq!(wt2.len(), wt.len());
assert_eq!(wt2.sigma(), wt.sigma());
for i in 0..wt.len() {
assert_eq!(wt2.access(i), wt.access(i));
}
}
}