use std::fmt;
#[derive(Clone, Debug)]
struct BitRankVector {
words: Vec<u64>,
prefix: Vec<usize>,
len: usize,
}
impl BitRankVector {
fn from_bits(bits: &[bool]) -> Self {
let len = bits.len();
let n_words = (len + 63) / 64;
let mut words = vec![0u64; n_words];
for (i, &b) in bits.iter().enumerate() {
if b {
words[i / 64] |= 1u64 << (i % 64);
}
}
let mut prefix = vec![0usize; n_words + 1];
for (i, &w) in words.iter().enumerate() {
prefix[i + 1] = prefix[i] + w.count_ones() as usize;
}
BitRankVector { words, prefix, len }
}
#[inline]
fn get(&self, i: usize) -> bool {
if i >= self.len {
return false;
}
(self.words[i / 64] >> (i % 64)) & 1 == 1
}
#[inline]
fn rank1(&self, i: usize) -> usize {
if i == 0 {
return 0;
}
let i = i.min(self.len);
let word_idx = (i - 1) / 64;
let bit_pos = (i - 1) % 64;
let mask = if bit_pos == 63 {
u64::MAX
} else {
(1u64 << (bit_pos + 1)) - 1
};
self.prefix[word_idx] + (self.words[word_idx] & mask).count_ones() as usize
}
#[inline]
fn rank0(&self, i: usize) -> usize {
i.min(self.len) - self.rank1(i)
}
fn len(&self) -> usize {
self.len
}
}
#[derive(Clone, Debug)]
struct WaveletNode {
lo: u32,
hi: u32,
bits: BitRankVector,
left: Option<Box<WaveletNode>>,
right: Option<Box<WaveletNode>>,
}
impl WaveletNode {
fn build(seq: &[u32], lo: u32, hi: u32) -> Option<Box<Self>> {
if seq.is_empty() || lo + 1 >= hi {
return None;
}
let mid = lo + (hi - lo) / 2;
let bits_raw: Vec<bool> = seq.iter().map(|&s| s >= mid).collect();
let bits = BitRankVector::from_bits(&bits_raw);
let left_seq: Vec<u32> = seq.iter().copied().filter(|&s| s < mid).collect();
let right_seq: Vec<u32> = seq.iter().copied().filter(|&s| s >= mid).collect();
let left = WaveletNode::build(&left_seq, lo, mid);
let right = WaveletNode::build(&right_seq, mid, hi);
Some(Box::new(WaveletNode {
lo,
hi,
bits,
left,
right,
}))
}
fn access(&self, i: usize) -> u32 {
let mid = self.lo + (self.hi - self.lo) / 2;
if self.lo + 1 == self.hi {
return self.lo;
}
if !self.bits.get(i) {
let j = self.bits.rank0(i + 1) - 1;
match &self.left {
Some(child) => child.access(j),
None => self.lo,
}
} else {
let j = self.bits.rank1(i + 1) - 1;
match &self.right {
Some(child) => child.access(j),
None => mid,
}
}
}
fn rank(&self, symbol: u32, i: usize) -> usize {
if i == 0 {
return 0;
}
let mid = self.lo + (self.hi - self.lo) / 2;
if self.lo + 1 == self.hi {
return i.min(self.bits.len());
}
if symbol < mid {
let j = self.bits.rank0(i);
match &self.left {
Some(child) => child.rank(symbol, j),
None => 0,
}
} else {
let j = self.bits.rank1(i);
match &self.right {
Some(child) => child.rank(symbol, j),
None => 0,
}
}
}
fn select(&self, symbol: u32, k: usize) -> Option<usize> {
if k == 0 {
return None;
}
let mid = self.lo + (self.hi - self.lo) / 2;
if self.lo + 1 == self.hi {
if k <= self.bits.len() {
return Some(self.select_zero(k));
}
return None;
}
if symbol < mid {
let j = match &self.left {
Some(child) => child.select(symbol, k)?,
None => return None,
};
Some(self.select_zero(j + 1))
} else {
let j = match &self.right {
Some(child) => child.select(symbol, k)?,
None => return None,
};
Some(self.select_one(j + 1))
}
}
fn select_zero(&self, k: usize) -> usize {
let mut remaining = k;
for (word_idx, &word) in self.bits.words.iter().enumerate() {
let zeros = ((!word) as u64).count_ones() as usize;
let valid_bits = (self.bits.len - word_idx * 64).min(64);
let valid_mask = if valid_bits == 64 {
u64::MAX
} else {
(1u64 << valid_bits) - 1
};
let valid_zeros = ((!word) & valid_mask).count_ones() as usize;
if remaining <= valid_zeros {
let mut w = (!word) & valid_mask;
for bit_pos in 0..valid_bits {
if (w & 1) == 1 {
remaining -= 1;
if remaining == 0 {
return word_idx * 64 + bit_pos;
}
}
w >>= 1;
}
}
remaining -= valid_zeros;
let _ = zeros;
}
self.bits.len() }
fn select_one(&self, k: usize) -> usize {
let mut remaining = k;
for (word_idx, &word) in self.bits.words.iter().enumerate() {
let valid_bits = (self.bits.len - word_idx * 64).min(64);
let valid_mask = if valid_bits == 64 {
u64::MAX
} else {
(1u64 << valid_bits) - 1
};
let valid_ones = (word & valid_mask).count_ones() as usize;
if remaining <= valid_ones {
let mut w = word & valid_mask;
for bit_pos in 0..valid_bits {
if (w & 1) == 1 {
remaining -= 1;
if remaining == 0 {
return word_idx * 64 + bit_pos;
}
}
w >>= 1;
}
}
remaining -= valid_ones;
}
self.bits.len() }
}
pub struct WaveletTree {
n: usize,
sigma: u32,
root: Option<Box<WaveletNode>>,
}
impl WaveletTree {
pub fn build(seq: &[u32], sigma: u32) -> Self {
let sigma = sigma.max(2);
let clamped: Vec<u32> = seq.iter().map(|&s| s.min(sigma - 1)).collect();
let root = WaveletNode::build(&clamped, 0, sigma);
WaveletTree {
n: seq.len(),
sigma,
root,
}
}
pub fn access(&self, i: usize) -> u32 {
assert!(i < self.n, "index {i} out of bounds (len={})", self.n);
match &self.root {
Some(node) => node.access(i),
None => 0,
}
}
pub fn rank(&self, symbol: u32, i: usize) -> usize {
if symbol >= self.sigma || i == 0 {
return 0;
}
let i = i.min(self.n);
match &self.root {
Some(node) => node.rank(symbol, i),
None => 0,
}
}
pub fn select(&self, symbol: u32, k: usize) -> Option<usize> {
if symbol >= self.sigma || k == 0 {
return None;
}
match &self.root {
Some(node) => node.select(symbol, k),
None => None,
}
}
pub fn len(&self) -> usize {
self.n
}
pub fn is_empty(&self) -> bool {
self.n == 0
}
pub fn sigma(&self) -> u32 {
self.sigma
}
}
impl fmt::Debug for WaveletTree {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WaveletTree")
.field("n", &self.n)
.field("sigma", &self.sigma)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_seq() -> Vec<u32> {
vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3]
}
#[test]
fn access_roundtrip() {
let seq = sample_seq();
let sigma = 10;
let wt = WaveletTree::build(&seq, sigma);
for (i, &expected) in seq.iter().enumerate() {
assert_eq!(
wt.access(i),
expected,
"access({i}) mismatch"
);
}
}
#[test]
fn rank_basic() {
let seq = sample_seq(); let wt = WaveletTree::build(&seq, 10);
assert_eq!(wt.rank(1, 5), 2);
assert_eq!(wt.rank(3, 10), 2);
assert_eq!(wt.rank(9, 10), 1);
assert_eq!(wt.rank(7, 10), 0);
assert_eq!(wt.rank(3, 0), 0);
}
#[test]
fn rank_out_of_sigma() {
let wt = WaveletTree::build(&[0u32, 1, 2], 3);
assert_eq!(wt.rank(3, 3), 0);
assert_eq!(wt.rank(100, 3), 0);
}
#[test]
fn select_basic() {
let seq = sample_seq(); let wt = WaveletTree::build(&seq, 10);
assert_eq!(wt.select(1, 1), Some(1));
assert_eq!(wt.select(1, 2), Some(3));
assert_eq!(wt.select(1, 3), None);
assert_eq!(wt.select(9, 1), Some(5));
assert_eq!(wt.select(5, 2), Some(8));
}
#[test]
fn select_k_zero_returns_none() {
let wt = WaveletTree::build(&[1u32, 2, 3], 4);
assert_eq!(wt.select(1, 0), None);
}
#[test]
fn select_absent_symbol() {
let wt = WaveletTree::build(&[1u32, 2, 3], 4);
assert_eq!(wt.select(0, 1), None);
}
#[test]
fn empty_sequence() {
let wt = WaveletTree::build(&[], 8);
assert!(wt.is_empty());
assert_eq!(wt.rank(0, 0), 0);
assert_eq!(wt.select(0, 1), None);
}
#[test]
fn single_symbol_repeated() {
let seq = vec![7u32; 20];
let wt = WaveletTree::build(&seq, 16);
for i in 0..20 {
assert_eq!(wt.access(i), 7);
}
assert_eq!(wt.rank(7, 20), 20);
assert_eq!(wt.rank(7, 10), 10);
for k in 1..=20 {
assert_eq!(wt.select(7, k), Some(k - 1));
}
assert_eq!(wt.select(7, 21), None);
}
#[test]
fn rank_select_consistency() {
let seq: Vec<u32> = (0u32..8).flat_map(|c| vec![c, c]).collect();
let wt = WaveletTree::build(&seq, 8);
for c in 0u32..8 {
for k in 1..=2 {
if let Some(pos) = wt.select(c, k) {
let r = wt.rank(c, pos + 1);
assert_eq!(r, k, "rank({c}, select({c}, {k})+1) = {r} ≠ {k}");
}
}
}
}
#[test]
fn bit_rank_vector_correctness() {
let bits = vec![true, false, true, true, false, false, true, false];
let brv = BitRankVector::from_bits(&bits);
assert_eq!(brv.rank1(0), 0);
assert_eq!(brv.rank1(1), 1); assert_eq!(brv.rank1(2), 1); assert_eq!(brv.rank1(3), 2); assert_eq!(brv.rank1(4), 3); assert_eq!(brv.rank1(8), 4); assert_eq!(brv.rank0(8), 4);
}
}