packed_seq/
packed_seq.rs

1use traits::Seq;
2
3use crate::intrinsics::transpose;
4
5use super::*;
6
7/// A 2-bit packed non-owned slice of DNA bases.
8#[derive(Copy, Clone, Debug, MemSize, MemDbg)]
9pub struct PackedSeq<'s> {
10    /// Packed data.
11    pub seq: &'s [u8],
12    /// Offset in bp from the start of the `seq`.
13    pub offset: usize,
14    /// Length of the sequence in bp, starting at `offset` from the start of `seq`.
15    pub len: usize,
16}
17
18/// A 2-bit packed owned sequence of DNA bases.
19#[derive(Clone, Debug, Default, MemSize, MemDbg)]
20#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
21#[cfg_attr(feature = "epserde", derive(epserde::Epserde))]
22pub struct PackedSeqVec {
23    pub seq: Vec<u8>,
24    pub len: usize,
25}
26
27/// Pack an ASCII `ACTGactg` character into its 2-bit representation.
28pub fn pack_char(base: u8) -> u8 {
29    match base {
30        b'a' | b'A' => 0,
31        b'c' | b'C' => 1,
32        b'g' | b'G' => 3,
33        b't' | b'T' => 2,
34        _ => panic!(
35            "Unexpected character '{}' with ASCII value {base}. Expected one of ACTGactg.",
36            base as char
37        ),
38    }
39}
40
41/// Unpack a 2-bit DNA base into the corresponding `ACTG` character.
42pub fn unpack_base(base: u8) -> u8 {
43    debug_assert!(base < 4, "Base {base} is not <4.");
44    b"ACTG"[base as usize]
45}
46
47/// Complement an ASCII character: `A<>T` and `C<>G`.
48pub const fn complement_char(base: u8) -> u8 {
49    match base {
50        b'A' => b'T',
51        b'C' => b'G',
52        b'G' => b'C',
53        b'T' => b'A',
54        _ => panic!("Unexpected character. Expected one of ACTGactg.",),
55    }
56}
57
58/// Complement a 2-bit base: `0<>2` and `1<>3`.
59pub const fn complement_base(base: u8) -> u8 {
60    base ^ 2
61}
62
63/// Complement 8 lanes of 2-bit bases: `0<>2` and `1<>3`.
64pub fn complement_base_simd(base: u32x8) -> u32x8 {
65    base ^ u32x8::splat(2)
66}
67
68impl<'s> PackedSeq<'s> {
69    /// Shrink `seq` to only just cover the data.
70    #[inline(always)]
71    pub fn normalize(&self) -> Self {
72        let start = self.offset / 4;
73        let end = (self.offset + self.len).div_ceil(4);
74        Self {
75            seq: &self.seq[start..end],
76            offset: self.offset % 4,
77            len: self.len,
78        }
79    }
80
81    /// Return a `Vec<u8>` of ASCII `ACTG` characters.
82    pub fn unpack(&self) -> Vec<u8> {
83        self.iter_bp().map(unpack_base).collect()
84    }
85}
86
87#[inline(always)]
88pub(crate) fn read_slice(seq: &[u8], idx: usize) -> u32x8 {
89    // assert!(idx <= seq.len());
90    let mut result = [0u8; 32];
91    let num_bytes = 32.min(seq.len().saturating_sub(idx));
92    unsafe {
93        let src = seq.as_ptr().add(idx);
94        std::ptr::copy_nonoverlapping(src, result.as_mut_ptr(), num_bytes);
95        std::mem::transmute(result)
96    }
97}
98
99impl<'s> Seq<'s> for PackedSeq<'s> {
100    const BASES_PER_BYTE: usize = 4;
101    const BITS_PER_CHAR: usize = 2;
102    type SeqVec = PackedSeqVec;
103
104    #[inline(always)]
105    fn len(&self) -> usize {
106        self.len
107    }
108
109    #[inline(always)]
110    fn get(&self, index: usize) -> u8 {
111        let offset = self.offset + index;
112        let idx = offset / 4;
113        let offset = offset % 4;
114        unsafe { (*self.seq.get_unchecked(idx) >> (2 * offset)) & 3 }
115    }
116
117    #[inline(always)]
118    fn get_ascii(&self, index: usize) -> u8 {
119        unpack_base(self.get(index))
120    }
121
122    /// Panics if `self` is longer than 29 characters.
123    #[inline(always)]
124    fn to_word(&self) -> usize {
125        debug_assert!(self.len() <= usize::BITS as usize / 2 - 3);
126        let mask = usize::MAX >> (64 - 2 * self.len());
127        unsafe {
128            ((self.seq.as_ptr() as *const usize).read_unaligned() >> (2 * self.offset)) & mask
129        }
130    }
131
132    fn to_vec(&self) -> PackedSeqVec {
133        assert_eq!(self.offset, 0);
134        PackedSeqVec {
135            seq: self.seq.to_vec(),
136            len: self.len,
137        }
138    }
139
140    #[inline(always)]
141    fn slice(&self, range: Range<usize>) -> Self {
142        debug_assert!(
143            range.end <= self.len,
144            "Slice index out of bounds: {} > {}",
145            range.end,
146            self.len
147        );
148        PackedSeq {
149            seq: self.seq,
150            offset: self.offset + range.start,
151            len: range.end - range.start,
152        }
153        .normalize()
154    }
155
156    #[inline(always)]
157    fn iter_bp(self) -> impl ExactSizeIterator<Item = u8> + Clone {
158        assert!(self.len <= self.seq.len() * 4);
159
160        let this = self.normalize();
161
162        // read u64 at a time?
163        let mut byte = 0;
164        let mut it = (0..this.len + this.offset).map(
165            #[inline(always)]
166            move |i| {
167                if i % 4 == 0 {
168                    byte = this.seq[i / 4];
169                }
170                // Shift byte instead of i?
171                (byte >> (2 * (i % 4))) & 0b11
172            },
173        );
174        it.by_ref().take(this.offset).for_each(drop);
175        it
176    }
177
178    #[inline(always)]
179    fn par_iter_bp(self, context: usize) -> (impl ExactSizeIterator<Item = S> + Clone, usize) {
180        #[cfg(target_endian = "big")]
181        panic!("Big endian architectures are not supported.");
182
183        let this = self.normalize();
184        assert_eq!(this.offset, 0, "Non-byte offsets are not yet supported.");
185
186        let num_kmers = this.len.saturating_sub(context - 1);
187        let n = num_kmers.div_ceil(L).next_multiple_of(4);
188        let bytes_per_chunk = n / 4;
189        let padding = 4 * L * bytes_per_chunk - num_kmers;
190
191        let offsets: [usize; 8] = from_fn(|l| (l * bytes_per_chunk)).into();
192        let mut cur = S::ZERO;
193
194        // Boxed, so it doesn't consume precious registers.
195        // Without this, cur is not always inlined into a register.
196        let mut buf = Box::new([S::ZERO; 8]);
197
198        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
199        let it = (0..par_len).map(
200            #[inline(always)]
201            move |i| {
202                if i % 16 == 0 {
203                    if i % 128 == 0 {
204                        // Read a u256 for each lane containing the next 128 characters.
205                        let data: [u32x8; 8] = from_fn(
206                            #[inline(always)]
207                            |lane| read_slice(this.seq, offsets[lane] + (i / 4)),
208                        );
209                        *buf = transpose(data);
210                    }
211                    cur = buf[(i % 128) / 16];
212                }
213                // Extract the last 2 bits of each character.
214                let chars = cur & S::splat(0x03);
215                // Shift remaining characters to the right.
216                cur = cur >> S::splat(2);
217                chars
218            },
219        );
220
221        (it, padding)
222    }
223
224    #[inline(always)]
225    fn par_iter_bp_delayed(
226        self,
227        context: usize,
228        delay: usize,
229    ) -> (impl ExactSizeIterator<Item = (S, S)> + Clone, usize) {
230        #[cfg(target_endian = "big")]
231        panic!("Big endian architectures are not supported.");
232
233        assert!(
234            delay < usize::MAX / 2,
235            "Delay={} should be >=0.",
236            delay as isize
237        );
238
239        let this = self.normalize();
240        assert_eq!(this.offset, 0, "Non-byte offsets are not yet supported.");
241
242        let num_kmers = this.len.saturating_sub(context - 1);
243        let n = num_kmers.div_ceil(L).next_multiple_of(4);
244        let bytes_per_chunk = n / 4;
245        let padding = 4 * L * bytes_per_chunk - num_kmers;
246
247        let offsets: [usize; 8] = from_fn(|l| (l * bytes_per_chunk)).into();
248        let mut upcoming = S::ZERO;
249        let mut upcoming_d = S::ZERO;
250
251        // Even buf_len is nice to only have the write==buf_len check once.
252        // We also make it the next power of 2, for faster modulo operations.
253        // delay/16: number of bp in a u32.
254        let buf_len = (delay / 16 + 8).next_power_of_two();
255        let buf_mask = buf_len - 1;
256        let mut buf = vec![S::ZERO; buf_len];
257        let mut write_idx = 0;
258        // We compensate for the first delay/16 triggers of the check below that
259        // happen before the delay is actually reached.
260        let mut read_idx = (buf_len - delay / 16) % buf_len;
261
262        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
263        let it = (0..par_len).map(
264            #[inline(always)]
265            move |i| {
266                if i % 16 == 0 {
267                    if i % 128 == 0 {
268                        // Read a u256 for each lane containing the next 128 characters.
269                        let data: [u32x8; 8] = from_fn(
270                            #[inline(always)]
271                            |lane| read_slice(this.seq, offsets[lane] + (i / 4)),
272                        );
273                        unsafe {
274                            *TryInto::<&mut [u32x8; 8]>::try_into(
275                                buf.get_unchecked_mut(write_idx..write_idx + 8),
276                            )
277                            .unwrap_unchecked() = transpose(data);
278                        }
279                    }
280                    upcoming = buf[write_idx];
281                    write_idx += 1;
282                    write_idx &= buf_mask;
283                }
284                if i % 16 == delay % 16 {
285                    unsafe { assert_unchecked(read_idx < buf.len()) };
286                    upcoming_d = buf[read_idx];
287                    read_idx += 1;
288                    read_idx &= buf_mask;
289                }
290                // Extract the last 2 bits of each character.
291                let chars = upcoming & S::splat(0x03);
292                let chars_d = upcoming_d & S::splat(0x03);
293                // Shift remaining characters to the right.
294                upcoming = upcoming >> S::splat(2);
295                upcoming_d = upcoming_d >> S::splat(2);
296                (chars, chars_d)
297            },
298        );
299
300        (it, padding)
301    }
302
303    #[inline(always)]
304    fn par_iter_bp_delayed_2(
305        self,
306        context: usize,
307        delay1: usize,
308        delay2: usize,
309    ) -> (impl ExactSizeIterator<Item = (S, S, S)> + Clone, usize) {
310        #[cfg(target_endian = "big")]
311        panic!("Big endian architectures are not supported.");
312
313        let this = self.normalize();
314        assert_eq!(this.offset, 0, "Non-byte offsets are not yet supported.");
315        assert!(delay1 <= delay2, "Delay1 must be at most delay2.");
316
317        let num_kmers = this.len.saturating_sub(context - 1);
318        let n = num_kmers.div_ceil(L).next_multiple_of(4);
319        let bytes_per_chunk = n / 4;
320        let padding = 4 * L * bytes_per_chunk - num_kmers;
321
322        let offsets: [usize; 8] = from_fn(|l| (l * bytes_per_chunk)).into();
323        let mut upcoming = S::ZERO;
324        let mut upcoming_d1 = S::ZERO;
325        let mut upcoming_d2 = S::ZERO;
326
327        // Even buf_len is nice to only have the write==buf_len check once.
328        let buf_len = (delay2 / 16 + 8).next_power_of_two();
329        let buf_mask = buf_len - 1;
330        let mut buf = vec![S::ZERO; buf_len];
331        let mut write_idx = 0;
332        // We compensate for the first delay/16 triggers of the check below that
333        // happen before the delay is actually reached.
334        let mut read_idx1 = (buf_len - delay1 / 16) % buf_len;
335        let mut read_idx2 = (buf_len - delay2 / 16) % buf_len;
336
337        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
338        let it = (0..par_len).map(
339            #[inline(always)]
340            move |i| {
341                if i % 16 == 0 {
342                    if i % 128 == 0 {
343                        // Read a u256 for each lane containing the next 128 characters.
344                        let data: [u32x8; 8] = from_fn(
345                            #[inline(always)]
346                            |lane| read_slice(this.seq, offsets[lane] + (i / 4)),
347                        );
348                        unsafe {
349                            *TryInto::<&mut [u32x8; 8]>::try_into(
350                                buf.get_unchecked_mut(write_idx..write_idx + 8),
351                            )
352                            .unwrap_unchecked() = transpose(data);
353                        }
354                    }
355                    upcoming = buf[write_idx];
356                    write_idx += 1;
357                    write_idx &= buf_mask;
358                }
359                if i % 16 == delay1 % 16 {
360                    unsafe { assert_unchecked(read_idx1 < buf.len()) };
361                    upcoming_d1 = buf[read_idx1];
362                    read_idx1 += 1;
363                    read_idx1 &= buf_mask;
364                }
365                if i % 16 == delay2 % 16 {
366                    unsafe { assert_unchecked(read_idx2 < buf.len()) };
367                    upcoming_d2 = buf[read_idx2];
368                    read_idx2 += 1;
369                    read_idx2 &= buf_mask;
370                }
371                // Extract the last 2 bits of each character.
372                let chars = upcoming & S::splat(0x03);
373                let chars_d1 = upcoming_d1 & S::splat(0x03);
374                let chars_d2 = upcoming_d2 & S::splat(0x03);
375                // Shift remaining characters to the right.
376                upcoming = upcoming >> S::splat(2);
377                upcoming_d1 = upcoming_d1 >> S::splat(2);
378                upcoming_d2 = upcoming_d2 >> S::splat(2);
379                (chars, chars_d1, chars_d2)
380            },
381        );
382
383        (it, padding)
384    }
385
386    /// Compares 29 characters at a time.
387    fn cmp_lcp(&self, other: &Self) -> (std::cmp::Ordering, usize) {
388        let mut lcp = 0;
389        let min_len = self.len.min(other.len);
390        for i in (0..min_len).step_by(29) {
391            let len = (min_len - i).min(29);
392            let this = self.slice(i..i + len);
393            let other = other.slice(i..i + len);
394            let this_word = this.to_word();
395            let other_word = other.to_word();
396            if this_word != other_word {
397                // Unfortunately, bases are packed in little endian order, so the default order is reversed.
398                let eq = this_word ^ other_word;
399                let t = eq.trailing_zeros() / 2 * 2;
400                lcp += t as usize / 2;
401                let mask = 0b11 << t;
402                return ((this_word & mask).cmp(&(other_word & mask)), lcp);
403            }
404            lcp += len;
405        }
406        (self.len.cmp(&other.len), lcp)
407    }
408}
409
410impl PartialEq for PackedSeq<'_> {
411    /// Compares 29 characters at a time.
412    fn eq(&self, other: &Self) -> bool {
413        if self.len != other.len {
414            return false;
415        }
416        for i in (0..self.len).step_by(29) {
417            let len = (self.len - i).min(29);
418            let this = self.slice(i..i + len);
419            let that = other.slice(i..i + len);
420            if this.to_word() != that.to_word() {
421                return false;
422            }
423        }
424        return true;
425    }
426}
427
428impl Eq for PackedSeq<'_> {}
429
430impl PartialOrd for PackedSeq<'_> {
431    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
432        Some(self.cmp(other))
433    }
434}
435
436impl Ord for PackedSeq<'_> {
437    /// Compares 29 characters at a time.
438    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
439        let min_len = self.len.min(other.len);
440        for i in (0..min_len).step_by(29) {
441            let len = (min_len - i).min(29);
442            let this = self.slice(i..i + len);
443            let other = other.slice(i..i + len);
444            let this_word = this.to_word();
445            let other_word = other.to_word();
446            if this_word != other_word {
447                // Unfortunately, bases are packed in little endian order, so the default order is reversed.
448                let eq = this_word ^ other_word;
449                let t = eq.trailing_zeros() / 2 * 2;
450                let mask = 0b11 << t;
451                return (this_word & mask).cmp(&(other_word & mask));
452            }
453        }
454        self.len.cmp(&other.len)
455    }
456}
457
458impl SeqVec for PackedSeqVec {
459    type Seq<'s> = PackedSeq<'s>;
460
461    fn into_raw(self) -> Vec<u8> {
462        self.seq
463    }
464
465    #[inline(always)]
466    fn as_slice(&self) -> Self::Seq<'_> {
467        PackedSeq {
468            seq: &self.seq,
469            offset: 0,
470            len: self.len,
471        }
472    }
473
474    fn push_seq<'a>(&mut self, seq: PackedSeq<'_>) -> Range<usize> {
475        let start = 4 * self.seq.len() + seq.offset;
476        let end = start + seq.len();
477        self.seq.extend(seq.seq);
478        self.len = 4 * self.seq.len();
479        start..end
480    }
481
482    /// Push an ASCII sequence to an `PackedSeqVec`.
483    /// `Aa` map to `0`, `Cc` to `1`, `Gg` to `3`, and `Tt` to `2`.
484    /// Other characters may be silently mapped into `[0, 4)` or panic.
485    /// (TODO: Explicitly support different conversions.)
486    ///
487    /// Uses the BMI2 `pext` instruction when available, based on the
488    /// `n_to_bits_pext` method described at
489    /// <https://github.com/Daniel-Liu-c0deb0t/cute-nucleotides>.
490    ///
491    /// TODO: Optimize for non-BMI2 platforms.
492    /// TODO: Support multiple ways of dealing with non-`ACTG` characters.
493    fn push_ascii(&mut self, seq: &[u8]) -> Range<usize> {
494        let start_aligned = 4 * self.seq.len();
495        let start = self.len;
496        let len = seq.len();
497
498        let unaligned = core::cmp::min(start_aligned - start, len);
499        if unaligned > 0 {
500            let mut packed_byte = *self.seq.last().unwrap();
501            for &base in &seq[..unaligned] {
502                packed_byte |= pack_char(base) << ((self.len % 4) * 2);
503                self.len += 1;
504            }
505            *self.seq.last_mut().unwrap() = packed_byte;
506        }
507
508        #[allow(unused)]
509        let mut last = unaligned;
510
511        #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
512        {
513            last = unaligned + (len - unaligned) / 8 * 8;
514
515            for i in (unaligned..last).step_by(8) {
516                let chunk = &seq[i..i + 8].try_into().unwrap();
517                let ascii = u64::from_ne_bytes(*chunk);
518                let packed_bytes =
519                    unsafe { std::arch::x86_64::_pext_u64(ascii, 0x0606060606060606) };
520                self.seq.push(packed_bytes as u8);
521                self.seq.push((packed_bytes >> 8) as u8);
522                self.len += 8;
523            }
524        }
525
526        let mut packed_byte = 0;
527        for &base in &seq[last..] {
528            packed_byte |= pack_char(base) << ((self.len % 4) * 2);
529            self.len += 1;
530            if self.len % 4 == 0 {
531                self.seq.push(packed_byte);
532                packed_byte = 0;
533            }
534        }
535        if self.len % 4 != 0 && last < len {
536            self.seq.push(packed_byte);
537        }
538        start..start + len
539    }
540
541    fn random(n: usize) -> Self {
542        let mut seq = vec![0; n.div_ceil(4)];
543        rand::rngs::SmallRng::from_os_rng().fill_bytes(&mut seq);
544        PackedSeqVec { seq, len: n }
545    }
546}