Skip to main content

cyanea_core/
bitvec.rs

1//! Rank/select bitvectors and wavelet matrix.
2//!
3//! [`RankSelectBitVec`] supports O(1) rank queries and O(log n) select queries
4//! using a two-level index over u64 blocks.
5//!
6//! [`WaveletMatrix`] uses a stack of bitvectors to support access, rank, and
7//! select over integer alphabets in O(log σ) time.
8
9use crate::{CyaneaError, Result};
10
11/// Superblock size in bits (must be a multiple of 64).
12const SUPERBLOCK_SIZE: usize = 512;
13/// Number of u64 blocks per superblock.
14const BLOCKS_PER_SUPER: usize = SUPERBLOCK_SIZE / 64;
15
16/// A bitvector with O(1) rank and O(log n) select support.
17///
18/// Uses u64 blocks with a superblock index every 512 bits. Rank queries
19/// use `u64::count_ones()` for popcount. Select uses binary search over
20/// superblocks.
21#[derive(Debug, Clone)]
22pub struct RankSelectBitVec {
23    blocks: Vec<u64>,
24    /// Cumulative popcount at the start of each superblock.
25    superblocks: Vec<usize>,
26    len: usize,
27}
28
29impl RankSelectBitVec {
30    /// Build a bitvector from a slice of booleans.
31    pub fn build(bits: &[bool]) -> Self {
32        let n = bits.len();
33        let num_blocks = (n + 63) / 64;
34        let mut blocks = vec![0u64; num_blocks];
35
36        for (i, &b) in bits.iter().enumerate() {
37            if b {
38                blocks[i / 64] |= 1u64 << (i % 64);
39            }
40        }
41
42        // Build superblock index
43        // superblocks[i] = cumulative popcount before the i-th superblock group.
44        // An extra sentinel entry stores the total count so binary search works.
45        let num_super_groups = (num_blocks + BLOCKS_PER_SUPER - 1) / BLOCKS_PER_SUPER;
46        let mut superblocks = vec![0usize; num_super_groups + 1];
47        let mut cumulative = 0usize;
48        for (i, block) in blocks.iter().enumerate() {
49            if i % BLOCKS_PER_SUPER == 0 {
50                superblocks[i / BLOCKS_PER_SUPER] = cumulative;
51            }
52            cumulative += block.count_ones() as usize;
53        }
54        superblocks[num_super_groups] = cumulative;
55
56        Self {
57            blocks,
58            superblocks,
59            len: n,
60        }
61    }
62
63    /// Get the bit at position `i`.
64    ///
65    /// # Panics
66    ///
67    /// Panics if `i >= len`.
68    pub fn get(&self, i: usize) -> bool {
69        assert!(i < self.len, "index out of bounds");
70        (self.blocks[i / 64] >> (i % 64)) & 1 == 1
71    }
72
73    /// Count the number of 1-bits in positions `[0, i)`.
74    ///
75    /// Returns 0 if `i == 0`. Panics if `i > len`.
76    pub fn rank1(&self, i: usize) -> usize {
77        assert!(i <= self.len, "rank1: index out of bounds");
78        if i == 0 {
79            return 0;
80        }
81
82        let block_idx = (i - 1) / 64;
83        let super_idx = block_idx / BLOCKS_PER_SUPER;
84        let mut count = self.superblocks[super_idx];
85
86        // Count blocks within the superblock
87        let first_block = super_idx * BLOCKS_PER_SUPER;
88        for b in first_block..block_idx {
89            count += self.blocks[b].count_ones() as usize;
90        }
91
92        // Count bits within the last block
93        let bit_pos = i % 64;
94        if bit_pos == 0 {
95            count += self.blocks[block_idx].count_ones() as usize;
96        } else {
97            let mask = (1u64 << bit_pos) - 1;
98            count += (self.blocks[block_idx] & mask).count_ones() as usize;
99        }
100
101        count
102    }
103
104    /// Count the number of 0-bits in positions `[0, i)`.
105    pub fn rank0(&self, i: usize) -> usize {
106        i - self.rank1(i)
107    }
108
109    /// Find the position of the `k`-th 1-bit (1-indexed).
110    ///
111    /// Returns `None` if there are fewer than `k` set bits.
112    pub fn select1(&self, k: usize) -> Option<usize> {
113        if k == 0 || k > self.count_ones() {
114            return None;
115        }
116
117        // Binary search on superblocks
118        let mut lo = 0;
119        let mut hi = self.superblocks.len() - 1;
120        while lo < hi {
121            let mid = lo + (hi - lo + 1) / 2;
122            if self.superblocks[mid] < k {
123                lo = mid;
124            } else {
125                hi = mid - 1;
126            }
127        }
128
129        let mut remaining = k - self.superblocks[lo];
130        let first_block = lo * BLOCKS_PER_SUPER;
131
132        // Scan blocks within the superblock
133        for b in first_block..self.blocks.len() {
134            let popcnt = self.blocks[b].count_ones() as usize;
135            if popcnt >= remaining {
136                // Find the exact bit position within this block
137                let mut word = self.blocks[b];
138                for _ in 1..remaining {
139                    word &= word - 1; // clear lowest set bit
140                }
141                let bit_in_block = word.trailing_zeros() as usize;
142                let pos = b * 64 + bit_in_block;
143                return if pos < self.len { Some(pos) } else { None };
144            }
145            remaining -= popcnt;
146        }
147
148        None
149    }
150
151    /// Find the position of the `k`-th 0-bit (1-indexed).
152    ///
153    /// Returns `None` if there are fewer than `k` zero bits.
154    pub fn select0(&self, k: usize) -> Option<usize> {
155        if k == 0 || k > self.count_zeros() {
156            return None;
157        }
158
159        // Linear scan (acceptable for most bioinformatics use cases)
160        let mut remaining = k;
161        for (b, &block) in self.blocks.iter().enumerate() {
162            let zeros_in_block = if (b + 1) * 64 <= self.len {
163                64 - block.count_ones() as usize
164            } else {
165                let valid_bits = self.len - b * 64;
166                let mask = if valid_bits >= 64 {
167                    u64::MAX
168                } else {
169                    (1u64 << valid_bits) - 1
170                };
171                valid_bits - (block & mask).count_ones() as usize
172            };
173
174            if zeros_in_block >= remaining {
175                // Find the exact bit position
176                let mut word = if (b + 1) * 64 <= self.len {
177                    !block
178                } else {
179                    let valid_bits = self.len - b * 64;
180                    let mask = (1u64 << valid_bits) - 1;
181                    !block & mask
182                };
183                for _ in 1..remaining {
184                    word &= word - 1;
185                }
186                let bit_in_block = word.trailing_zeros() as usize;
187                let pos = b * 64 + bit_in_block;
188                return if pos < self.len { Some(pos) } else { None };
189            }
190            remaining -= zeros_in_block;
191        }
192
193        None
194    }
195
196    /// Total number of bits in the bitvector.
197    pub fn len(&self) -> usize {
198        self.len
199    }
200
201    /// Whether the bitvector has zero length.
202    pub fn is_empty(&self) -> bool {
203        self.len == 0
204    }
205
206    /// Total number of 1-bits.
207    pub fn count_ones(&self) -> usize {
208        self.rank1(self.len)
209    }
210
211    /// Total number of 0-bits.
212    pub fn count_zeros(&self) -> usize {
213        self.len - self.count_ones()
214    }
215}
216
217// ── Wavelet Matrix ───────────────────────────────────────────────────────
218
219/// A wavelet matrix over an integer alphabet `[0, σ)`.
220///
221/// Supports access, rank, and select queries in O(log σ) time using
222/// ⌈log₂ σ⌉ levels of [`RankSelectBitVec`]s.
223#[derive(Debug, Clone)]
224pub struct WaveletMatrix {
225    levels: Vec<RankSelectBitVec>,
226    /// Number of zeros at each level (for navigation).
227    num_zeros: Vec<usize>,
228    sigma: usize,
229    len: usize,
230}
231
232impl WaveletMatrix {
233    /// Build a wavelet matrix from a sequence of symbols in `[0, sigma)`.
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if any symbol is ≥ `sigma` or `sigma` is 0.
238    pub fn build(symbols: &[usize], sigma: usize) -> Result<Self> {
239        if sigma == 0 {
240            return Err(CyaneaError::InvalidInput(
241                "WaveletMatrix: sigma must be positive".into(),
242            ));
243        }
244        if let Some(&s) = symbols.iter().find(|&&s| s >= sigma) {
245            return Err(CyaneaError::InvalidInput(format!(
246                "WaveletMatrix: symbol {} out of range [0, {})",
247                s, sigma
248            )));
249        }
250
251        let n = symbols.len();
252        let num_levels = if sigma <= 1 { 1 } else { (sigma as f64).log2().ceil() as usize };
253
254        let mut levels = Vec::with_capacity(num_levels);
255        let mut num_zeros = Vec::with_capacity(num_levels);
256        let mut current = symbols.to_vec();
257
258        for level in (0..num_levels).rev() {
259            let bit = 1 << level;
260            let bits: Vec<bool> = current.iter().map(|&s| s & bit != 0).collect();
261            let bv = RankSelectBitVec::build(&bits);
262            let nz = bv.count_zeros();
263            num_zeros.push(nz);
264            levels.push(bv);
265
266            // Stable partition: 0-bit symbols first, then 1-bit symbols
267            let mut next = Vec::with_capacity(n);
268            for &s in &current {
269                if s & bit == 0 {
270                    next.push(s);
271                }
272            }
273            for &s in &current {
274                if s & bit != 0 {
275                    next.push(s);
276                }
277            }
278            current = next;
279        }
280
281        Ok(Self {
282            levels,
283            num_zeros,
284            sigma,
285            len: n,
286        })
287    }
288
289    /// Access the symbol at position `i`.
290    ///
291    /// Returns `None` if `i >= len`.
292    pub fn access(&self, mut i: usize) -> Option<usize> {
293        if i >= self.len {
294            return None;
295        }
296
297        let mut symbol = 0;
298        for (level_idx, bv) in self.levels.iter().enumerate() {
299            let bit_val = 1 << (self.levels.len() - 1 - level_idx);
300            if bv.get(i) {
301                symbol |= bit_val;
302                i = self.num_zeros[level_idx] + bv.rank1(i);
303            } else {
304                i = bv.rank0(i);
305            }
306        }
307
308        Some(symbol)
309    }
310
311    /// Count occurrences of symbol `c` in positions `[0, i)`.
312    pub fn rank(&self, c: usize, mut i: usize) -> usize {
313        if c >= self.sigma || i == 0 {
314            return 0;
315        }
316        if i > self.len {
317            i = self.len;
318        }
319
320        let mut lo = 0;
321        let mut hi = i;
322
323        for (level_idx, bv) in self.levels.iter().enumerate() {
324            let bit_val = 1 << (self.levels.len() - 1 - level_idx);
325            if c & bit_val != 0 {
326                lo = self.num_zeros[level_idx] + bv.rank1(lo);
327                hi = self.num_zeros[level_idx] + bv.rank1(hi);
328            } else {
329                lo = bv.rank0(lo);
330                hi = bv.rank0(hi);
331            }
332        }
333
334        hi - lo
335    }
336
337    /// Find the position of the `k`-th occurrence of symbol `c` (1-indexed).
338    ///
339    /// Returns `None` if there are fewer than `k` occurrences.
340    pub fn select(&self, c: usize, k: usize) -> Option<usize> {
341        if c >= self.sigma || k == 0 {
342            return None;
343        }
344
345        // Navigate down to find the range for symbol c
346        let mut lo = 0usize;
347        let mut hi = self.len;
348        for (level_idx, bv) in self.levels.iter().enumerate() {
349            let bit_val = 1 << (self.levels.len() - 1 - level_idx);
350            if c & bit_val != 0 {
351                lo = self.num_zeros[level_idx] + bv.rank1(lo);
352                hi = self.num_zeros[level_idx] + bv.rank1(hi);
353            } else {
354                lo = bv.rank0(lo);
355                hi = bv.rank0(hi);
356            }
357        }
358
359        if k > hi - lo {
360            return None;
361        }
362
363        // Navigate back up from position lo + k - 1
364        let mut pos = lo + k - 1;
365        for level_idx in (0..self.levels.len()).rev() {
366            let bv = &self.levels[level_idx];
367            let bit_val = 1 << (self.levels.len() - 1 - level_idx);
368            if c & bit_val != 0 {
369                // pos is in the 1-zone: pos = nz + rank1(original_pos)
370                // We need to find original_pos such that nz + rank1(original_pos) == pos
371                // i.e., rank1(original_pos) == pos - nz
372                let target_rank = pos - self.num_zeros[level_idx] + 1;
373                pos = bv.select1(target_rank)?;
374            } else {
375                let target_rank = pos + 1;
376                pos = bv.select0(target_rank)?;
377            }
378        }
379
380        Some(pos)
381    }
382
383    /// Length of the indexed sequence.
384    pub fn len(&self) -> usize {
385        self.len
386    }
387
388    /// Whether the sequence is empty.
389    pub fn is_empty(&self) -> bool {
390        self.len == 0
391    }
392
393    /// Alphabet size.
394    pub fn sigma(&self) -> usize {
395        self.sigma
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    // ── RankSelectBitVec tests ───────────────────────────────────────
404
405    #[test]
406    fn rank_empty() {
407        let bv = RankSelectBitVec::build(&[]);
408        assert_eq!(bv.len(), 0);
409        assert!(bv.is_empty());
410        assert_eq!(bv.count_ones(), 0);
411    }
412
413    #[test]
414    fn rank_basic() {
415        // bits: 1 0 1 1 0 1 0 0
416        let bits = [true, false, true, true, false, true, false, false];
417        let bv = RankSelectBitVec::build(&bits);
418        assert_eq!(bv.len(), 8);
419        assert_eq!(bv.count_ones(), 4);
420        assert_eq!(bv.count_zeros(), 4);
421
422        assert_eq!(bv.rank1(0), 0);
423        assert_eq!(bv.rank1(1), 1);
424        assert_eq!(bv.rank1(2), 1);
425        assert_eq!(bv.rank1(3), 2);
426        assert_eq!(bv.rank1(4), 3);
427        assert_eq!(bv.rank1(8), 4);
428    }
429
430    #[test]
431    fn rank0_basic() {
432        let bits = [true, false, true, true, false, true, false, false];
433        let bv = RankSelectBitVec::build(&bits);
434        assert_eq!(bv.rank0(0), 0);
435        assert_eq!(bv.rank0(2), 1);
436        assert_eq!(bv.rank0(8), 4);
437    }
438
439    #[test]
440    fn get_bits() {
441        let bits = [true, false, true, false];
442        let bv = RankSelectBitVec::build(&bits);
443        assert!(bv.get(0));
444        assert!(!bv.get(1));
445        assert!(bv.get(2));
446        assert!(!bv.get(3));
447    }
448
449    #[test]
450    fn select1_basic() {
451        let bits = [true, false, true, true, false, true, false, false];
452        let bv = RankSelectBitVec::build(&bits);
453        assert_eq!(bv.select1(1), Some(0));
454        assert_eq!(bv.select1(2), Some(2));
455        assert_eq!(bv.select1(3), Some(3));
456        assert_eq!(bv.select1(4), Some(5));
457        assert_eq!(bv.select1(5), None);
458        assert_eq!(bv.select1(0), None);
459    }
460
461    #[test]
462    fn select0_basic() {
463        let bits = [true, false, true, true, false, true, false, false];
464        let bv = RankSelectBitVec::build(&bits);
465        assert_eq!(bv.select0(1), Some(1));
466        assert_eq!(bv.select0(2), Some(4));
467        assert_eq!(bv.select0(3), Some(6));
468        assert_eq!(bv.select0(4), Some(7));
469        assert_eq!(bv.select0(5), None);
470    }
471
472    #[test]
473    fn rank_large_bitvec() {
474        // Test with > 512 bits to exercise superblock logic
475        let n = 1000;
476        let bits: Vec<bool> = (0..n).map(|i| i % 3 == 0).collect();
477        let bv = RankSelectBitVec::build(&bits);
478
479        // Verify rank against brute force
480        for i in (0..=n).step_by(100) {
481            let expected = bits[..i].iter().filter(|&&b| b).count();
482            assert_eq!(bv.rank1(i), expected, "rank1({}) mismatch", i);
483        }
484    }
485
486    #[test]
487    fn select1_large() {
488        let n = 1000;
489        let bits: Vec<bool> = (0..n).map(|i| i % 3 == 0).collect();
490        let bv = RankSelectBitVec::build(&bits);
491        // First 1-bit is at position 0, second at position 3, etc.
492        assert_eq!(bv.select1(1), Some(0));
493        assert_eq!(bv.select1(2), Some(3));
494        assert_eq!(bv.select1(3), Some(6));
495    }
496
497    #[test]
498    fn all_ones() {
499        let bits = vec![true; 200];
500        let bv = RankSelectBitVec::build(&bits);
501        assert_eq!(bv.count_ones(), 200);
502        assert_eq!(bv.rank1(100), 100);
503        assert_eq!(bv.select1(50), Some(49));
504    }
505
506    #[test]
507    fn all_zeros() {
508        let bits = vec![false; 200];
509        let bv = RankSelectBitVec::build(&bits);
510        assert_eq!(bv.count_zeros(), 200);
511        assert_eq!(bv.rank0(100), 100);
512        assert_eq!(bv.select0(50), Some(49));
513        assert_eq!(bv.select1(1), None);
514    }
515
516    // ── WaveletMatrix tests ──────────────────────────────────────────
517
518    #[test]
519    fn wavelet_access() {
520        let data = [3, 1, 4, 1, 5, 9, 2, 6];
521        let wm = WaveletMatrix::build(&data, 10).unwrap();
522        for (i, &expected) in data.iter().enumerate() {
523            assert_eq!(wm.access(i), Some(expected), "access({}) failed", i);
524        }
525        assert_eq!(wm.access(8), None);
526    }
527
528    #[test]
529    fn wavelet_rank() {
530        let data = [3, 1, 4, 1, 5, 9, 2, 6];
531        let wm = WaveletMatrix::build(&data, 10).unwrap();
532        assert_eq!(wm.rank(1, 4), 2); // two 1s in [0, 4)
533        assert_eq!(wm.rank(1, 2), 1); // one 1 in [0, 2)
534        assert_eq!(wm.rank(4, 3), 1); // one 4 in [0, 3)
535        assert_eq!(wm.rank(7, 8), 0); // no 7s
536    }
537
538    #[test]
539    fn wavelet_select() {
540        let data = [3, 1, 4, 1, 5, 9, 2, 6];
541        let wm = WaveletMatrix::build(&data, 10).unwrap();
542        assert_eq!(wm.select(1, 1), Some(1)); // first 1 at index 1
543        assert_eq!(wm.select(1, 2), Some(3)); // second 1 at index 3
544        assert_eq!(wm.select(1, 3), None);    // no third 1
545        assert_eq!(wm.select(3, 1), Some(0)); // first 3 at index 0
546    }
547
548    #[test]
549    fn wavelet_binary_alphabet() {
550        let data = [0, 1, 0, 1, 1, 0];
551        let wm = WaveletMatrix::build(&data, 2).unwrap();
552        assert_eq!(wm.rank(0, 6), 3);
553        assert_eq!(wm.rank(1, 6), 3);
554        assert_eq!(wm.select(0, 1), Some(0));
555        assert_eq!(wm.select(0, 2), Some(2));
556    }
557
558    #[test]
559    fn wavelet_single_symbol() {
560        let data = [0, 0, 0, 0];
561        let wm = WaveletMatrix::build(&data, 1).unwrap();
562        assert_eq!(wm.access(0), Some(0));
563        assert_eq!(wm.rank(0, 4), 4);
564    }
565
566    #[test]
567    fn wavelet_empty() {
568        let wm = WaveletMatrix::build(&[], 4).unwrap();
569        assert_eq!(wm.len(), 0);
570        assert!(wm.is_empty());
571        assert_eq!(wm.access(0), None);
572    }
573
574    #[test]
575    fn wavelet_invalid() {
576        assert!(WaveletMatrix::build(&[], 0).is_err());
577        assert!(WaveletMatrix::build(&[5], 4).is_err());
578    }
579
580    #[test]
581    fn wavelet_dna_encoded() {
582        // DNA as 0=A, 1=C, 2=G, 3=T
583        let dna = [0, 1, 2, 3, 0, 1, 2, 3]; // ACGTACGT
584        let wm = WaveletMatrix::build(&dna, 4).unwrap();
585        assert_eq!(wm.rank(0, 8), 2); // 2 A's
586        assert_eq!(wm.rank(1, 8), 2); // 2 C's
587        assert_eq!(wm.select(2, 1), Some(2)); // first G at index 2
588        assert_eq!(wm.select(3, 2), Some(7)); // second T at index 7
589    }
590}