1use crate::data_structures::rank_select::RankSelect;
17use bv::BitVec;
18use bv::BitsMut;
19
20const DNA2INT: [u8; 128] = [
21 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
34]; #[derive(Default, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Serialize, Deserialize)]
37pub struct WaveletMatrix {
38 width: usize, height: usize, zeros: Vec<u64>,
41 levels: Vec<RankSelect>,
42}
43
44fn build_partlevel(
45 vals: &[u8],
46 shift: u8,
47 next_zeros: &mut Vec<u8>,
48 next_ones: &mut Vec<u8>,
49 bits: &mut BitVec<u8>,
50 prev_bits: u64,
51) {
52 let mut p = prev_bits;
53 for val in vals {
54 let bit = ((DNA2INT[usize::from(*val)] >> shift) & 1) == 1; bits.set_bit(p, bit);
56 p += 1;
57 if bit {
58 next_ones.push(*val);
59 } else {
60 next_zeros.push(*val);
61 }
62 }
63}
64
65impl WaveletMatrix {
66 pub fn new(text: &[u8]) -> Self {
69 let width = text.len();
70 let height: usize = 3; let mut curr_zeros: Vec<u8> = text.to_vec();
73 let mut curr_ones: Vec<u8> = Vec::new();
74
75 let mut zeros: Vec<u64> = Vec::new();
76 let mut levels: Vec<RankSelect> = Vec::new();
77
78 for level in 0..height {
79 let mut next_zeros: Vec<u8> = Vec::with_capacity(width);
80 let mut next_ones: Vec<u8> = Vec::with_capacity(width);
81 let mut curr_bits: BitVec<u8> = BitVec::new_fill(false, width as u64);
82 let shift = (height - level - 1) as u8;
83 build_partlevel(
84 &curr_zeros,
85 shift,
86 &mut next_zeros,
87 &mut next_ones,
88 &mut curr_bits,
89 0,
90 );
91 build_partlevel(
92 &curr_ones,
93 shift,
94 &mut next_zeros,
95 &mut next_ones,
96 &mut curr_bits,
97 curr_zeros.len() as u64,
98 );
99
100 curr_zeros = next_zeros;
101 curr_ones = next_ones;
102
103 let level = RankSelect::new(curr_bits, 1);
104 levels.push(level);
105 zeros.push(curr_zeros.len() as u64);
106 }
107
108 WaveletMatrix {
109 width,
110 height,
111 zeros,
112 levels,
113 }
114 }
115
116 fn check_overflow(&self, p: u64) -> bool {
117 p >= self.width as u64
118 }
119
120 fn prank(&self, level: usize, p: u64, val: u8) -> u64 {
121 if p == 0 {
122 0
123 } else if val == 0 {
124 self.levels[level].rank_0(p - 1).unwrap()
125 } else {
126 self.levels[level].rank_1(p - 1).unwrap()
127 }
128 }
129
130 pub fn rank(&self, val: u8, p: u64) -> u64 {
133 assert!(
134 !self.check_overflow(p),
135 "Invalid p (it must be in range 0..wm_size-1"
136 );
137 let height = self.height;
138 let mut spos = 0;
139 let mut epos = p + 1;
140 for level in 0..height {
141 let shift = height - level - 1;
142 let bit = ((DNA2INT[val as usize] >> shift) & 1) == 1; if bit {
144 spos = self.prank(level, spos, 1) + self.zeros[level];
145 epos = self.prank(level, epos, 1) + self.zeros[level];
146 } else {
147 spos = self.prank(level, spos, 0);
148 epos = self.prank(level, epos, 0);
149 }
150 }
151 epos - spos
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_wm_buildpaper() {
161 let text = b"476532101417";
162 let wm = WaveletMatrix::new(text);
163 let levels = vec![
164 vec![
165 true, true, true, true, false, false, false, false, false, true, false, true,
166 ],
167 vec![
168 true, true, false, false, false, false, false, true, true, false, false, true,
169 ],
170 vec![
171 true, false, true, true, false, true, false, true, false, true, false, true,
172 ],
173 ];
174 let zeros = [6, 7, 5];
175
176 assert_eq!(wm.height, zeros.len());
177 assert_eq!(wm.width, levels[0].len());
178 for level in 0..wm.height {
179 assert_eq!(wm.zeros[level], zeros[level]);
180 for i in 0..wm.width {
181 assert_eq!(wm.levels[level].bits().get(i as u64), levels[level][i]);
182 }
183 }
184 }
185
186 #[test]
187 fn test_wm_builddna() {
188 let text = b"ACGTN$NAGCT$";
189 let wm = WaveletMatrix::new(text);
190 let levels = vec![
191 vec![
192 false, false, false, false, true, true, true, false, false, false, false, true,
193 ],
194 vec![
195 false, false, true, true, false, true, false, true, false, false, false, false,
196 ],
197 vec![
198 false, true, false, true, false, true, false, true, false, true, false, true,
199 ],
200 ];
201 let zeros = [8, 8, 6];
202
203 assert_eq!(wm.height, zeros.len());
204 assert_eq!(wm.width, levels[0].len());
205 for level in 0..wm.height {
206 assert_eq!(wm.zeros[level], zeros[level]);
207 for i in 0..wm.width {
208 assert_eq!(wm.levels[level].bits().get(i as u64), levels[level][i]);
209 }
210 }
211 }
212
213 #[test]
214 #[should_panic]
215 fn test_wm_rank_overflowpanic() {
216 let text = b"476532101417";
217 let wm = WaveletMatrix::new(text);
218 wm.rank(b'4', text.len() as u64);
219 }
220
221 #[test]
222 fn test_wm_rank_firstpos() {
223 let text = b"476532101417";
224 let wm = WaveletMatrix::new(text);
225 assert_eq!(wm.rank(b'4', 0), 1);
226 }
227
228 #[test]
229 fn test_wm_rank_lastpos() {
230 let text = b"476532101417";
231 let wm = WaveletMatrix::new(text);
232 assert_eq!(wm.rank(b'7', text.len() as u64 - 1), 2);
233 }
234
235 #[test]
236 fn test_wm_rank_1() {
237 let text = b"476532101417";
238 let wm = WaveletMatrix::new(text);
239 assert_eq!(wm.rank(b'0', 6), 0);
240 assert_eq!(wm.rank(b'0', 7), 1);
241 assert_eq!(wm.rank(b'0', 8), 1);
242 }
243
244 #[test]
245 fn test_wm_rank_2() {
246 let text = b"476532101417";
247 let wm = WaveletMatrix::new(text);
248 assert_eq!(wm.rank(b'4', 8), 1);
249 assert_eq!(wm.rank(b'4', 9), 2);
250 assert_eq!(wm.rank(b'4', 10), 2);
251 }
252
253 #[test]
254 fn test_wm_rank_all() {
255 let text = b"476532101417";
256 let wm = WaveletMatrix::new(text);
257
258 let ranks = vec![
259 vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
260 vec![0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3],
261 vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
262 vec![0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
263 vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2],
264 vec![0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
265 vec![0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
266 vec![0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2],
267 ];
268
269 let alphabet = [b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7'];
270 for (i, c) in alphabet.iter().enumerate() {
271 for p in 0..text.len() {
272 assert_eq!(wm.rank(*c, p as u64), ranks[i][p]);
273 }
274 }
275 }
276
277 #[test]
278 fn test_wm_rank_alldna() {
279 let text = b"AAGCTC$$CATTNGA";
280 let wm = WaveletMatrix::new(text);
281
282 let ranks = vec![
283 vec![1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4],
284 vec![0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3],
285 vec![0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2],
286 vec![0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 3, 3, 3, 3],
287 vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
288 vec![0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2],
289 ];
290
291 let alphabet = [b'A', b'C', b'G', b'T', b'N', b'$'];
292 for (i, c) in alphabet.iter().enumerate() {
293 for p in 0..text.len() {
294 assert_eq!(wm.rank(*c, p as u64), ranks[i][p]);
295 }
296 }
297 }
298}