1use crate::{CyaneaError, Result};
10
11const SUPERBLOCK_SIZE: usize = 512;
13const BLOCKS_PER_SUPER: usize = SUPERBLOCK_SIZE / 64;
15
16#[derive(Debug, Clone)]
22pub struct RankSelectBitVec {
23 blocks: Vec<u64>,
24 superblocks: Vec<usize>,
26 len: usize,
27}
28
29impl RankSelectBitVec {
30 pub fn build(bits: &[bool]) -> Self {
32 let n = bits.len();
33 let num_blocks = (n + 63) / 64;
34 let mut blocks = vec![0u64; num_blocks];
35
36 for (i, &b) in bits.iter().enumerate() {
37 if b {
38 blocks[i / 64] |= 1u64 << (i % 64);
39 }
40 }
41
42 let num_super_groups = (num_blocks + BLOCKS_PER_SUPER - 1) / BLOCKS_PER_SUPER;
46 let mut superblocks = vec![0usize; num_super_groups + 1];
47 let mut cumulative = 0usize;
48 for (i, block) in blocks.iter().enumerate() {
49 if i % BLOCKS_PER_SUPER == 0 {
50 superblocks[i / BLOCKS_PER_SUPER] = cumulative;
51 }
52 cumulative += block.count_ones() as usize;
53 }
54 superblocks[num_super_groups] = cumulative;
55
56 Self {
57 blocks,
58 superblocks,
59 len: n,
60 }
61 }
62
63 pub fn get(&self, i: usize) -> bool {
69 assert!(i < self.len, "index out of bounds");
70 (self.blocks[i / 64] >> (i % 64)) & 1 == 1
71 }
72
73 pub fn rank1(&self, i: usize) -> usize {
77 assert!(i <= self.len, "rank1: index out of bounds");
78 if i == 0 {
79 return 0;
80 }
81
82 let block_idx = (i - 1) / 64;
83 let super_idx = block_idx / BLOCKS_PER_SUPER;
84 let mut count = self.superblocks[super_idx];
85
86 let first_block = super_idx * BLOCKS_PER_SUPER;
88 for b in first_block..block_idx {
89 count += self.blocks[b].count_ones() as usize;
90 }
91
92 let bit_pos = i % 64;
94 if bit_pos == 0 {
95 count += self.blocks[block_idx].count_ones() as usize;
96 } else {
97 let mask = (1u64 << bit_pos) - 1;
98 count += (self.blocks[block_idx] & mask).count_ones() as usize;
99 }
100
101 count
102 }
103
104 pub fn rank0(&self, i: usize) -> usize {
106 i - self.rank1(i)
107 }
108
109 pub fn select1(&self, k: usize) -> Option<usize> {
113 if k == 0 || k > self.count_ones() {
114 return None;
115 }
116
117 let mut lo = 0;
119 let mut hi = self.superblocks.len() - 1;
120 while lo < hi {
121 let mid = lo + (hi - lo + 1) / 2;
122 if self.superblocks[mid] < k {
123 lo = mid;
124 } else {
125 hi = mid - 1;
126 }
127 }
128
129 let mut remaining = k - self.superblocks[lo];
130 let first_block = lo * BLOCKS_PER_SUPER;
131
132 for b in first_block..self.blocks.len() {
134 let popcnt = self.blocks[b].count_ones() as usize;
135 if popcnt >= remaining {
136 let mut word = self.blocks[b];
138 for _ in 1..remaining {
139 word &= word - 1; }
141 let bit_in_block = word.trailing_zeros() as usize;
142 let pos = b * 64 + bit_in_block;
143 return if pos < self.len { Some(pos) } else { None };
144 }
145 remaining -= popcnt;
146 }
147
148 None
149 }
150
151 pub fn select0(&self, k: usize) -> Option<usize> {
155 if k == 0 || k > self.count_zeros() {
156 return None;
157 }
158
159 let mut remaining = k;
161 for (b, &block) in self.blocks.iter().enumerate() {
162 let zeros_in_block = if (b + 1) * 64 <= self.len {
163 64 - block.count_ones() as usize
164 } else {
165 let valid_bits = self.len - b * 64;
166 let mask = if valid_bits >= 64 {
167 u64::MAX
168 } else {
169 (1u64 << valid_bits) - 1
170 };
171 valid_bits - (block & mask).count_ones() as usize
172 };
173
174 if zeros_in_block >= remaining {
175 let mut word = if (b + 1) * 64 <= self.len {
177 !block
178 } else {
179 let valid_bits = self.len - b * 64;
180 let mask = (1u64 << valid_bits) - 1;
181 !block & mask
182 };
183 for _ in 1..remaining {
184 word &= word - 1;
185 }
186 let bit_in_block = word.trailing_zeros() as usize;
187 let pos = b * 64 + bit_in_block;
188 return if pos < self.len { Some(pos) } else { None };
189 }
190 remaining -= zeros_in_block;
191 }
192
193 None
194 }
195
196 pub fn len(&self) -> usize {
198 self.len
199 }
200
201 pub fn is_empty(&self) -> bool {
203 self.len == 0
204 }
205
206 pub fn count_ones(&self) -> usize {
208 self.rank1(self.len)
209 }
210
211 pub fn count_zeros(&self) -> usize {
213 self.len - self.count_ones()
214 }
215}
216
217#[derive(Debug, Clone)]
224pub struct WaveletMatrix {
225 levels: Vec<RankSelectBitVec>,
226 num_zeros: Vec<usize>,
228 sigma: usize,
229 len: usize,
230}
231
232impl WaveletMatrix {
233 pub fn build(symbols: &[usize], sigma: usize) -> Result<Self> {
239 if sigma == 0 {
240 return Err(CyaneaError::InvalidInput(
241 "WaveletMatrix: sigma must be positive".into(),
242 ));
243 }
244 if let Some(&s) = symbols.iter().find(|&&s| s >= sigma) {
245 return Err(CyaneaError::InvalidInput(format!(
246 "WaveletMatrix: symbol {} out of range [0, {})",
247 s, sigma
248 )));
249 }
250
251 let n = symbols.len();
252 let num_levels = if sigma <= 1 { 1 } else { (sigma as f64).log2().ceil() as usize };
253
254 let mut levels = Vec::with_capacity(num_levels);
255 let mut num_zeros = Vec::with_capacity(num_levels);
256 let mut current = symbols.to_vec();
257
258 for level in (0..num_levels).rev() {
259 let bit = 1 << level;
260 let bits: Vec<bool> = current.iter().map(|&s| s & bit != 0).collect();
261 let bv = RankSelectBitVec::build(&bits);
262 let nz = bv.count_zeros();
263 num_zeros.push(nz);
264 levels.push(bv);
265
266 let mut next = Vec::with_capacity(n);
268 for &s in ¤t {
269 if s & bit == 0 {
270 next.push(s);
271 }
272 }
273 for &s in ¤t {
274 if s & bit != 0 {
275 next.push(s);
276 }
277 }
278 current = next;
279 }
280
281 Ok(Self {
282 levels,
283 num_zeros,
284 sigma,
285 len: n,
286 })
287 }
288
289 pub fn access(&self, mut i: usize) -> Option<usize> {
293 if i >= self.len {
294 return None;
295 }
296
297 let mut symbol = 0;
298 for (level_idx, bv) in self.levels.iter().enumerate() {
299 let bit_val = 1 << (self.levels.len() - 1 - level_idx);
300 if bv.get(i) {
301 symbol |= bit_val;
302 i = self.num_zeros[level_idx] + bv.rank1(i);
303 } else {
304 i = bv.rank0(i);
305 }
306 }
307
308 Some(symbol)
309 }
310
311 pub fn rank(&self, c: usize, mut i: usize) -> usize {
313 if c >= self.sigma || i == 0 {
314 return 0;
315 }
316 if i > self.len {
317 i = self.len;
318 }
319
320 let mut lo = 0;
321 let mut hi = i;
322
323 for (level_idx, bv) in self.levels.iter().enumerate() {
324 let bit_val = 1 << (self.levels.len() - 1 - level_idx);
325 if c & bit_val != 0 {
326 lo = self.num_zeros[level_idx] + bv.rank1(lo);
327 hi = self.num_zeros[level_idx] + bv.rank1(hi);
328 } else {
329 lo = bv.rank0(lo);
330 hi = bv.rank0(hi);
331 }
332 }
333
334 hi - lo
335 }
336
337 pub fn select(&self, c: usize, k: usize) -> Option<usize> {
341 if c >= self.sigma || k == 0 {
342 return None;
343 }
344
345 let mut lo = 0usize;
347 let mut hi = self.len;
348 for (level_idx, bv) in self.levels.iter().enumerate() {
349 let bit_val = 1 << (self.levels.len() - 1 - level_idx);
350 if c & bit_val != 0 {
351 lo = self.num_zeros[level_idx] + bv.rank1(lo);
352 hi = self.num_zeros[level_idx] + bv.rank1(hi);
353 } else {
354 lo = bv.rank0(lo);
355 hi = bv.rank0(hi);
356 }
357 }
358
359 if k > hi - lo {
360 return None;
361 }
362
363 let mut pos = lo + k - 1;
365 for level_idx in (0..self.levels.len()).rev() {
366 let bv = &self.levels[level_idx];
367 let bit_val = 1 << (self.levels.len() - 1 - level_idx);
368 if c & bit_val != 0 {
369 let target_rank = pos - self.num_zeros[level_idx] + 1;
373 pos = bv.select1(target_rank)?;
374 } else {
375 let target_rank = pos + 1;
376 pos = bv.select0(target_rank)?;
377 }
378 }
379
380 Some(pos)
381 }
382
383 pub fn len(&self) -> usize {
385 self.len
386 }
387
388 pub fn is_empty(&self) -> bool {
390 self.len == 0
391 }
392
393 pub fn sigma(&self) -> usize {
395 self.sigma
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
406 fn rank_empty() {
407 let bv = RankSelectBitVec::build(&[]);
408 assert_eq!(bv.len(), 0);
409 assert!(bv.is_empty());
410 assert_eq!(bv.count_ones(), 0);
411 }
412
413 #[test]
414 fn rank_basic() {
415 let bits = [true, false, true, true, false, true, false, false];
417 let bv = RankSelectBitVec::build(&bits);
418 assert_eq!(bv.len(), 8);
419 assert_eq!(bv.count_ones(), 4);
420 assert_eq!(bv.count_zeros(), 4);
421
422 assert_eq!(bv.rank1(0), 0);
423 assert_eq!(bv.rank1(1), 1);
424 assert_eq!(bv.rank1(2), 1);
425 assert_eq!(bv.rank1(3), 2);
426 assert_eq!(bv.rank1(4), 3);
427 assert_eq!(bv.rank1(8), 4);
428 }
429
430 #[test]
431 fn rank0_basic() {
432 let bits = [true, false, true, true, false, true, false, false];
433 let bv = RankSelectBitVec::build(&bits);
434 assert_eq!(bv.rank0(0), 0);
435 assert_eq!(bv.rank0(2), 1);
436 assert_eq!(bv.rank0(8), 4);
437 }
438
439 #[test]
440 fn get_bits() {
441 let bits = [true, false, true, false];
442 let bv = RankSelectBitVec::build(&bits);
443 assert!(bv.get(0));
444 assert!(!bv.get(1));
445 assert!(bv.get(2));
446 assert!(!bv.get(3));
447 }
448
449 #[test]
450 fn select1_basic() {
451 let bits = [true, false, true, true, false, true, false, false];
452 let bv = RankSelectBitVec::build(&bits);
453 assert_eq!(bv.select1(1), Some(0));
454 assert_eq!(bv.select1(2), Some(2));
455 assert_eq!(bv.select1(3), Some(3));
456 assert_eq!(bv.select1(4), Some(5));
457 assert_eq!(bv.select1(5), None);
458 assert_eq!(bv.select1(0), None);
459 }
460
461 #[test]
462 fn select0_basic() {
463 let bits = [true, false, true, true, false, true, false, false];
464 let bv = RankSelectBitVec::build(&bits);
465 assert_eq!(bv.select0(1), Some(1));
466 assert_eq!(bv.select0(2), Some(4));
467 assert_eq!(bv.select0(3), Some(6));
468 assert_eq!(bv.select0(4), Some(7));
469 assert_eq!(bv.select0(5), None);
470 }
471
472 #[test]
473 fn rank_large_bitvec() {
474 let n = 1000;
476 let bits: Vec<bool> = (0..n).map(|i| i % 3 == 0).collect();
477 let bv = RankSelectBitVec::build(&bits);
478
479 for i in (0..=n).step_by(100) {
481 let expected = bits[..i].iter().filter(|&&b| b).count();
482 assert_eq!(bv.rank1(i), expected, "rank1({}) mismatch", i);
483 }
484 }
485
486 #[test]
487 fn select1_large() {
488 let n = 1000;
489 let bits: Vec<bool> = (0..n).map(|i| i % 3 == 0).collect();
490 let bv = RankSelectBitVec::build(&bits);
491 assert_eq!(bv.select1(1), Some(0));
493 assert_eq!(bv.select1(2), Some(3));
494 assert_eq!(bv.select1(3), Some(6));
495 }
496
497 #[test]
498 fn all_ones() {
499 let bits = vec![true; 200];
500 let bv = RankSelectBitVec::build(&bits);
501 assert_eq!(bv.count_ones(), 200);
502 assert_eq!(bv.rank1(100), 100);
503 assert_eq!(bv.select1(50), Some(49));
504 }
505
506 #[test]
507 fn all_zeros() {
508 let bits = vec![false; 200];
509 let bv = RankSelectBitVec::build(&bits);
510 assert_eq!(bv.count_zeros(), 200);
511 assert_eq!(bv.rank0(100), 100);
512 assert_eq!(bv.select0(50), Some(49));
513 assert_eq!(bv.select1(1), None);
514 }
515
516 #[test]
519 fn wavelet_access() {
520 let data = [3, 1, 4, 1, 5, 9, 2, 6];
521 let wm = WaveletMatrix::build(&data, 10).unwrap();
522 for (i, &expected) in data.iter().enumerate() {
523 assert_eq!(wm.access(i), Some(expected), "access({}) failed", i);
524 }
525 assert_eq!(wm.access(8), None);
526 }
527
528 #[test]
529 fn wavelet_rank() {
530 let data = [3, 1, 4, 1, 5, 9, 2, 6];
531 let wm = WaveletMatrix::build(&data, 10).unwrap();
532 assert_eq!(wm.rank(1, 4), 2); assert_eq!(wm.rank(1, 2), 1); assert_eq!(wm.rank(4, 3), 1); assert_eq!(wm.rank(7, 8), 0); }
537
538 #[test]
539 fn wavelet_select() {
540 let data = [3, 1, 4, 1, 5, 9, 2, 6];
541 let wm = WaveletMatrix::build(&data, 10).unwrap();
542 assert_eq!(wm.select(1, 1), Some(1)); assert_eq!(wm.select(1, 2), Some(3)); assert_eq!(wm.select(1, 3), None); assert_eq!(wm.select(3, 1), Some(0)); }
547
548 #[test]
549 fn wavelet_binary_alphabet() {
550 let data = [0, 1, 0, 1, 1, 0];
551 let wm = WaveletMatrix::build(&data, 2).unwrap();
552 assert_eq!(wm.rank(0, 6), 3);
553 assert_eq!(wm.rank(1, 6), 3);
554 assert_eq!(wm.select(0, 1), Some(0));
555 assert_eq!(wm.select(0, 2), Some(2));
556 }
557
558 #[test]
559 fn wavelet_single_symbol() {
560 let data = [0, 0, 0, 0];
561 let wm = WaveletMatrix::build(&data, 1).unwrap();
562 assert_eq!(wm.access(0), Some(0));
563 assert_eq!(wm.rank(0, 4), 4);
564 }
565
566 #[test]
567 fn wavelet_empty() {
568 let wm = WaveletMatrix::build(&[], 4).unwrap();
569 assert_eq!(wm.len(), 0);
570 assert!(wm.is_empty());
571 assert_eq!(wm.access(0), None);
572 }
573
574 #[test]
575 fn wavelet_invalid() {
576 assert!(WaveletMatrix::build(&[], 0).is_err());
577 assert!(WaveletMatrix::build(&[5], 4).is_err());
578 }
579
580 #[test]
581 fn wavelet_dna_encoded() {
582 let dna = [0, 1, 2, 3, 0, 1, 2, 3]; let wm = WaveletMatrix::build(&dna, 4).unwrap();
585 assert_eq!(wm.rank(0, 8), 2); assert_eq!(wm.rank(1, 8), 2); assert_eq!(wm.select(2, 1), Some(2)); assert_eq!(wm.select(3, 2), Some(7)); }
590}