sbwt 0.4.2

Indexing sets of DNA k-mers with the spectral Burrow-Wheeler transform.
Documentation
use std::cmp::min;

use read_exact::{self, ReadExactExt};

use crate::util::is_dna;

#[derive(Debug)]
#[allow(dead_code)]
pub enum KmerEncodingError{
    InvalidNucleotide(char), // contains the offending char
    TooLong(usize), // Contains the length of the k-mer which was too long
}

// B is the number of u64 in a kmer
// The k-mer will be a (32*B)-mer
// k-mer comparing structs gives lexicographic k-mer comparison.
// NOTE: k-mer comparison only works correctly for equal-length kmers
// because k-mers are padded with A's at the end to fill B*32 characters.
#[derive(Copy, Clone, PartialEq, Eq, Ord, PartialOrd, Hash, Debug)]
pub struct LongKmer<const B: usize> {
    data: [u64; B] // Packed with 2 bits / nucleotide so that bitwise lex comparison is k-mer lex comparison
}

fn ascii_to_bitpair_panic_if_not_ACGT(c: u8) -> u8 {
    match c {
        b'A' => 0,
        b'C' => 1,
        b'G' => 2,
        b'T' => 3,
        _ => panic!("Invalid nucleotide"),
    }
}

// TODO: always pass these by value since this type is Copy?
#[allow(dead_code)]
impl<const B: usize> LongKmer<B>{

    // If the length of the ASCII string is less than 32*B, the k-mer is padded with A's from the left
    pub fn from_ascii(ascii: &[u8]) -> Result<Self, KmerEncodingError>{
        if ascii.len() > B*32{
            return Err(KmerEncodingError::TooLong(ascii.len()));
        }
        let mut data = [0_u64; B];
        for (i, c) in ascii.iter().enumerate() {
            let bitpair: u64 = match *c{
                b'A' => 0,
                b'C' => 1,
                b'G' => 2,
                b'T' => 3,
                _ => {return Err(KmerEncodingError::InvalidNucleotide(*c as char))}
            };
            let block = i / 32;
            let off = 31 - i % 32;
            //eprintln!("Set {} {} {} {}", c, block, off, bitpair);
            data[block] |= bitpair << (2*off);
        }
        
        Ok(Self{data})
    }

    /// c has alphabet {0,1,2,3}!
    pub fn copy_set_from_left(&self, i: usize, c: u8) -> Self {
        let pos = i;
        let block = pos / 32;
        let off = 31 - pos % 32;
        let mask = 3_u64 << (2*off);

        let mut data_copy = self.data;
        data_copy[block] = (data_copy[block] & !mask) | ((c as u64) << (2*off));

        Self{data: data_copy}
    }

    pub fn get_from_left(&self, i: usize) -> u8 {
        let pos = i;
        let block = pos / 32;
        let off = 31 - pos % 32;
        ((self.data[block] >> (2*off)) & 3) as u8
    }

    // Extends with A's at the end
    // c has alphabet {0,1,2,3}!
    pub fn right_shifted(&self, chars: usize) -> Self{
        // TODO: this could be done without any branching
        let mut new_data = [0_u64; B];
        for block in 0..B{
            let b1 = block + chars / 32; // Which block the first char lands on
            let o1 = (chars % 32) * 2; // Which bit within block the first char lands on
            let b2 = block + chars.div_ceil(32); // Which block the last char lands on
            if b1 < B {
                new_data[b1] |= self.data[block] >> o1;
            }
            if b2 < B {
                let shift = 64 - o1; 

                // shift by 64 is panic
                if shift < 64 {
                    new_data[b2] |= self.data[block] << shift;
                }
            }
        }
        Self{data: new_data}
    }

    pub fn left_shifted(&self, chars: usize) -> Self{
        // TODO: this could be done without any branching
        let chars = chars as isize;
        let mut new_data = [0_u64; B];
        for block in 0..(B as isize){
            let b1 = block - (chars + 31) / 32; // Which block the first char lands on
            let o1 = ((32 - (chars % 32)) * 2) % 64; // Which bit within block the first char lands on
            let b2 = block - chars / 32; // Which block the last char lands on
            if b1 >= 0 {
                new_data[b1 as usize] |= self.data[block as usize] >> o1;
            }
            if b2 >= 0 {
                let shift = 64 - o1; 

                // shift by 64 is panic
                if shift < 64 {
                    new_data[b2 as usize] |= self.data[block as usize] << shift;
                }
            }
        }
        Self{data: new_data}
    }

    #[allow(dead_code)]
    pub fn get_u64_data(&self) -> &[u64]{
        &self.data
    }

    pub fn from_u64_data(data: [u64; B]) -> Self{
        Self{data}
    }

    pub fn byte_size() -> usize {
        8*B
    }

    pub fn serialize<W: std::io::Write>(&self, out: &mut W) -> std::io::Result<usize>{
        // TODO: maybe use an unsafe single write to avoid the loop
        let mut written = 0;
        for block in self.data.iter(){
            let bytes = block.to_le_bytes();
            out.write_all(&bytes)?;
            written += bytes.len();
        }
        Ok(written)
    }

    // Returns Ok(None) if the stream gives an EOF
    pub fn load<R: std::io::Read>(input: &mut R) -> std::io::Result<Option<Self>>{
        // TODO: read with just 1 IO call
        // TODO: this should return an error if could not read 8*B bytes
        // These todos may seem easy but they are not because the const generic support is not good enough yet
        let mut data = [0_u64; B];
        let mut buf = [0_u8; 8];
        for block in data.iter_mut(){
            match input.read_exact_or_eof(&mut buf) {
                Ok(true) => {*block = u64::from_le_bytes(buf);},
                Ok(false) => return Ok(None), // EOF
                Err(e) => return Err(e),
            }
        }
        Ok(Some(Self::from_u64_data(data)))
    }

    pub fn lcp(a: &Self, b: &Self) -> usize{
        for i in 0..B{
            let xor = a.data[i] ^ b.data[i];
            if xor != 0{
                return 32*i + xor.leading_zeros() as usize / 2;
            }
        }
        B*32 // Full k-mer match: 32 nucleotids per block
    }

    pub fn lcp_with_different_lengths(a: (&Self, u8), b: (&Self, u8)) -> usize { // Takes pairs (kmer, len)
        let lcp_value = LongKmer::<B>::lcp(a.0, b.0);
        min(lcp_value, min(a.1 as usize, b.1 as usize))
    }

}

pub struct KmerIterator<'a, const B: usize> {
    seq: &'a[u8],
    cur_kmer: LongKmer<B>,
    first_iteration: bool,
    next_seq_pos: usize,
    cur_len: usize,
    k: usize,
}

impl<'a, const B: usize> KmerIterator<'a, B> {

    #[allow(dead_code)]
    pub fn new(seq: &'a [u8], k: usize) -> KmerIterator<'a, B>{
        Self{seq, cur_kmer: LongKmer::from_u64_data([0; B]), first_iteration: true, next_seq_pos: 0, cur_len: 0, k}
    }

    fn scan_to_next_full_kmer(&mut self) -> Option<LongKmer<B>> {
        // Find the next k-mer that as only symbols from alphabet ACGT 
        while self.cur_len < self.k && self.next_seq_pos < self.seq.len() {
            let ascii_char = self.seq[self.next_seq_pos];
            if is_dna(ascii_char) {
                let c = ascii_to_bitpair_panic_if_not_ACGT(ascii_char);
                self.cur_kmer = self.cur_kmer.copy_set_from_left(self.cur_len, c); // Append c
                self.cur_len += 1;
            } else {
                self.cur_kmer = LongKmer::from_u64_data([0; B]); // Clear
                self.cur_len = 0;
            }
            self.next_seq_pos += 1;
        }

        if self.cur_len == self.k {
            Some(self.cur_kmer)
        } else {
            None
        }

    }
}

impl<const B: usize> Iterator for KmerIterator<'_, B> {
    type Item = LongKmer<B>;

    fn next(&mut self) -> Option<Self::Item> {
        if self.first_iteration {
            self.first_iteration = false;
            self.scan_to_next_full_kmer()
        } else {
            assert!(self.cur_len == self.k);
            self.cur_kmer = self.cur_kmer.left_shifted(1); // Deletes the first character
            self.cur_len -= 1;
            self.scan_to_next_full_kmer()
        }
    }

}

impl<const B: usize> std::fmt::Display for LongKmer<B>{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let mut s = String::new();
        for i in 0..B*32{
            s.push(match self.get_from_left(i){
                0 => 'A',
                1 => 'C',
                2 => 'G',
                3 => 'T',
                _ => panic!("Invalid character in DNA sequence"),
            });
        }
        write!(f, "{}", s)
    }
}

#[cfg(test)]
mod tests{
    use super::*;

    #[allow(clippy::ptr_arg)]
    fn left_shifted(s: &String) -> String {
        let mut s = s.clone();
        s.remove(0);
        s.push('A');
        s
    }

    #[allow(clippy::ptr_arg)]
    fn right_shifted(s: &String) -> String {
        let mut s = s.clone();
        s.pop();
        s.insert(0, 'A');
        s
    }

    #[test]
    fn test_lcp(){
        let ascii1 = b"ACGTACGTACGTACGTACGTACGTACGTACGTACATGCATTT";
        let ascii2 = b"ACGTACGTACGTACGTACGTACGTACGTACGTACATGCATAT";
        let x = LongKmer::<2>::from_ascii(ascii1).unwrap();        
        let y = LongKmer::<2>::from_ascii(ascii2).unwrap();        
        assert_eq!(LongKmer::<2>::lcp(&x, &y), 40);

        // Equal k-mers
        let ascii1 = b"ACGTACGTACGTACGTACGTACGTACGTACGTACATGCATTTCTAGCTAGCTGATCGATCGA";
        let ascii2 = b"ACGTACGTACGTACGTACGTACGTACGTACGTACATGCATTTCTAGCTAGCTGATCGATCGA";
        let x = LongKmer::<2>::from_ascii(ascii1).unwrap();        
        let y = LongKmer::<2>::from_ascii(ascii2).unwrap();        
        assert_eq!(LongKmer::<2>::lcp(&x, &y), 64);
    }

    #[test]
    fn test_kmer_iterator(){

        let k = 4;

        assert_eq!(KmerIterator::<3>::new(b"", k).count(), 0);
        assert_eq!(KmerIterator::<3>::new(b"NON-NUCLEOTIDES", k).count(), 0);

        let ascii = b"NACGATNACANANAAATN";
        let packed_kmers: Vec<LongKmer<3>> = KmerIterator::<3>::new(ascii, k).collect();
        let ascii_kmers: Vec<Vec<u8>> = packed_kmers.iter().map(|x| x.to_string().as_bytes()[0..k].to_owned()).collect();
        let true_ascii_kmers = vec![b"ACGA".to_vec(), b"CGAT".to_vec(), b"AAAT".to_vec()];
        assert_eq!(ascii_kmers, true_ascii_kmers);
    }

    #[test]
    #[allow(clippy::nonminimal_bool)]
    fn test_long_kmer(){
        let ascii = b"ACGTACGTACGTACGTACGTACGTACGTACGTACATGCATTT";
        let mut x = LongKmer::<2>::from_ascii(ascii).unwrap();        

        // Setting and getting

        x = x.copy_set_from_left(0, 2);
        x = x.copy_set_from_left(1, 3);
        x = x.copy_set_from_left(62, 1);
        eprintln!("{}", x);

        let mut expected = String::from("GTGTACGTACGTACGTACGTACGTACGTACGTACATGCATTTAAAAAAAAAAAAAAAAAAAACA");
        let actual = format!("{}", x); // This currently uses get_from_left
        assert_eq!(expected, actual);

        // Shifts

        for i in 0..100{
            let our = x.left_shifted(i);
            eprintln!("{}", our);
            assert_eq!(expected, format!("{}", our));
            expected = left_shifted(&expected);
        }

        expected = String::from("GTGTACGTACGTACGTACGTACGTACGTACGTACATGCATTTAAAAAAAAAAAAAAAAAAAACA");
        for i in 0..100{
            let our = x.right_shifted(i);
            eprintln!("{}", our);
            assert_eq!(expected, format!("{}", our));
            expected = right_shifted(&expected);
        }

        // Comparison
        let x = LongKmer::<2>::from_ascii(b"AATCAGCTAGCTACTATCTACGTACTACGTACGGGCGTACGTAGCA").unwrap();
        let y = LongKmer::<2>::from_ascii(b"AATCAGCTAGCTACTATCTACGTACTACGTACGGGCGTACGTCAGC").unwrap();

        assert!(x < y);

        let x = LongKmer::<2>::from_ascii(b"GGGGAC").unwrap();
        let y = LongKmer::<2>::from_ascii(b"GGGGAC").unwrap();

        assert!(x == y);
        assert!(x <= y);
        assert!(!(x < y));
        assert!(!(x > y));

        // COMPARISON ONLY WORKS FOR EQUAL-LENGTH kmers
        /* 
        let x = LongKmer::<2>::from_ascii(b"GGGGAC").unwrap();
        let y = LongKmer::<2>::from_ascii(b"GGGGACAA").unwrap();

        assert!(x < y);

        let x = LongKmer::<2>::from_ascii(b"AATCAGCTAGCTACTATCTACGTACTACGTACGGGCGTACGTAGCAA").unwrap();
        let y = LongKmer::<2>::from_ascii(b"AATCAGCTAGCTACTATCTACGTACTACGTACGGGCGTACGTAGCA").unwrap();

        assert!(x > y);
        */
    }
}