use super::super::BitVector;
use super::rank_select::SuccinctBitVector;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct WaveletTree {
levels: Vec<SuccinctBitVector>,
height: usize,
sigma: u64,
len: usize,
symbols: Vec<u64>,
symbol_to_code: hashbrown::HashMap<u64, u64>,
}
impl WaveletTree {
#[must_use]
pub fn new(sequence: &[u64]) -> Self {
if sequence.is_empty() {
return Self {
levels: Vec::new(),
height: 0,
sigma: 0,
len: 0,
symbols: Vec::new(),
symbol_to_code: hashbrown::HashMap::default(),
};
}
let mut symbols: Vec<u64> = sequence.to_vec();
symbols.sort_unstable();
symbols.dedup();
let sigma = symbols.len() as u64;
let height = if sigma <= 1 {
1
} else {
64 - (sigma - 1).leading_zeros() as usize
};
let mut symbol_to_code = hashbrown::HashMap::with_capacity(symbols.len());
for (code, &sym) in symbols.iter().enumerate() {
symbol_to_code.insert(sym, code as u64);
}
let codes: Vec<u64> = sequence
.iter()
.map(|&s| {
*symbol_to_code
.get(&s)
.expect("symbol_to_code built from same sequence")
})
.collect();
let levels = Self::build_levels(&codes, height);
Self {
levels,
height,
sigma,
len: sequence.len(),
symbols,
symbol_to_code,
}
}
fn build_levels(codes: &[u64], height: usize) -> Vec<SuccinctBitVector> {
if codes.is_empty() || height == 0 {
return Vec::new();
}
let mut levels = Vec::with_capacity(height);
let mut current_sequence: Vec<(u64, usize)> = codes
.iter()
.copied()
.enumerate()
.map(|(i, c)| (c, i))
.collect();
for level in 0..height {
let bit_pos = height - 1 - level;
let mut bits = BitVector::with_capacity(current_sequence.len());
for &(code, _) in ¤t_sequence {
let bit = (code >> bit_pos) & 1;
bits.push(bit == 1);
}
levels.push(SuccinctBitVector::from_bitvec(bits));
let mut left = Vec::new();
let mut right = Vec::new();
for &(code, orig_idx) in ¤t_sequence {
let bit = (code >> bit_pos) & 1;
if bit == 0 {
left.push((code, orig_idx));
} else {
right.push((code, orig_idx));
}
}
current_sequence = left;
current_sequence.extend(right);
}
levels
}
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub fn sigma(&self) -> u64 {
self.sigma
}
#[must_use]
pub fn access(&self, i: usize) -> u64 {
assert!(i < self.len, "Index {} out of bounds (len={})", i, self.len);
let mut pos = i;
let mut code = 0u64;
for level in 0..self.height {
let bv = &self.levels[level];
let bit = bv.get(pos).unwrap_or(false);
let bit_pos = self.height - 1 - level;
if bit {
code |= 1 << bit_pos;
let zeros_total = bv.count_zeros();
pos = zeros_total + bv.rank1(pos);
} else {
pos = bv.rank0(pos);
}
}
let code_idx = usize::try_from(code).ok();
code_idx
.and_then(|idx| self.symbols.get(idx))
.copied()
.unwrap_or(0)
}
#[must_use]
pub fn rank(&self, symbol: u64, i: usize) -> usize {
if i == 0 || self.is_empty() {
return 0;
}
let Some(&code) = self.symbol_to_code.get(&symbol) else {
return 0; };
let i = i.min(self.len);
let mut lo = 0;
let mut hi = i;
for level in 0..self.height {
let bv = &self.levels[level];
let bit_pos = self.height - 1 - level;
let bit = (code >> bit_pos) & 1;
if bit == 0 {
lo = bv.rank0(lo);
hi = bv.rank0(hi);
} else {
let zeros_total = bv.count_zeros();
lo = zeros_total + bv.rank1(lo);
hi = zeros_total + bv.rank1(hi);
}
}
hi - lo
}
#[must_use]
pub fn select(&self, symbol: u64, k: usize) -> Option<usize> {
if self.is_empty() {
return None;
}
let &code = self.symbol_to_code.get(&symbol)?;
let mut lo = 0usize;
let mut hi = self.len;
for level in 0..self.height {
let bv = &self.levels[level];
let bit_pos = self.height - 1 - level;
let bit = (code >> bit_pos) & 1;
if bit == 0 {
lo = bv.rank0(lo);
hi = bv.rank0(hi);
} else {
let zeros_total = bv.count_zeros();
lo = zeros_total + bv.rank1(lo);
hi = zeros_total + bv.rank1(hi);
}
}
if k >= hi - lo {
return None;
}
let mut pos = lo + k;
for level in (0..self.height).rev() {
let bv = &self.levels[level];
let bit_pos = self.height - 1 - level;
let bit = (code >> bit_pos) & 1;
if bit == 0 {
pos = bv.select0(pos)?;
} else {
let zeros_total = bv.count_zeros();
let rank_in_right = pos - zeros_total;
pos = bv.select1(rank_in_right)?;
}
}
Some(pos)
}
#[must_use]
pub fn count(&self, symbol: u64) -> usize {
self.rank(symbol, self.len)
}
pub fn alphabet(&self) -> impl Iterator<Item = u64> + '_ {
self.symbols.iter().copied()
}
#[must_use]
pub fn size_bytes(&self) -> usize {
let base = std::mem::size_of::<Self>();
let levels_bytes: usize = self.levels.iter().map(|bv| bv.size_bytes()).sum();
let symbols_bytes = self.symbols.len() * 8;
let map_bytes = self.symbol_to_code.len() * 16;
base + levels_bytes + symbols_bytes + map_bytes
}
pub fn iter(&self) -> impl Iterator<Item = (usize, u64)> + '_ {
(0..self.len).map(move |i| (i, self.access(i)))
}
pub fn validate(&self) -> Result<(), String> {
if self.len == 0 {
if self.height != 0 {
return Err(format!("empty tree has non-zero height {}", self.height));
}
if self.sigma != 0 {
return Err(format!("empty tree has non-zero sigma {}", self.sigma));
}
if !self.levels.is_empty() {
return Err(format!("empty tree has {} levels", self.levels.len()));
}
if !self.symbols.is_empty() {
return Err(format!("empty tree has {} symbols", self.symbols.len()));
}
return Ok(());
}
if self.levels.len() != self.height {
return Err(format!(
"levels count {} != height {}",
self.levels.len(),
self.height
));
}
let sigma_usize = usize::try_from(self.sigma).map_err(|_| {
format!(
"sigma {} exceeds usize::MAX, cannot validate on this platform",
self.sigma
)
})?;
if self.symbols.len() != sigma_usize {
return Err(format!(
"symbols count {} != sigma {}",
self.symbols.len(),
self.sigma
));
}
if self.symbol_to_code.len() != self.symbols.len() {
return Err(format!(
"symbol_to_code size {} != symbols count {}",
self.symbol_to_code.len(),
self.symbols.len()
));
}
let expected_height = if self.sigma <= 1 {
1
} else {
64 - (self.sigma - 1).leading_zeros() as usize
};
if self.height != expected_height {
return Err(format!(
"height {} inconsistent with sigma {} (expected {expected_height})",
self.height, self.sigma
));
}
for (i, bv) in self.levels.iter().enumerate() {
if bv.len() != self.len {
return Err(format!(
"level {i} bitvector length {} != sequence length {}",
bv.len(),
self.len
));
}
}
for (i, &sym) in self.symbols.iter().enumerate() {
match self.symbol_to_code.get(&sym) {
Some(&code) if code == i as u64 => {}
Some(&code) => {
return Err(format!("symbol_to_code[{sym}] = {code}, expected {i}"));
}
None => {
return Err(format!(
"symbol {sym} at index {i} missing from symbol_to_code"
));
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty() {
let wt = WaveletTree::new(&[]);
assert!(wt.is_empty());
assert_eq!(wt.len(), 0);
assert_eq!(wt.sigma(), 0);
assert_eq!(wt.rank(0, 0), 0);
assert_eq!(wt.select(0, 0), None);
}
#[test]
fn test_single() {
let wt = WaveletTree::new(&[42]);
assert_eq!(wt.len(), 1);
assert_eq!(wt.access(0), 42);
assert_eq!(wt.rank(42, 1), 1);
assert_eq!(wt.select(42, 0), Some(0));
}
#[test]
fn test_small() {
let seq = vec![0, 1, 0, 2, 1, 0, 2, 2];
let wt = WaveletTree::new(&seq);
for (i, &expected) in seq.iter().enumerate() {
assert_eq!(wt.access(i), expected, "access({}) failed", i);
}
assert_eq!(wt.rank(0, 0), 0);
assert_eq!(wt.rank(0, 1), 1);
assert_eq!(wt.rank(0, 3), 2);
assert_eq!(wt.rank(0, 8), 3);
assert_eq!(wt.rank(1, 2), 1);
assert_eq!(wt.rank(1, 5), 2);
assert_eq!(wt.rank(1, 8), 2);
assert_eq!(wt.rank(2, 4), 1);
assert_eq!(wt.rank(2, 8), 3);
assert_eq!(wt.select(0, 0), Some(0));
assert_eq!(wt.select(0, 1), Some(2));
assert_eq!(wt.select(0, 2), Some(5));
assert_eq!(wt.select(0, 3), None);
assert_eq!(wt.select(1, 0), Some(1));
assert_eq!(wt.select(1, 1), Some(4));
assert_eq!(wt.select(1, 2), None);
assert_eq!(wt.select(2, 0), Some(3));
assert_eq!(wt.select(2, 1), Some(6));
assert_eq!(wt.select(2, 2), Some(7));
}
#[test]
#[allow(clippy::cast_sign_loss)]
fn test_rank_select_consistency() {
let seq: Vec<u64> = (0..1000).map(|i| (i % 10) as u64).collect();
let wt = WaveletTree::new(&seq);
for sym in 0..10u64 {
let count = wt.count(sym);
for k in 0..count {
let pos = wt.select(sym, k).expect("select should succeed");
assert_eq!(
wt.rank(sym, pos),
k,
"rank(select({})) mismatch for symbol {}",
k,
sym
);
assert_eq!(wt.access(pos), sym, "access mismatch at position {}", pos);
}
}
}
#[test]
fn test_access_all() {
let seq: Vec<u64> = vec![5, 3, 8, 1, 3, 5, 1, 8, 3];
let wt = WaveletTree::new(&seq);
for (i, &expected) in seq.iter().enumerate() {
assert_eq!(wt.access(i), expected, "access({}) failed", i);
}
}
#[test]
fn test_large_alphabet() {
let seq: Vec<u64> = (0..100).map(|i| i * 7 % 50).collect();
let wt = WaveletTree::new(&seq);
assert_eq!(wt.len(), 100);
for (i, &expected) in seq.iter().enumerate() {
assert_eq!(wt.access(i), expected, "access({}) failed", i);
}
}
#[test]
fn test_count() {
let seq = vec![0, 1, 0, 2, 1, 0, 2, 2];
let wt = WaveletTree::new(&seq);
assert_eq!(wt.count(0), 3);
assert_eq!(wt.count(1), 2);
assert_eq!(wt.count(2), 3);
assert_eq!(wt.count(99), 0); }
#[test]
fn test_nonexistent_symbol() {
let wt = WaveletTree::new(&[1, 2, 3]);
assert_eq!(wt.rank(99, 3), 0);
assert_eq!(wt.select(99, 0), None);
assert_eq!(wt.count(99), 0);
}
#[test]
fn test_alphabet() {
let seq = vec![5, 3, 8, 1];
let wt = WaveletTree::new(&seq);
let mut alpha: Vec<u64> = wt.alphabet().collect();
alpha.sort_unstable();
assert_eq!(alpha, vec![1, 3, 5, 8]);
}
#[test]
fn test_iter() {
let seq = vec![2, 0, 1];
let wt = WaveletTree::new(&seq);
let collected: Vec<(usize, u64)> = wt.iter().collect();
assert_eq!(collected, vec![(0, 2), (1, 0), (2, 1)]);
}
#[test]
fn test_single_symbol_repeated() {
let seq = vec![7, 7, 7, 7, 7];
let wt = WaveletTree::new(&seq);
assert_eq!(wt.sigma(), 1);
for i in 0..5 {
assert_eq!(wt.access(i), 7);
}
assert_eq!(wt.rank(7, 3), 3);
assert_eq!(wt.select(7, 2), Some(2));
}
#[test]
fn test_large_values() {
let seq: Vec<u64> = vec![1_000_000, 5_000_000, 1_000_000, 10_000_000];
let wt = WaveletTree::new(&seq);
for (i, &expected) in seq.iter().enumerate() {
assert_eq!(wt.access(i), expected, "access({}) failed", i);
}
assert_eq!(wt.count(1_000_000), 2);
assert_eq!(wt.rank(1_000_000, 3), 2);
}
#[test]
fn test_validate_empty() {
let wt = WaveletTree::new(&[]);
assert!(wt.validate().is_ok());
}
#[test]
fn test_validate_non_empty() {
let wt = WaveletTree::new(&[0, 1, 0, 2, 1, 0, 2, 2]);
assert!(wt.validate().is_ok());
}
#[test]
fn test_validate_single_symbol() {
let wt = WaveletTree::new(&[7, 7, 7]);
assert!(wt.validate().is_ok());
}
#[test]
fn test_validate_large_alphabet() {
let seq: Vec<u64> = (0..100).map(|i| i * 7 % 50).collect();
let wt = WaveletTree::new(&seq);
assert!(wt.validate().is_ok());
}
#[test]
fn test_validate_bad_height() {
let wt = WaveletTree::new(&[0, 1, 2, 3]);
let mut json_val: serde_json::Value = serde_json::to_value(&wt).unwrap();
let first_level = json_val["levels"][0].clone();
let levels = json_val["levels"].as_array_mut().unwrap();
while levels.len() < 5 {
levels.push(first_level.clone());
}
json_val["height"] = serde_json::json!(5);
let corrupted: WaveletTree = serde_json::from_value(json_val).unwrap();
let err = corrupted.validate().unwrap_err();
assert!(
err.contains("height") && err.contains("inconsistent"),
"expected height-inconsistent error, got: {err}"
);
}
#[test]
fn test_validate_bad_level_count() {
let wt = WaveletTree::new(&[0, 1, 2, 3]);
let mut json_val: serde_json::Value = serde_json::to_value(&wt).unwrap();
let levels = json_val["levels"].as_array_mut().unwrap();
levels.pop();
let corrupted: WaveletTree = serde_json::from_value(json_val).unwrap();
let err = corrupted.validate().unwrap_err();
assert!(
err.contains("levels count") && err.contains("height"),
"expected levels/height mismatch error, got: {err}"
);
}
#[test]
fn test_validate_mismatched_symbols() {
let wt = WaveletTree::new(&[0, 1, 2, 3]);
let mut json_val: serde_json::Value = serde_json::to_value(&wt).unwrap();
let symbols = json_val["symbols"].as_array_mut().unwrap();
symbols.pop();
let corrupted: WaveletTree = serde_json::from_value(json_val).unwrap();
let err = corrupted.validate().unwrap_err();
assert!(
err.contains("symbols count") && err.contains("sigma"),
"expected symbols/sigma mismatch error, got: {err}"
);
}
}