bio/data_structures/
wavelet_matrix.rs

1//! Wavelet Matrix data structure for DNA alphabet.
2//! The implementation is based on the paper
3//! [Claude Francisco and Gonzalo Navarro. The wavelet matrix. SPIRE (2012)](https://doi.org/10.1007/978-3-642-34109-0_18)
4//!
5//! # Example
6//!
7//! ```
8//! use bio::data_structures::wavelet_matrix::WaveletMatrix;
9//! let text = b"AANGGT$ACCNTT$";
10//! let wm = WaveletMatrix::new(text);
11//! assert_eq!(wm.rank(b'A', 0), 1);
12//! assert_eq!(wm.rank(b'G', 9), 2);
13//! assert_eq!(wm.rank(b'T', 13), 3);
14//! ```
15
16use crate::data_structures::rank_select::RankSelect;
17use bv::BitVec;
18use bv::BitsMut;
19
20const DNA2INT: [u8; 128] = [
21    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //  0
22    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 10
23    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20
24    0, 0, 0, 0, 0, 0, 5, 0, 0, 0, // 30
25    0, 0, 0, 0, 0, 0, 0, 0, 0, 1, // 40
26    2, 3, 4, 5, 6, 7, 0, 0, 0, 0, // 50
27    0, 0, 0, 0, 0, 0, 0, 1, 0, 0, // 60
28    0, 2, 0, 0, 0, 0, 0, 0, 4, 0, // 70
29    0, 0, 0, 0, 3, 0, 0, 0, 0, 0, // 80
30    0, 0, 0, 0, 0, 0, 0, 0, 0, 1, // 90
31    0, 0, 0, 2, 0, 0, 0, 0, 0, 0, // 100
32    4, 0, 0, 0, 0, 0, 3, 0, 0, 0, // 110
33    0, 0, 0, 0, 0, 0, 0, 0,
34]; // 120
35
36#[derive(Default, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Serialize, Deserialize)]
37pub struct WaveletMatrix {
38    width: usize,  // levels[0].len()
39    height: usize, // zeros.len() == levels.len()
40    zeros: Vec<u64>,
41    levels: Vec<RankSelect>,
42}
43
44fn build_partlevel(
45    vals: &[u8],
46    shift: u8,
47    next_zeros: &mut Vec<u8>,
48    next_ones: &mut Vec<u8>,
49    bits: &mut BitVec<u8>,
50    prev_bits: u64,
51) {
52    let mut p = prev_bits;
53    for val in vals {
54        let bit = ((DNA2INT[usize::from(*val)] >> shift) & 1) == 1; // get shifted lsb
55        bits.set_bit(p, bit);
56        p += 1;
57        if bit {
58            next_ones.push(*val);
59        } else {
60            next_zeros.push(*val);
61        }
62    }
63}
64
65impl WaveletMatrix {
66    /// Construct a new instance of the wavelet matrix of given text of length n (DNA alphabet plus sentinel symbol).
67    /// Complexity: O(n).
68    pub fn new(text: &[u8]) -> Self {
69        let width = text.len();
70        let height: usize = 3; // hardcoded for alphabet size <= 8 (ACGTN$)
71
72        let mut curr_zeros: Vec<u8> = text.to_vec();
73        let mut curr_ones: Vec<u8> = Vec::new();
74
75        let mut zeros: Vec<u64> = Vec::new();
76        let mut levels: Vec<RankSelect> = Vec::new();
77
78        for level in 0..height {
79            let mut next_zeros: Vec<u8> = Vec::with_capacity(width);
80            let mut next_ones: Vec<u8> = Vec::with_capacity(width);
81            let mut curr_bits: BitVec<u8> = BitVec::new_fill(false, width as u64);
82            let shift = (height - level - 1) as u8;
83            build_partlevel(
84                &curr_zeros,
85                shift,
86                &mut next_zeros,
87                &mut next_ones,
88                &mut curr_bits,
89                0,
90            );
91            build_partlevel(
92                &curr_ones,
93                shift,
94                &mut next_zeros,
95                &mut next_ones,
96                &mut curr_bits,
97                curr_zeros.len() as u64,
98            );
99
100            curr_zeros = next_zeros;
101            curr_ones = next_ones;
102
103            let level = RankSelect::new(curr_bits, 1);
104            levels.push(level);
105            zeros.push(curr_zeros.len() as u64);
106        }
107
108        WaveletMatrix {
109            width,
110            height,
111            zeros,
112            levels,
113        }
114    }
115
116    fn check_overflow(&self, p: u64) -> bool {
117        p >= self.width as u64
118    }
119
120    fn prank(&self, level: usize, p: u64, val: u8) -> u64 {
121        if p == 0 {
122            0
123        } else if val == 0 {
124            self.levels[level].rank_0(p - 1).unwrap()
125        } else {
126            self.levels[level].rank_1(p - 1).unwrap()
127        }
128    }
129
130    /// Compute the number of occurrences of symbol val in the original text up to position p (inclusive).
131    /// Complexity O(1).
132    pub fn rank(&self, val: u8, p: u64) -> u64 {
133        assert!(
134            !self.check_overflow(p),
135            "Invalid p (it must be in range 0..wm_size-1"
136        );
137        let height = self.height;
138        let mut spos = 0;
139        let mut epos = p + 1;
140        for level in 0..height {
141            let shift = height - level - 1;
142            let bit = ((DNA2INT[val as usize] >> shift) & 1) == 1; // get shifted lsb
143            if bit {
144                spos = self.prank(level, spos, 1) + self.zeros[level];
145                epos = self.prank(level, epos, 1) + self.zeros[level];
146            } else {
147                spos = self.prank(level, spos, 0);
148                epos = self.prank(level, epos, 0);
149            }
150        }
151        epos - spos
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_wm_buildpaper() {
161        let text = b"476532101417";
162        let wm = WaveletMatrix::new(text);
163        let levels = vec![
164            vec![
165                true, true, true, true, false, false, false, false, false, true, false, true,
166            ],
167            vec![
168                true, true, false, false, false, false, false, true, true, false, false, true,
169            ],
170            vec![
171                true, false, true, true, false, true, false, true, false, true, false, true,
172            ],
173        ];
174        let zeros = [6, 7, 5];
175
176        assert_eq!(wm.height, zeros.len());
177        assert_eq!(wm.width, levels[0].len());
178        for level in 0..wm.height {
179            assert_eq!(wm.zeros[level], zeros[level]);
180            for i in 0..wm.width {
181                assert_eq!(wm.levels[level].bits().get(i as u64), levels[level][i]);
182            }
183        }
184    }
185
186    #[test]
187    fn test_wm_builddna() {
188        let text = b"ACGTN$NAGCT$";
189        let wm = WaveletMatrix::new(text);
190        let levels = vec![
191            vec![
192                false, false, false, false, true, true, true, false, false, false, false, true,
193            ],
194            vec![
195                false, false, true, true, false, true, false, true, false, false, false, false,
196            ],
197            vec![
198                false, true, false, true, false, true, false, true, false, true, false, true,
199            ],
200        ];
201        let zeros = [8, 8, 6];
202
203        assert_eq!(wm.height, zeros.len());
204        assert_eq!(wm.width, levels[0].len());
205        for level in 0..wm.height {
206            assert_eq!(wm.zeros[level], zeros[level]);
207            for i in 0..wm.width {
208                assert_eq!(wm.levels[level].bits().get(i as u64), levels[level][i]);
209            }
210        }
211    }
212
213    #[test]
214    #[should_panic]
215    fn test_wm_rank_overflowpanic() {
216        let text = b"476532101417";
217        let wm = WaveletMatrix::new(text);
218        wm.rank(b'4', text.len() as u64);
219    }
220
221    #[test]
222    fn test_wm_rank_firstpos() {
223        let text = b"476532101417";
224        let wm = WaveletMatrix::new(text);
225        assert_eq!(wm.rank(b'4', 0), 1);
226    }
227
228    #[test]
229    fn test_wm_rank_lastpos() {
230        let text = b"476532101417";
231        let wm = WaveletMatrix::new(text);
232        assert_eq!(wm.rank(b'7', text.len() as u64 - 1), 2);
233    }
234
235    #[test]
236    fn test_wm_rank_1() {
237        let text = b"476532101417";
238        let wm = WaveletMatrix::new(text);
239        assert_eq!(wm.rank(b'0', 6), 0);
240        assert_eq!(wm.rank(b'0', 7), 1);
241        assert_eq!(wm.rank(b'0', 8), 1);
242    }
243
244    #[test]
245    fn test_wm_rank_2() {
246        let text = b"476532101417";
247        let wm = WaveletMatrix::new(text);
248        assert_eq!(wm.rank(b'4', 8), 1);
249        assert_eq!(wm.rank(b'4', 9), 2);
250        assert_eq!(wm.rank(b'4', 10), 2);
251    }
252
253    #[test]
254    fn test_wm_rank_all() {
255        let text = b"476532101417";
256        let wm = WaveletMatrix::new(text);
257
258        let ranks = vec![
259            vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
260            vec![0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3],
261            vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
262            vec![0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
263            vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2],
264            vec![0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
265            vec![0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
266            vec![0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2],
267        ];
268
269        let alphabet = [b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7'];
270        for (i, c) in alphabet.iter().enumerate() {
271            for p in 0..text.len() {
272                assert_eq!(wm.rank(*c, p as u64), ranks[i][p]);
273            }
274        }
275    }
276
277    #[test]
278    fn test_wm_rank_alldna() {
279        let text = b"AAGCTC$$CATTNGA";
280        let wm = WaveletMatrix::new(text);
281
282        let ranks = vec![
283            vec![1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4],
284            vec![0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3],
285            vec![0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2],
286            vec![0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 3, 3, 3, 3],
287            vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
288            vec![0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2],
289        ];
290
291        let alphabet = [b'A', b'C', b'G', b'T', b'N', b'$'];
292        for (i, c) in alphabet.iter().enumerate() {
293            for p in 0..text.len() {
294                assert_eq!(wm.rank(*c, p as u64), ranks[i][p]);
295            }
296        }
297    }
298}