bio/data_structures/
bwt.rs

1// Copyright 2014-2016 Johannes Köster, Taylor Cramer.
2// Licensed under the MIT license (http://opensource.org/licenses/MIT)
3// This file may not be copied, modified, or distributed
4// except according to those terms.
5
6//! The [Burrows-Wheeler-Transform](https://www.semanticscholar.org/paper/A-Block-sorting-Lossless-Data-Compression-Algorithm-Burrows-Wheeler/af56e6d4901dcd0f589bf969e604663d40f1be5d) and related data structures.
7//! The implementation is based on the lecture notes
8//! "Algorithmen auf Sequenzen", Kopczynski, Marschall, Martin and Rahmann, 2008 - 2015.
9
10use std::iter::repeat;
11
12use crate::alphabets::Alphabet;
13use crate::data_structures::suffix_array::RawSuffixArraySlice;
14use crate::utils::prescan;
15
16pub type BWT = Vec<u8>;
17pub type BWTSlice = [u8];
18pub type Less = Vec<usize>;
19pub type BWTFind = Vec<usize>;
20
21/// Calculate Burrows-Wheeler-Transform of the given text of length n.
22/// Complexity: O(n).
23///
24/// # Arguments
25///
26/// * `text` - the text ended by sentinel symbol (being lexicographically smallest)
27/// * `pos` - the suffix array for the text
28///
29/// # Example
30///
31/// ```
32/// use bio::data_structures::bwt::bwt;
33/// use bio::data_structures::suffix_array::suffix_array;
34/// let text = b"GCCTTAACATTATTACGCCTA$";
35/// let pos = suffix_array(text);
36/// let bwt = bwt(text, &pos);
37/// assert_eq!(bwt, b"ATTATTCAGGACCC$CTTTCAA");
38/// ```
39pub fn bwt(text: &[u8], pos: RawSuffixArraySlice) -> BWT {
40    assert_eq!(text.len(), pos.len());
41    let n = text.len();
42    let mut bwt: BWT = repeat(0).take(n).collect();
43    for r in 0..n {
44        let p = pos[r];
45        bwt[r] = if p > 0 { text[p - 1] } else { text[n - 1] };
46    }
47
48    bwt
49}
50
51/// Calculate the inverse of a BWT of length n, which is the original text.
52/// Complexity: O(n).
53///
54/// This only works if the last sentinel in the original text is unique
55/// and lexicographically the smallest.
56///
57/// # Arguments
58///
59/// * `bwt` - the BWT
60pub fn invert_bwt(bwt: &BWTSlice) -> Vec<u8> {
61    let alphabet = Alphabet::new(bwt);
62    let n = bwt.len();
63    let bwtfind = bwtfind(bwt, &alphabet);
64    let mut inverse = Vec::with_capacity(n);
65
66    let mut r = bwtfind[0];
67    for _ in 0..n {
68        r = bwtfind[r];
69        inverse.push(bwt[r]);
70    }
71
72    inverse
73}
74
75/// An occurrence array implementation.
76#[derive(Default, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Serialize, Deserialize)]
77pub struct Occ {
78    occ: Vec<Vec<usize>>,
79    k: u32,
80}
81
82impl Occ {
83    /// Calculate occ array with sampling from BWT of length n.
84    /// Time complexity: O(n).
85    /// Space complexity: O(n / k * A) with A being the alphabet size.
86    /// The specified alphabet must match the alphabet of the text and its BWT.
87    /// For large texts, it is advisable to transform
88    /// the text before calculating the BWT (see alphabets::rank_transform).
89    ///
90    /// # Arguments
91    ///
92    /// * `bwt` - the BWT
93    /// * `k` - the sampling rate: every k-th entry will be stored
94    pub fn new(bwt: &BWTSlice, k: u32, alphabet: &Alphabet) -> Self {
95        let n = bwt.len();
96        let m = alphabet
97            .max_symbol()
98            .expect("Expecting non-empty alphabet.") as usize
99            + 1;
100        let mut alpha = alphabet.symbols.iter().collect::<Vec<usize>>();
101        // include sentinel '$'
102        if (b'$' as usize) < m && !alphabet.is_word(b"$") {
103            alpha.push(b'$' as usize);
104        }
105        let mut occ: Vec<Vec<usize>> = vec![Vec::new(); m];
106        let mut curr_occ = vec![0usize; m];
107
108        // characters not in the alphabet won't take up much space
109        for &a in &alpha {
110            occ[a].reserve(n / k as usize);
111        }
112
113        for (i, &c) in bwt.iter().enumerate() {
114            curr_occ[c as usize] += 1;
115
116            if i % k as usize == 0 {
117                // only visit characters in the alphabet
118                for &a in &alpha {
119                    occ[a].push(curr_occ[a]);
120                }
121            }
122        }
123
124        Occ { occ, k }
125    }
126
127    /// Get occurrence count of symbol a in BWT[..r+1].
128    /// Complexity: O(k).
129    pub fn get(&self, bwt: &BWTSlice, r: usize, a: u8) -> usize {
130        // NOTE:
131        //
132        // Retrieving byte match counts in this function is critical to the performance of FM Index.
133        //
134        // The below manual count code is roughly equivalent to:
135        // ```
136        // let count = bwt[(i * self.k) + 1..r + 1]
137        //     .iter()
138        //     .filter(|&&c| c == a)
139        //     .count();
140        // self.occ[a as usize][i] + count
141        // ```
142        //
143        // But there are a couple of reasons to do this manually:
144        // 1) As of 2016, versions of rustc/LLVM vectorize this manual loop more reliably
145        //    than the iterator adapter version.
146        // 2) Manually accumulating the byte match count in a single chunk can allows
147        //    us to use a `u32` for that count, which has faster arithmetic on common arches.
148        //    This does necessitate storing `k` as a u32.
149        //
150        // See the conversation in these issues for some of the history here:
151        //
152        // https://github.com/rust-bio/rust-bio/pull/74
153        // https://github.com/rust-bio/rust-bio/pull/76
154
155        // self.k is our sampling rate, so find the checkpoints either side of r.
156        let lo_checkpoint = r / self.k as usize;
157        // Get the occurences at the low checkpoint
158        let lo_occ = self.occ[a as usize][lo_checkpoint];
159
160        // If the sampling rate is infrequent it is worth checking if there is a closer
161        // hi checkpoint.
162        if self.k > 64 {
163            let hi_checkpoint = lo_checkpoint + 1;
164            if let Some(&hi_occ) = self.occ[a as usize].get(hi_checkpoint) {
165                // Its possible that there are no occurences between the low and high
166                // checkpoint in which case we bail early.
167                if lo_occ == hi_occ {
168                    return lo_occ;
169                }
170
171                // If r is closer to the high checkpoint, count backwards from there.
172                let hi_idx = hi_checkpoint * self.k as usize;
173                if (hi_idx - r) < (self.k as usize / 2) {
174                    return hi_occ - bytecount::count(&bwt[r + 1..=hi_idx], a);
175                }
176            }
177        }
178
179        // Otherwise the default case is to count from the low checkpoint.
180        let lo_idx = lo_checkpoint * self.k as usize;
181        bytecount::count(&bwt[lo_idx + 1..=r], a) + lo_occ
182    }
183}
184
185/// Calculate the less array for a given BWT. Complexity O(n).
186pub fn less(bwt: &BWTSlice, alphabet: &Alphabet) -> Less {
187    let m = alphabet
188        .max_symbol()
189        .expect("Expecting non-empty alphabet.") as usize
190        + 2;
191    let mut less: Less = repeat(0).take(m).collect();
192    for &c in bwt.iter() {
193        less[c as usize] += 1;
194    }
195    // calculate +-prescan
196    prescan(&mut less[..], 0, |a, b| a + b);
197
198    less
199}
200
201/// Calculate the bwtfind array needed for inverting the BWT. Complexity O(n).
202pub fn bwtfind(bwt: &BWTSlice, alphabet: &Alphabet) -> BWTFind {
203    let n = bwt.len();
204    let mut less = less(bwt, alphabet);
205
206    let mut bwtfind: BWTFind = repeat(0).take(n).collect();
207    for (r, &c) in bwt.iter().enumerate() {
208        bwtfind[less[c as usize]] = r;
209        less[c as usize] += 1;
210    }
211
212    bwtfind
213}
214
215#[cfg(test)]
216mod tests {
217    use super::{bwt, bwtfind, invert_bwt, Occ};
218    use crate::alphabets::dna;
219    use crate::alphabets::Alphabet;
220    use crate::data_structures::suffix_array::suffix_array;
221    use crate::data_structures::wavelet_matrix::WaveletMatrix;
222
223    #[test]
224    fn test_bwtfind() {
225        let text = b"cabca$";
226        let alphabet = Alphabet::new(b"abc$");
227        let pos = suffix_array(text);
228        let bwt = bwt(text, &pos);
229        let bwtfind = bwtfind(&bwt, &alphabet);
230        assert_eq!(bwtfind, vec![5, 0, 3, 4, 1, 2]);
231    }
232
233    #[test]
234    fn test_invert_bwt() {
235        let text = b"cabca$";
236        let pos = suffix_array(text);
237        let bwt = bwt(text, &pos);
238        let inverse = invert_bwt(&bwt);
239        assert_eq!(inverse, text);
240    }
241
242    #[test]
243    fn test_occ() {
244        let bwt = vec![1u8, 3u8, 3u8, 1u8, 2u8, 0u8];
245        let alphabet = Alphabet::new([0u8, 1u8, 2u8, 3u8]);
246        let occ = Occ::new(&bwt, 3, &alphabet);
247        assert_eq!(occ.occ, [[0, 0], [1, 2], [0, 0], [0, 2]]);
248        assert_eq!(occ.get(&bwt, 4, 2u8), 1);
249        assert_eq!(occ.get(&bwt, 4, 3u8), 2);
250    }
251
252    #[test]
253    fn test_occwm() {
254        let text = b"GCCTTAACATTATTACGCCTA$";
255        let alphabet = {
256            let mut a = dna::n_alphabet();
257            a.insert(b'$');
258            a
259        };
260        let sa = suffix_array(text);
261        let bwt = bwt(text, &sa);
262        let occ = Occ::new(&bwt, 3, &alphabet);
263        let wm = WaveletMatrix::new(&bwt);
264
265        for c in [b'A', b'C', b'G', b'T', b'$'] {
266            for p in 0..text.len() {
267                assert_eq!(occ.get(&bwt, p, c) as u64, wm.rank(c, p as u64));
268            }
269        }
270    }
271}