bio/data_structures/bwt.rs
1// Copyright 2014-2016 Johannes Köster, Taylor Cramer.
2// Licensed under the MIT license (http://opensource.org/licenses/MIT)
3// This file may not be copied, modified, or distributed
4// except according to those terms.
5
6//! The [Burrows-Wheeler-Transform](https://www.semanticscholar.org/paper/A-Block-sorting-Lossless-Data-Compression-Algorithm-Burrows-Wheeler/af56e6d4901dcd0f589bf969e604663d40f1be5d) and related data structures.
7//! The implementation is based on the lecture notes
8//! "Algorithmen auf Sequenzen", Kopczynski, Marschall, Martin and Rahmann, 2008 - 2015.
9
10use std::iter::repeat;
11
12use crate::alphabets::Alphabet;
13use crate::data_structures::suffix_array::RawSuffixArraySlice;
14use crate::utils::prescan;
15
16pub type BWT = Vec<u8>;
17pub type BWTSlice = [u8];
18pub type Less = Vec<usize>;
19pub type BWTFind = Vec<usize>;
20
21/// Calculate Burrows-Wheeler-Transform of the given text of length n.
22/// Complexity: O(n).
23///
24/// # Arguments
25///
26/// * `text` - the text ended by sentinel symbol (being lexicographically smallest)
27/// * `pos` - the suffix array for the text
28///
29/// # Example
30///
31/// ```
32/// use bio::data_structures::bwt::bwt;
33/// use bio::data_structures::suffix_array::suffix_array;
34/// let text = b"GCCTTAACATTATTACGCCTA$";
35/// let pos = suffix_array(text);
36/// let bwt = bwt(text, &pos);
37/// assert_eq!(bwt, b"ATTATTCAGGACCC$CTTTCAA");
38/// ```
39pub fn bwt(text: &[u8], pos: RawSuffixArraySlice) -> BWT {
40 assert_eq!(text.len(), pos.len());
41 let n = text.len();
42 let mut bwt: BWT = repeat(0).take(n).collect();
43 for r in 0..n {
44 let p = pos[r];
45 bwt[r] = if p > 0 { text[p - 1] } else { text[n - 1] };
46 }
47
48 bwt
49}
50
51/// Calculate the inverse of a BWT of length n, which is the original text.
52/// Complexity: O(n).
53///
54/// This only works if the last sentinel in the original text is unique
55/// and lexicographically the smallest.
56///
57/// # Arguments
58///
59/// * `bwt` - the BWT
60pub fn invert_bwt(bwt: &BWTSlice) -> Vec<u8> {
61 let alphabet = Alphabet::new(bwt);
62 let n = bwt.len();
63 let bwtfind = bwtfind(bwt, &alphabet);
64 let mut inverse = Vec::with_capacity(n);
65
66 let mut r = bwtfind[0];
67 for _ in 0..n {
68 r = bwtfind[r];
69 inverse.push(bwt[r]);
70 }
71
72 inverse
73}
74
75/// An occurrence array implementation.
76#[derive(Default, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Serialize, Deserialize)]
77pub struct Occ {
78 occ: Vec<Vec<usize>>,
79 k: u32,
80}
81
82impl Occ {
83 /// Calculate occ array with sampling from BWT of length n.
84 /// Time complexity: O(n).
85 /// Space complexity: O(n / k * A) with A being the alphabet size.
86 /// The specified alphabet must match the alphabet of the text and its BWT.
87 /// For large texts, it is advisable to transform
88 /// the text before calculating the BWT (see alphabets::rank_transform).
89 ///
90 /// # Arguments
91 ///
92 /// * `bwt` - the BWT
93 /// * `k` - the sampling rate: every k-th entry will be stored
94 pub fn new(bwt: &BWTSlice, k: u32, alphabet: &Alphabet) -> Self {
95 let n = bwt.len();
96 let m = alphabet
97 .max_symbol()
98 .expect("Expecting non-empty alphabet.") as usize
99 + 1;
100 let mut alpha = alphabet.symbols.iter().collect::<Vec<usize>>();
101 // include sentinel '$'
102 if (b'$' as usize) < m && !alphabet.is_word(b"$") {
103 alpha.push(b'$' as usize);
104 }
105 let mut occ: Vec<Vec<usize>> = vec![Vec::new(); m];
106 let mut curr_occ = vec![0usize; m];
107
108 // characters not in the alphabet won't take up much space
109 for &a in &alpha {
110 occ[a].reserve(n / k as usize);
111 }
112
113 for (i, &c) in bwt.iter().enumerate() {
114 curr_occ[c as usize] += 1;
115
116 if i % k as usize == 0 {
117 // only visit characters in the alphabet
118 for &a in &alpha {
119 occ[a].push(curr_occ[a]);
120 }
121 }
122 }
123
124 Occ { occ, k }
125 }
126
127 /// Get occurrence count of symbol a in BWT[..r+1].
128 /// Complexity: O(k).
129 pub fn get(&self, bwt: &BWTSlice, r: usize, a: u8) -> usize {
130 // NOTE:
131 //
132 // Retrieving byte match counts in this function is critical to the performance of FM Index.
133 //
134 // The below manual count code is roughly equivalent to:
135 // ```
136 // let count = bwt[(i * self.k) + 1..r + 1]
137 // .iter()
138 // .filter(|&&c| c == a)
139 // .count();
140 // self.occ[a as usize][i] + count
141 // ```
142 //
143 // But there are a couple of reasons to do this manually:
144 // 1) As of 2016, versions of rustc/LLVM vectorize this manual loop more reliably
145 // than the iterator adapter version.
146 // 2) Manually accumulating the byte match count in a single chunk can allows
147 // us to use a `u32` for that count, which has faster arithmetic on common arches.
148 // This does necessitate storing `k` as a u32.
149 //
150 // See the conversation in these issues for some of the history here:
151 //
152 // https://github.com/rust-bio/rust-bio/pull/74
153 // https://github.com/rust-bio/rust-bio/pull/76
154
155 // self.k is our sampling rate, so find the checkpoints either side of r.
156 let lo_checkpoint = r / self.k as usize;
157 // Get the occurences at the low checkpoint
158 let lo_occ = self.occ[a as usize][lo_checkpoint];
159
160 // If the sampling rate is infrequent it is worth checking if there is a closer
161 // hi checkpoint.
162 if self.k > 64 {
163 let hi_checkpoint = lo_checkpoint + 1;
164 if let Some(&hi_occ) = self.occ[a as usize].get(hi_checkpoint) {
165 // Its possible that there are no occurences between the low and high
166 // checkpoint in which case we bail early.
167 if lo_occ == hi_occ {
168 return lo_occ;
169 }
170
171 // If r is closer to the high checkpoint, count backwards from there.
172 let hi_idx = hi_checkpoint * self.k as usize;
173 if (hi_idx - r) < (self.k as usize / 2) {
174 return hi_occ - bytecount::count(&bwt[r + 1..=hi_idx], a);
175 }
176 }
177 }
178
179 // Otherwise the default case is to count from the low checkpoint.
180 let lo_idx = lo_checkpoint * self.k as usize;
181 bytecount::count(&bwt[lo_idx + 1..=r], a) + lo_occ
182 }
183}
184
185/// Calculate the less array for a given BWT. Complexity O(n).
186pub fn less(bwt: &BWTSlice, alphabet: &Alphabet) -> Less {
187 let m = alphabet
188 .max_symbol()
189 .expect("Expecting non-empty alphabet.") as usize
190 + 2;
191 let mut less: Less = repeat(0).take(m).collect();
192 for &c in bwt.iter() {
193 less[c as usize] += 1;
194 }
195 // calculate +-prescan
196 prescan(&mut less[..], 0, |a, b| a + b);
197
198 less
199}
200
201/// Calculate the bwtfind array needed for inverting the BWT. Complexity O(n).
202pub fn bwtfind(bwt: &BWTSlice, alphabet: &Alphabet) -> BWTFind {
203 let n = bwt.len();
204 let mut less = less(bwt, alphabet);
205
206 let mut bwtfind: BWTFind = repeat(0).take(n).collect();
207 for (r, &c) in bwt.iter().enumerate() {
208 bwtfind[less[c as usize]] = r;
209 less[c as usize] += 1;
210 }
211
212 bwtfind
213}
214
215#[cfg(test)]
216mod tests {
217 use super::{bwt, bwtfind, invert_bwt, Occ};
218 use crate::alphabets::dna;
219 use crate::alphabets::Alphabet;
220 use crate::data_structures::suffix_array::suffix_array;
221 use crate::data_structures::wavelet_matrix::WaveletMatrix;
222
223 #[test]
224 fn test_bwtfind() {
225 let text = b"cabca$";
226 let alphabet = Alphabet::new(b"abc$");
227 let pos = suffix_array(text);
228 let bwt = bwt(text, &pos);
229 let bwtfind = bwtfind(&bwt, &alphabet);
230 assert_eq!(bwtfind, vec![5, 0, 3, 4, 1, 2]);
231 }
232
233 #[test]
234 fn test_invert_bwt() {
235 let text = b"cabca$";
236 let pos = suffix_array(text);
237 let bwt = bwt(text, &pos);
238 let inverse = invert_bwt(&bwt);
239 assert_eq!(inverse, text);
240 }
241
242 #[test]
243 fn test_occ() {
244 let bwt = vec![1u8, 3u8, 3u8, 1u8, 2u8, 0u8];
245 let alphabet = Alphabet::new([0u8, 1u8, 2u8, 3u8]);
246 let occ = Occ::new(&bwt, 3, &alphabet);
247 assert_eq!(occ.occ, [[0, 0], [1, 2], [0, 0], [0, 2]]);
248 assert_eq!(occ.get(&bwt, 4, 2u8), 1);
249 assert_eq!(occ.get(&bwt, 4, 3u8), 2);
250 }
251
252 #[test]
253 fn test_occwm() {
254 let text = b"GCCTTAACATTATTACGCCTA$";
255 let alphabet = {
256 let mut a = dna::n_alphabet();
257 a.insert(b'$');
258 a
259 };
260 let sa = suffix_array(text);
261 let bwt = bwt(text, &sa);
262 let occ = Occ::new(&bwt, 3, &alphabet);
263 let wm = WaveletMatrix::new(&bwt);
264
265 for c in [b'A', b'C', b'G', b'T', b'$'] {
266 for p in 0..text.len() {
267 assert_eq!(occ.get(&bwt, p, c) as u64, wm.rank(c, p as u64));
268 }
269 }
270 }
271}