packed_seq/
packed_seq.rs

1use traits::Seq;
2use wide::u16x8;
3
4use crate::{intrinsics::transpose, padded_it::ChunkIt};
5
6use super::*;
7
8#[doc(hidden)]
9pub struct Bits<const B: usize>;
10#[doc(hidden)]
11pub trait SupportedBits {}
12impl SupportedBits for Bits<1> {}
13impl SupportedBits for Bits<2> {}
14impl SupportedBits for Bits<4> {}
15impl SupportedBits for Bits<8> {}
16
17/// Number of padding bytes at the end of `PackedSeqVecBase::seq`.
18const PADDING: usize = 16;
19
20/// A 2-bit packed non-owned slice of DNA bases.
21#[doc(hidden)]
22#[derive(Copy, Clone, Debug, MemSize, MemDbg)]
23pub struct PackedSeqBase<'s, const B: usize>
24where
25    Bits<B>: SupportedBits,
26{
27    /// Packed data.
28    seq: &'s [u8],
29    /// Offset in bp from the start of the `seq`.
30    offset: usize,
31    /// Length of the sequence in bp, starting at `offset` from the start of `seq`.
32    len: usize,
33}
34
35/// A 2-bit packed owned sequence of DNA bases.
36#[doc(hidden)]
37#[derive(Clone, Debug, MemSize, MemDbg)]
38#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
39#[cfg_attr(feature = "epserde", derive(epserde::Epserde))]
40pub struct PackedSeqVecBase<const B: usize>
41where
42    Bits<B>: SupportedBits,
43{
44    /// NOTE: We maintain the invariant that this has at least 16 bytes padding
45    /// at the end after `len` finishes.
46    /// This ensures that `read_unaligned` in `as_64` works OK.
47    pub(crate) seq: Vec<u8>,
48
49    /// The length, in bp, of the underlying sequence. See `.len()`.
50    len: usize,
51}
52
53pub type PackedSeq<'s> = PackedSeqBase<'s, 2>;
54pub type PackedSeqVec = PackedSeqVecBase<2>;
55pub type BitSeq<'s> = PackedSeqBase<'s, 1>;
56pub type BitSeqVec = PackedSeqVecBase<1>;
57
58/// Convenience constants.
59/// B: bits per chat
60impl<'s, const B: usize> PackedSeqBase<'s, B>
61where
62    Bits<B>: SupportedBits,
63{
64    /// lowest B bits are 1.
65    const CHAR_MASK: u64 = (1 << B) - 1;
66    /// Chars per byte
67    const C8: usize = 8 / B;
68    /// Chars per u32
69    const C32: usize = 32 / B;
70    /// Chars per u256
71    const C256: usize = 256 / B;
72    /// Max length of a kmer that can be read as a single u64.
73    const K64: usize = (64 - 8) / B + 1;
74}
75
76/// Convenience constants.
77impl<const B: usize> PackedSeqVecBase<B>
78where
79    Bits<B>: SupportedBits,
80{
81    /// Chars per byte
82    const C8: usize = 8 / B;
83}
84
85impl<const B: usize> Default for PackedSeqVecBase<B>
86where
87    Bits<B>: SupportedBits,
88{
89    fn default() -> Self {
90        Self {
91            seq: vec![0; PADDING],
92            len: 0,
93        }
94    }
95}
96
97// ======================================================================
98// 2-BIT HELPER METHODS
99
100/// Pack an ASCII `ACTGactg` character into its 2-bit representation, and panic for anything else.
101#[inline(always)]
102pub fn pack_char(base: u8) -> u8 {
103    match base {
104        b'a' | b'A' => 0,
105        b'c' | b'C' => 1,
106        b'g' | b'G' => 3,
107        b't' | b'T' => 2,
108        _ => panic!(
109            "Unexpected character '{}' with ASCII value {base}. Expected one of ACTGactg.",
110            base as char
111        ),
112    }
113}
114
115/// Pack an ASCII `ACTGactg` character into its 2-bit representation, and silently convert other characters into 0..4 as well.
116#[inline(always)]
117pub fn pack_char_lossy(base: u8) -> u8 {
118    (base >> 1) & 3
119}
120
121/// Unpack a 2-bit DNA base into the corresponding `ACTG` character.
122#[inline(always)]
123pub fn unpack_base(base: u8) -> u8 {
124    debug_assert!(base < 4, "Base {base} is not <4.");
125    b"ACTG"[base as usize]
126}
127
128/// Complement an ASCII character: `A<>T` and `C<>G`.
129#[inline(always)]
130pub const fn complement_char(base: u8) -> u8 {
131    match base {
132        b'A' => b'T',
133        b'C' => b'G',
134        b'G' => b'C',
135        b'T' => b'A',
136        _ => panic!("Unexpected character. Expected one of ACTGactg.",),
137    }
138}
139
140/// Complement a 2-bit base: `0<>2` and `1<>3`.
141#[inline(always)]
142pub const fn complement_base(base: u8) -> u8 {
143    base ^ 2
144}
145
146/// Complement 8 lanes of 2-bit bases: `0<>2` and `1<>3`.
147#[inline(always)]
148pub fn complement_base_simd(base: u32x8) -> u32x8 {
149    base ^ u32x8::splat(2)
150}
151
152/// Reverse complement the 2-bit pairs in the input.
153#[inline(always)]
154const fn revcomp_raw(word: u64) -> u64 {
155    #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
156    {
157        let mut res = word.reverse_bits(); // ARM can reverse bits in a single instruction
158        res = ((res >> 1) & 0x5555_5555_5555_5555) | ((res & 0x5555_5555_5555_5555) << 1);
159        res ^ 0xAAAA_AAAA_AAAA_AAAA
160    }
161
162    #[cfg(not(any(target_arch = "arm", target_arch = "aarch64")))]
163    {
164        let mut res = word.swap_bytes();
165        res = ((res >> 4) & 0x0F0F_0F0F_0F0F_0F0F) | ((res & 0x0F0F_0F0F_0F0F_0F0F) << 4);
166        res = ((res >> 2) & 0x3333_3333_3333_3333) | ((res & 0x3333_3333_3333_3333) << 2);
167        res ^ 0xAAAA_AAAA_AAAA_AAAA
168    }
169}
170
171/// Compute the reverse complement of a short sequence packed in a `u64`.
172#[inline(always)]
173pub const fn revcomp_u64(word: u64, len: usize) -> u64 {
174    revcomp_raw(word) >> (usize::BITS as usize - 2 * len)
175}
176
177#[inline(always)]
178pub const fn revcomp_u128(word: u128, len: usize) -> u128 {
179    let low = word as u64;
180    let high = (word >> 64) as u64;
181    let rlow = revcomp_raw(low);
182    let rhigh = revcomp_raw(high);
183    let out = ((rlow as u128) << 64) | rhigh as u128;
184    out >> (u128::BITS as usize - 2 * len)
185}
186
187// ======================================================================
188// 1-BIT HELPER METHODS
189
190/// 1 when a char is ambiguous.
191#[inline(always)]
192pub fn char_is_ambiguous(base: u8) -> u8 {
193    // (!matches!(base, b'A' | b'C'  | b'G'  | b'T' | b'a' | b'c'  | b'g'  | b't')) as u8
194    let table = b"ACTG";
195    let upper_mask = !(b'a' - b'A');
196    (table[pack_char_lossy(base) as usize] != (base & upper_mask)) as u8
197}
198
199/// Reverse `len` bits packed in a `u64`.
200#[inline(always)]
201pub const fn rev_u64(word: u64, len: usize) -> u64 {
202    word.reverse_bits() >> (usize::BITS as usize - len)
203}
204
205/// Reverse `len` bits packed in a `u128`.
206#[inline(always)]
207pub const fn rev_u128(word: u128, len: usize) -> u128 {
208    word.reverse_bits() >> (u128::BITS as usize - len)
209}
210
211// ======================================================================
212
213impl<const B: usize> PackedSeqBase<'_, B>
214where
215    Bits<B>: SupportedBits,
216{
217    /// Shrink `seq` to only just cover the data.
218    #[inline(always)]
219    pub fn normalize(&self) -> Self {
220        let start_byte = self.offset / Self::C8;
221        let end_byte = (self.offset + self.len).div_ceil(Self::C8);
222        Self {
223            seq: &self.seq[start_byte..end_byte],
224            offset: self.offset % Self::C8,
225            len: self.len,
226        }
227    }
228
229    /// Return a `Vec<u8>` of ASCII `ACTG` characters.
230    #[inline(always)]
231    pub fn unpack(&self) -> Vec<u8> {
232        self.iter_bp().map(unpack_base).collect()
233    }
234}
235
236/// Read up to 32 bytes starting at idx.
237#[inline(always)]
238pub(crate) fn read_slice_32(seq: &[u8], idx: usize) -> u32x8 {
239    unsafe {
240        let src = seq.as_ptr().add(idx);
241        if idx + 32 <= seq.len() {
242            std::mem::transmute::<_, *const u32x8>(src).read_unaligned()
243        } else {
244            let num_bytes = seq.len().saturating_sub(idx);
245            let mut result = [0u8; 32];
246            std::ptr::copy_nonoverlapping(src, result.as_mut_ptr(), num_bytes);
247            std::mem::transmute(result)
248        }
249    }
250}
251
252/// Read up to 16 bytes starting at idx.
253#[allow(unused)]
254#[inline(always)]
255pub(crate) fn read_slice_16(seq: &[u8], idx: usize) -> u16x8 {
256    unsafe {
257        let src = seq.as_ptr().add(idx);
258        if idx + 16 <= seq.len() {
259            std::mem::transmute::<_, *const u16x8>(src).read_unaligned()
260        } else {
261            let num_bytes = seq.len().saturating_sub(idx);
262            let mut result = [0u8; 16];
263            std::ptr::copy_nonoverlapping(src, result.as_mut_ptr(), num_bytes);
264            std::mem::transmute(result)
265        }
266    }
267}
268
269impl<'s, const B: usize> Seq<'s> for PackedSeqBase<'s, B>
270where
271    Bits<B>: SupportedBits,
272{
273    const BITS_PER_CHAR: usize = B;
274    const BASES_PER_BYTE: usize = Self::C8;
275    type SeqVec = PackedSeqVecBase<B>;
276
277    #[inline(always)]
278    fn len(&self) -> usize {
279        self.len
280    }
281
282    #[inline(always)]
283    fn is_empty(&self) -> bool {
284        self.len == 0
285    }
286
287    #[inline(always)]
288    fn get_ascii(&self, index: usize) -> u8 {
289        unpack_base(self.get(index))
290    }
291
292    /// Convert a short sequence (kmer) to a packed representation as `u64`.
293    /// Panics if `self` is longer than 32 characters.
294    #[inline(always)]
295    fn as_u64(&self) -> u64 {
296        assert!(self.len() <= 64 / B);
297        debug_assert!(self.seq.len() <= 9);
298
299        let mask = u64::MAX >> (64 - B * self.len());
300
301        // The unaligned read is OK, because we ensure that the underlying `PackedSeqVecBase::seq` always
302        // has at least 16 bytes (the size of a u128) of padding at the end.
303        if self.len() <= Self::K64 {
304            let x = unsafe { (self.seq.as_ptr() as *const u64).read_unaligned() };
305            (x >> (B * self.offset)) & mask
306        } else {
307            let x = unsafe { (self.seq.as_ptr() as *const u128).read_unaligned() };
308            (x >> (B * self.offset)) as u64 & mask
309        }
310    }
311
312    /// Convert a short sequence (kmer) to a packed representation of its reverse complement as `usize`.
313    /// Panics if `self` is longer than 32 characters.
314    #[inline(always)]
315    fn revcomp_as_u64(&self) -> u64 {
316        match B {
317            1 => rev_u64(self.as_u64(), self.len()),
318            2 => revcomp_u64(self.as_u64(), self.len()),
319            _ => panic!("Rev(comp) is only supported for 1-bit and 2-bit alphabets."),
320        }
321    }
322
323    /// Convert a short sequence (kmer) to a packed representation as `u128`.
324    /// Panics if `self` is longer than 64 characters.
325    #[inline(always)]
326    fn as_u128(&self) -> u128 {
327        assert!(
328            self.len() <= (128 - 8) / B + 1,
329            "Sequences >61 long cannot be read with a single unaligned u128 read."
330        );
331        debug_assert!(self.seq.len() <= 17);
332
333        let mask = u128::MAX >> (128 - B * self.len());
334
335        // The unaligned read is OK, because we ensure that the underlying `PackedSeqVecBase::seq` always
336        // has at least 16 bytes (the size of a u128) of padding at the end.
337        let x = unsafe { (self.seq.as_ptr() as *const u128).read_unaligned() };
338        (x >> (B * self.offset)) & mask
339    }
340
341    /// Convert a short sequence (kmer) to a packed representation of its reverse complement as `usize`.
342    /// Panics if `self` is longer than 64 characters.
343    #[inline(always)]
344    fn revcomp_as_u128(&self) -> u128 {
345        match B {
346            1 => rev_u128(self.as_u128(), self.len()),
347            2 => revcomp_u128(self.as_u128(), self.len()),
348            _ => panic!("Rev(comp) is only supported for 1-bit and 2-bit alphabets."),
349        }
350    }
351
352    #[inline(always)]
353    fn to_vec(&self) -> PackedSeqVecBase<B> {
354        assert_eq!(self.offset, 0);
355        PackedSeqVecBase {
356            seq: self
357                .seq
358                .iter()
359                .copied()
360                .chain(std::iter::repeat_n(0u8, PADDING))
361                .collect(),
362            len: self.len,
363        }
364    }
365
366    fn to_revcomp(&self) -> PackedSeqVecBase<B> {
367        match B {
368            1 | 2 => {}
369            _ => panic!("Can only reverse (&complement) 1-bit and 2-bit packed sequences.",),
370        }
371
372        let mut seq = self.seq[..(self.offset + self.len).div_ceil(Self::C8)]
373            .iter()
374            // 1. reverse the bytes
375            .rev()
376            .copied()
377            .map(|mut res| {
378                match B {
379                    2 => {
380                        // 2. swap the bases in the byte
381                        // This is auto-vectorized.
382                        res = ((res >> 4) & 0x0F) | ((res & 0x0F) << 4);
383                        res = ((res >> 2) & 0x33) | ((res & 0x33) << 2);
384                        // Complement the bases.
385                        res ^ 0xAA
386                    }
387                    1 => res.reverse_bits(),
388                    _ => unreachable!(),
389                }
390            })
391            .chain(std::iter::repeat_n(0u8, PADDING))
392            .collect::<Vec<u8>>();
393
394        // 3. Shift away the offset.
395        let new_offset = (Self::C8 - (self.offset + self.len) % Self::C8) % Self::C8;
396
397        if new_offset > 0 {
398            // Shift everything left by `2*new_offset` bits.
399            let shift = B * new_offset;
400            *seq.last_mut().unwrap() >>= shift;
401            // This loop is also auto-vectorized.
402            for i in 0..seq.len() - 1 {
403                seq[i] = (seq[i] >> shift) | (seq[i + 1] << (8 - shift));
404            }
405        }
406
407        PackedSeqVecBase { seq, len: self.len }
408    }
409
410    #[inline(always)]
411    fn slice(&self, range: Range<usize>) -> Self {
412        debug_assert!(
413            range.end <= self.len,
414            "Slice index out of bounds: {} > {}",
415            range.end,
416            self.len
417        );
418        PackedSeqBase {
419            seq: self.seq,
420            offset: self.offset + range.start,
421            len: range.end - range.start,
422        }
423        .normalize()
424    }
425
426    #[inline(always)]
427    fn iter_bp(self) -> impl ExactSizeIterator<Item = u8> {
428        assert!(self.len <= self.seq.len() * Self::C8);
429
430        let this = self.normalize();
431
432        // read u64 at a time?
433        let mut byte = 0;
434        (0..this.len + this.offset)
435            .map(
436                #[inline(always)]
437                move |i| {
438                    if i % Self::C8 == 0 {
439                        byte = this.seq[i / Self::C8];
440                    }
441                    // Shift byte instead of i?
442                    (byte >> (B * (i % Self::C8))) & Self::CHAR_MASK as u8
443                },
444            )
445            .advance(this.offset)
446    }
447
448    #[inline(always)]
449    fn par_iter_bp(self, context: usize) -> PaddedIt<impl ChunkIt<S>> {
450        #[cfg(target_endian = "big")]
451        panic!("Big endian architectures are not supported.");
452
453        let this = self.normalize();
454        let o = this.offset;
455        assert!(o < Self::C8);
456
457        let num_kmers = if this.len == 0 {
458            0
459        } else {
460            (this.len + o).saturating_sub(context - 1)
461        };
462        // without +o, since we don't need them in the stride.
463        let num_kmers_stride = this.len.saturating_sub(context - 1);
464        let n = num_kmers_stride.div_ceil(L).next_multiple_of(Self::C8);
465        let bytes_per_chunk = n / Self::C8;
466        let padding = Self::C8 * L * bytes_per_chunk - num_kmers_stride;
467
468        let offsets: [usize; 8] = from_fn(|l| l * bytes_per_chunk);
469        let mut cur = S::ZERO;
470
471        // Boxed, so it doesn't consume precious registers.
472        // Without this, cur is not always inlined into a register.
473        let mut buf = Box::new([S::ZERO; 8]);
474
475        let par_len = if num_kmers == 0 {
476            0
477        } else {
478            n + context + o - 1
479        };
480        let it = (0..par_len)
481            .map(
482                #[inline(always)]
483                move |i| {
484                    if i % Self::C32 == 0 {
485                        if i % Self::C256 == 0 {
486                            // Read a u256 for each lane containing the next 128 characters.
487                            let data: [u32x8; 8] = from_fn(
488                                #[inline(always)]
489                                |lane| read_slice_32(this.seq, offsets[lane] + (i / Self::C8)),
490                            );
491                            *buf = transpose(data);
492                        }
493                        cur = buf[(i % Self::C256) / Self::C32];
494                    }
495                    // Extract the last 2 bits of each character.
496                    let chars = cur & S::splat(Self::CHAR_MASK as u32);
497                    // Shift remaining characters to the right.
498                    cur = cur >> S::splat(B as u32);
499                    chars
500                },
501            )
502            .advance(o);
503
504        PaddedIt { it, padding }
505    }
506
507    #[inline(always)]
508    fn par_iter_bp_delayed(self, context: usize, delay: Delay) -> PaddedIt<impl ChunkIt<(S, S)>> {
509        self.par_iter_bp_delayed_with_factor(context, delay, 1)
510    }
511
512    /// NOTE: When `self` starts does not start at a byte boundary, the
513    /// 'delayed' character is not guaranteed to be `0`.
514    #[inline(always)]
515    fn par_iter_bp_delayed_2(
516        self,
517        context: usize,
518        delay1: Delay,
519        delay2: Delay,
520    ) -> PaddedIt<impl ChunkIt<(S, S, S)>> {
521        self.par_iter_bp_delayed_2_with_factor(context, delay1, delay2, 1)
522    }
523
524    /// Compares 29 characters at a time.
525    fn cmp_lcp(&self, other: &Self) -> (std::cmp::Ordering, usize) {
526        let mut lcp = 0;
527        let min_len = self.len.min(other.len);
528        for i in (0..min_len).step_by(Self::K64) {
529            let len = (min_len - i).min(Self::K64);
530            let this = self.slice(i..i + len);
531            let other = other.slice(i..i + len);
532            let this_word = this.as_u64();
533            let other_word = other.as_u64();
534            if this_word != other_word {
535                // Unfortunately, bases are packed in little endian order, so the default order is reversed.
536                let eq = this_word ^ other_word;
537                let t = eq.trailing_zeros() as usize / B * B;
538                lcp += t / B;
539                let mask = (Self::CHAR_MASK) << t;
540                return ((this_word & mask).cmp(&(other_word & mask)), lcp);
541            }
542            lcp += len;
543        }
544        (self.len.cmp(&other.len), lcp)
545    }
546
547    #[inline(always)]
548    fn get(&self, index: usize) -> u8 {
549        let offset = self.offset + index;
550        let idx = offset / Self::C8;
551        let offset = offset % Self::C8;
552        (self.seq[idx] >> (B * offset)) & Self::CHAR_MASK as u8
553    }
554}
555
556impl<'s, const B: usize> PackedSeqBase<'s, B>
557where
558    Bits<B>: SupportedBits,
559{
560    #[inline(always)]
561    pub fn par_iter_bp_delayed_with_factor(
562        self,
563        context: usize,
564        Delay(delay): Delay,
565        factor: usize,
566    ) -> PaddedIt<impl ChunkIt<(S, S)> + use<'s, B>> {
567        #[cfg(target_endian = "big")]
568        panic!("Big endian architectures are not supported.");
569
570        assert!(
571            delay < usize::MAX / 2,
572            "Delay={} should be >=0.",
573            delay as isize
574        );
575
576        let this = self.normalize();
577        let o = this.offset;
578        assert!(o < Self::C8);
579
580        let num_kmers = if this.len == 0 {
581            0
582        } else {
583            (this.len + o).saturating_sub(context - 1)
584        };
585        // without +o, since we don't need them in the stride.
586        let num_kmers_stride = this.len.saturating_sub(context - 1);
587        let n = num_kmers_stride
588            .div_ceil(L)
589            .next_multiple_of(factor * Self::C8);
590        let bytes_per_chunk = n / Self::C8;
591        let padding = Self::C8 * L * bytes_per_chunk - num_kmers_stride;
592
593        let offsets: [usize; 8] = from_fn(|l| l * bytes_per_chunk);
594        let mut upcoming = S::ZERO;
595        let mut upcoming_d = S::ZERO;
596
597        // Even buf_len is nice to only have the write==buf_len check once.
598        // We also make it the next power of 2, for faster modulo operations.
599        // delay/16: number of bp in a u32.
600        // +8: some 'random' padding
601        let buf_len = (delay / Self::C32 + 8).next_power_of_two();
602        let buf_mask = buf_len - 1;
603        let mut buf = vec![S::ZERO; buf_len];
604        let mut write_idx = 0;
605        // We compensate for the first delay/16 triggers of the check below that
606        // happen before the delay is actually reached.
607        let mut read_idx = (buf_len - delay / Self::C32) % buf_len;
608
609        let par_len = if num_kmers == 0 {
610            0
611        } else {
612            n + context + o - 1
613        };
614        let it = (0..par_len)
615            .map(
616                #[inline(always)]
617                move |i| {
618                    if i % Self::C32 == 0 {
619                        if i % Self::C256 == 0 {
620                            // Read a u256 for each lane containing the next 128 characters.
621                            let data: [u32x8; 8] = from_fn(
622                                #[inline(always)]
623                                |lane| read_slice_32(this.seq, offsets[lane] + (i / Self::C8)),
624                            );
625                            unsafe {
626                                *TryInto::<&mut [u32x8; 8]>::try_into(
627                                    buf.get_unchecked_mut(write_idx..write_idx + 8),
628                                )
629                                .unwrap_unchecked() = transpose(data);
630                            }
631                            if i == 0 {
632                                // Mask out chars before the offset.
633                                let elem = !((1u32 << (B * o)) - 1);
634                                let mask = S::splat(elem);
635                                buf[write_idx] &= mask;
636                            }
637                        }
638                        upcoming = buf[write_idx];
639                        write_idx += 1;
640                        write_idx &= buf_mask;
641                    }
642                    if i % Self::C32 == delay % Self::C32 {
643                        unsafe { assert_unchecked(read_idx < buf.len()) };
644                        upcoming_d = buf[read_idx];
645                        read_idx += 1;
646                        read_idx &= buf_mask;
647                    }
648                    // Extract the last 2 bits of each character.
649                    let chars = upcoming & S::splat(Self::CHAR_MASK as u32);
650                    let chars_d = upcoming_d & S::splat(Self::CHAR_MASK as u32);
651                    // Shift remaining characters to the right.
652                    upcoming = upcoming >> S::splat(B as u32);
653                    upcoming_d = upcoming_d >> S::splat(B as u32);
654                    (chars, chars_d)
655                },
656            )
657            .advance(o);
658
659        PaddedIt { it, padding }
660    }
661
662    /// When iterating over 2-bit and 1-bit encoded data in parallel,
663    /// one must ensure that they have the same stride.
664    /// On the larger type, set `factor` as the ratio to the smaller one,
665    /// so that the stride in bytes is a multiple of `factor`,
666    /// so that the smaller type also has a byte-aligned stride.
667    #[inline(always)]
668    pub fn par_iter_bp_delayed_2_with_factor(
669        self,
670        context: usize,
671        Delay(delay1): Delay,
672        Delay(delay2): Delay,
673        factor: usize,
674    ) -> PaddedIt<impl ChunkIt<(S, S, S)> + use<'s, B>> {
675        #[cfg(target_endian = "big")]
676        panic!("Big endian architectures are not supported.");
677
678        let this = self.normalize();
679        let o = this.offset;
680        assert!(o < Self::C8);
681        assert!(delay1 <= delay2, "Delay1 must be at most delay2.");
682
683        let num_kmers = if this.len == 0 {
684            0
685        } else {
686            (this.len + o).saturating_sub(context - 1)
687        };
688        // without +o, since we don't need them in the stride.
689        let num_kmers_stride = this.len.saturating_sub(context - 1);
690        let n = num_kmers_stride
691            .div_ceil(L)
692            .next_multiple_of(factor * Self::C8);
693        let bytes_per_chunk = n / Self::C8;
694        let padding = Self::C8 * L * bytes_per_chunk - num_kmers_stride;
695
696        let offsets: [usize; 8] = from_fn(|l| l * bytes_per_chunk);
697        let mut upcoming = S::ZERO;
698        let mut upcoming_d1 = S::ZERO;
699        let mut upcoming_d2 = S::ZERO;
700
701        // Even buf_len is nice to only have the write==buf_len check once.
702        let buf_len = (delay2 / Self::C32 + 8).next_power_of_two();
703        let buf_mask = buf_len - 1;
704        let mut buf = vec![S::ZERO; buf_len];
705        let mut write_idx = 0;
706        // We compensate for the first delay/16 triggers of the check below that
707        // happen before the delay is actually reached.
708        let mut read_idx1 = (buf_len - delay1 / Self::C32) % buf_len;
709        let mut read_idx2 = (buf_len - delay2 / Self::C32) % buf_len;
710
711        let par_len = if num_kmers == 0 {
712            0
713        } else {
714            n + context + o - 1
715        };
716        let it = (0..par_len)
717            .map(
718                #[inline(always)]
719                move |i| {
720                    if i % Self::C32 == 0 {
721                        if i % Self::C256 == 0 {
722                            // Read a u256 for each lane containing the next 128 characters.
723                            let data: [u32x8; 8] = from_fn(
724                                #[inline(always)]
725                                |lane| read_slice_32(this.seq, offsets[lane] + (i / Self::C8)),
726                            );
727                            unsafe {
728                                *TryInto::<&mut [u32x8; 8]>::try_into(
729                                    buf.get_unchecked_mut(write_idx..write_idx + 8),
730                                )
731                                .unwrap_unchecked() = transpose(data);
732                            }
733                            if i == 0 {
734                                // Mask out chars before the offset.
735                                let elem = !((1u32 << (B * o)) - 1);
736                                let mask = S::splat(elem);
737                                buf[write_idx] &= mask;
738                            }
739                        }
740                        upcoming = buf[write_idx];
741                        write_idx += 1;
742                        write_idx &= buf_mask;
743                    }
744                    if i % Self::C32 == delay1 % Self::C32 {
745                        unsafe { assert_unchecked(read_idx1 < buf.len()) };
746                        upcoming_d1 = buf[read_idx1];
747                        read_idx1 += 1;
748                        read_idx1 &= buf_mask;
749                    }
750                    if i % Self::C32 == delay2 % Self::C32 {
751                        unsafe { assert_unchecked(read_idx2 < buf.len()) };
752                        upcoming_d2 = buf[read_idx2];
753                        read_idx2 += 1;
754                        read_idx2 &= buf_mask;
755                    }
756                    // Extract the last 2 bits of each character.
757                    let chars = upcoming & S::splat(Self::CHAR_MASK as u32);
758                    let chars_d1 = upcoming_d1 & S::splat(Self::CHAR_MASK as u32);
759                    let chars_d2 = upcoming_d2 & S::splat(Self::CHAR_MASK as u32);
760                    // Shift remaining characters to the right.
761                    upcoming = upcoming >> S::splat(B as u32);
762                    upcoming_d1 = upcoming_d1 >> S::splat(B as u32);
763                    upcoming_d2 = upcoming_d2 >> S::splat(B as u32);
764                    (chars, chars_d1, chars_d2)
765                },
766            )
767            .advance(o);
768
769        PaddedIt { it, padding }
770    }
771}
772
773impl<const B: usize> PartialEq for PackedSeqBase<'_, B>
774where
775    Bits<B>: SupportedBits,
776{
777    /// Compares 29 characters at a time.
778    fn eq(&self, other: &Self) -> bool {
779        if self.len != other.len {
780            return false;
781        }
782        for i in (0..self.len).step_by(Self::K64) {
783            let len = (self.len - i).min(Self::K64);
784            let this = self.slice(i..i + len);
785            let that = other.slice(i..i + len);
786            if this.as_u64() != that.as_u64() {
787                return false;
788            }
789        }
790        true
791    }
792}
793
794impl<const B: usize> Eq for PackedSeqBase<'_, B> where Bits<B>: SupportedBits {}
795
796impl<const B: usize> PartialOrd for PackedSeqBase<'_, B>
797where
798    Bits<B>: SupportedBits,
799{
800    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
801        Some(self.cmp(other))
802    }
803}
804
805impl<const B: usize> Ord for PackedSeqBase<'_, B>
806where
807    Bits<B>: SupportedBits,
808{
809    /// Compares 29 characters at a time.
810    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
811        let min_len = self.len.min(other.len);
812        for i in (0..min_len).step_by(Self::K64) {
813            let len = (min_len - i).min(Self::K64);
814            let this = self.slice(i..i + len);
815            let other = other.slice(i..i + len);
816            let this_word = this.as_u64();
817            let other_word = other.as_u64();
818            if this_word != other_word {
819                // Unfortunately, bases are packed in little endian order, so the default order is reversed.
820                let eq = this_word ^ other_word;
821                let t = eq.trailing_zeros() as usize / B * B;
822                let mask = (Self::CHAR_MASK) << t;
823                return (this_word & mask).cmp(&(other_word & mask));
824            }
825        }
826        self.len.cmp(&other.len)
827    }
828}
829
830impl<const B: usize> SeqVec for PackedSeqVecBase<B>
831where
832    Bits<B>: SupportedBits,
833{
834    type Seq<'s> = PackedSeqBase<'s, B>;
835
836    #[inline(always)]
837    fn into_raw(mut self) -> Vec<u8> {
838        self.seq.resize(self.len.div_ceil(Self::C8), 0);
839        self.seq
840    }
841
842    #[inline(always)]
843    fn as_slice(&self) -> Self::Seq<'_> {
844        PackedSeqBase {
845            seq: &self.seq[..self.len.div_ceil(Self::C8)],
846            offset: 0,
847            len: self.len,
848        }
849    }
850
851    #[inline(always)]
852    fn len(&self) -> usize {
853        self.len
854    }
855
856    #[inline(always)]
857    fn is_empty(&self) -> bool {
858        self.len == 0
859    }
860
861    #[inline(always)]
862    fn clear(&mut self) {
863        self.seq.clear();
864        self.len = 0;
865    }
866
867    fn push_seq<'a>(&mut self, seq: PackedSeqBase<'_, B>) -> Range<usize> {
868        let start = self.len.next_multiple_of(Self::C8) + seq.offset;
869        let end = start + seq.len();
870        // Reserve *additional* capacity.
871        self.seq.reserve(seq.seq.len());
872        // Shrink away the padding.
873        self.seq.resize(self.len.div_ceil(Self::C8), 0);
874        // Extend.
875        self.seq.extend(seq.seq);
876        // Push padding.
877        self.seq.extend(std::iter::repeat_n(0u8, PADDING));
878        self.len = end;
879        start..end
880    }
881
882    /// For `PackedSeqVec` (2-bit encoding): map ASCII `ACGT` (and `acgt`) to DNA `0132` in that order.
883    /// Other characters are silently mapped into `0..4`.
884    ///
885    /// Uses the BMI2 `pext` instruction when available, based on the
886    /// `n_to_bits_pext` method described at
887    /// <https://github.com/Daniel-Liu-c0deb0t/cute-nucleotides>.
888    ///
889    /// For `BitSeqVec` (1-bit encoding): map `ACGTacgt` to `0`, and everything else to `1`.
890    fn push_ascii(&mut self, seq: &[u8]) -> Range<usize> {
891        match B {
892            1 | 2 => {}
893            _ => panic!(
894                "Can only use ASCII input for 2-bit DNA packing, or 1-bit ambiguous indicators."
895            ),
896        }
897
898        self.seq
899            .resize((self.len + seq.len()).div_ceil(Self::C8) + PADDING, 0);
900        let start_aligned = self.len.next_multiple_of(Self::C8);
901        let start = self.len;
902        let len = seq.len();
903        let mut idx = self.len / Self::C8;
904
905        let parse_base = |base| match B {
906            1 => char_is_ambiguous(base),
907            2 => pack_char_lossy(base),
908            _ => unreachable!(),
909        };
910
911        let unaligned = core::cmp::min(start_aligned - start, len);
912        if unaligned > 0 {
913            let mut packed_byte = self.seq[idx];
914            for &base in &seq[..unaligned] {
915                packed_byte |= parse_base(base) << ((self.len % Self::C8) * B);
916                self.len += 1;
917            }
918            self.seq[idx] = packed_byte;
919            idx += 1;
920        }
921
922        #[allow(unused)]
923        let mut last = unaligned;
924
925        if B == 2 {
926            #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
927            {
928                last = unaligned + (len - unaligned) / 8 * 8;
929
930                for i in (unaligned..last).step_by(8) {
931                    let chunk =
932                        unsafe { seq.get_unchecked(i..i + 8).try_into().unwrap_unchecked() };
933                    let ascii = u64::from_le_bytes(chunk);
934                    let packed_bytes =
935                        unsafe { std::arch::x86_64::_pext_u64(ascii, 0x0606060606060606) } as u16;
936                    unsafe {
937                        self.seq
938                            .get_unchecked_mut(idx..(idx + 2))
939                            .copy_from_slice(&packed_bytes.to_le_bytes())
940                    };
941                    idx += 2;
942                    self.len += 8;
943                }
944            }
945
946            #[cfg(target_feature = "neon")]
947            {
948                use core::arch::aarch64::{
949                    vandq_u8, vdup_n_u8, vld1q_u8, vpadd_u8, vshlq_u8, vst1_u8,
950                };
951                use core::mem::transmute;
952
953                last = unaligned + (len - unaligned) / 16 * 16;
954
955                for i in (unaligned..last).step_by(16) {
956                    unsafe {
957                        let ascii = vld1q_u8(seq.as_ptr().add(i));
958                        let masked_bits = vandq_u8(ascii, transmute([6i8; 16]));
959                        let (bits_0, bits_1) = transmute(vshlq_u8(
960                            masked_bits,
961                            transmute([-1i8, 1, 3, 5, -1, 1, 3, 5, -1, 1, 3, 5, -1, 1, 3, 5]),
962                        ));
963                        let half_packed = vpadd_u8(bits_0, bits_1);
964                        let packed = vpadd_u8(half_packed, vdup_n_u8(0));
965                        vst1_u8(self.seq.as_mut_ptr().add(idx), packed);
966                        idx += Self::C8;
967                        self.len += 16;
968                    }
969                }
970            }
971        }
972
973        if B == 1 {
974            #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
975            {
976                last = len;
977                self.len += len - unaligned;
978
979                let mut last_i = unaligned;
980
981                for i in (unaligned..last).step_by(32) {
982                    use std::mem::transmute as t;
983
984                    use wide::CmpEq;
985                    // Wide doesn't have u8x32, so this is messy here...
986                    type S = wide::i8x32;
987                    let chars: S = unsafe { t(read_slice_32(seq, i)) };
988                    let upper_mask = !(b'a' - b'A');
989                    // make everything upper case
990                    let chars = chars & S::splat(upper_mask as i8);
991                    let lossy_encoded = chars & S::splat(6);
992                    let table = unsafe { S::from(t::<_, S>(*b"AxCxTxGxxxxxxxxxAxCxTxGxxxxxxxxx")) };
993                    let lookup: S = unsafe {
994                        t(std::arch::x86_64::_mm256_shuffle_epi8(
995                            t(table),
996                            t(lossy_encoded),
997                        ))
998                    };
999                    let packed_bytes = !(chars.cmp_eq(lookup).move_mask() as u32);
1000
1001                    last_i = i;
1002                    unsafe {
1003                        self.seq
1004                            .get_unchecked_mut(idx..(idx + 4))
1005                            .copy_from_slice(&packed_bytes.to_le_bytes())
1006                    };
1007                    idx += 4;
1008                }
1009
1010                // Fix up trailing bytes.
1011                if unaligned < last {
1012                    idx -= 4;
1013                    let mut val = unsafe {
1014                        u32::from_le_bytes(
1015                            self.seq
1016                                .get_unchecked(idx..(idx + 4))
1017                                .try_into()
1018                                .unwrap_unchecked(),
1019                        )
1020                    };
1021                    // keep only the `last - last_i` low bits.
1022                    let keep = last - last_i;
1023                    val <<= 32 - keep;
1024                    val >>= 32 - keep;
1025                    unsafe {
1026                        self.seq
1027                            .get_unchecked_mut(idx..(idx + 4))
1028                            .copy_from_slice(&val.to_le_bytes())
1029                    };
1030                    idx += keep.div_ceil(8);
1031                }
1032            }
1033
1034            #[cfg(target_feature = "neon")]
1035            {
1036                use core::arch::aarch64::*;
1037                use core::mem::transmute;
1038
1039                last = unaligned + (len - unaligned) / 64 * 64;
1040
1041                for i in (unaligned..last).step_by(64) {
1042                    unsafe {
1043                        let ptr = seq.as_ptr().add(i);
1044                        let chars = vld4q_u8(ptr);
1045
1046                        let upper_mask = vdupq_n_u8(!(b'a' - b'A'));
1047                        let chars = neon::map_8x16x4(chars, |v| vandq_u8(v, upper_mask));
1048
1049                        let two_bits_mask = vdupq_n_u8(6);
1050                        let lossy_encoded = neon::map_8x16x4(chars, |v| vandq_u8(v, two_bits_mask));
1051
1052                        let table = transmute(*b"AxCxTxGxxxxxxxxx");
1053                        let lookup = neon::map_8x16x4(lossy_encoded, |v| vqtbl1q_u8(table, v));
1054
1055                        let mask = neon::map_two_8x16x4(chars, lookup, |v1, v2| vceqq_u8(v1, v2));
1056                        let packed_bytes = !neon::movemask_64(mask);
1057
1058                        self.seq[idx..(idx + 8)].copy_from_slice(&packed_bytes.to_le_bytes());
1059                        idx += 8;
1060                        self.len += 64;
1061                    }
1062                }
1063            }
1064        }
1065
1066        let mut packed_byte = 0;
1067        for &base in &seq[last..] {
1068            packed_byte |= parse_base(base) << ((self.len % Self::C8) * B);
1069            self.len += 1;
1070            if self.len % Self::C8 == 0 {
1071                self.seq[idx] = packed_byte;
1072                idx += 1;
1073                packed_byte = 0;
1074            }
1075        }
1076        if self.len % Self::C8 != 0 && last < len {
1077            self.seq[idx] = packed_byte;
1078            idx += 1;
1079        }
1080        assert_eq!(idx + PADDING, self.seq.len());
1081        start..start + len
1082    }
1083
1084    #[cfg(feature = "rand")]
1085    fn random(n: usize) -> Self {
1086        use rand::{RngCore, SeedableRng};
1087
1088        let byte_len = n.div_ceil(Self::C8);
1089        let mut seq = vec![0; byte_len + PADDING];
1090        rand::rngs::SmallRng::from_os_rng().fill_bytes(&mut seq[..byte_len]);
1091        // Ensure that the last byte is padded with zeros.
1092        if n % Self::C8 != 0 {
1093            seq[byte_len - 1] &= (1 << (B * (n % Self::C8))) - 1;
1094        }
1095
1096        Self { seq, len: n }
1097    }
1098}
1099
1100impl PackedSeqVecBase<1> {
1101    pub fn with_len(n: usize) -> Self {
1102        Self {
1103            seq: vec![0; n.div_ceil(Self::C8) + PADDING],
1104            len: n,
1105        }
1106    }
1107
1108    pub fn random(len: usize, n_frac: f32) -> Self {
1109        let byte_len = len.div_ceil(Self::C8);
1110        let mut seq = vec![0; byte_len + PADDING];
1111
1112        assert!(
1113            (0.0..=0.3).contains(&n_frac),
1114            "n_frac={} should be in [0, 0.3]",
1115            n_frac
1116        );
1117
1118        for _ in 0..(len as f32 * n_frac) as usize {
1119            let idx = rand::random::<u64>() as usize % len;
1120            let byte = idx / Self::C8;
1121            let offset = idx % Self::C8;
1122            seq[byte] |= 1 << offset;
1123        }
1124
1125        Self { seq, len }
1126    }
1127}
1128
1129impl<'s> PackedSeqBase<'s, 1> {
1130    /// An iterator indicating for each kmer whether it contains ambiguous bases.
1131    ///
1132    /// Returns n-(k-1) elements.
1133    #[inline(always)]
1134    pub fn iter_kmer_ambiguity(self, k: usize) -> impl ExactSizeIterator<Item = bool> + use<'s> {
1135        let this = self.normalize();
1136        assert!(k > 0);
1137        assert!(k <= Self::K64);
1138        (this.offset..this.offset + this.len.saturating_sub(k - 1))
1139            .map(move |i| self.read_kmer(k, i) != 0)
1140    }
1141
1142    /// A parallel iterator indicating for each kmer whether it contains ambiguous bases.
1143    ///
1144    /// First element is the 'kmer' consisting only of the first character of each chunk.
1145    ///
1146    /// `k`: length of windows to check
1147    /// `context`: number of overlapping iterations +1. To determine stride of each lane.
1148    /// `skip`: Set to `context-1` to skip the iterations added by the context.
1149    #[inline(always)]
1150    pub fn par_iter_kmer_ambiguity(
1151        self,
1152        k: usize,
1153        context: usize,
1154        skip: usize,
1155    ) -> PaddedIt<impl ChunkIt<S> + use<'s>> {
1156        #[cfg(target_endian = "big")]
1157        panic!("Big endian architectures are not supported.");
1158
1159        assert!(k <= 64, "par_iter_kmers requires k<=64, but k={k}");
1160
1161        let this = self.normalize();
1162        let o = this.offset;
1163        assert!(o < Self::C8);
1164
1165        let num_kmers = if this.len == 0 {
1166            0
1167        } else {
1168            (this.len + o).saturating_sub(context - 1)
1169        };
1170        // without +o, since we don't need them in the stride.
1171        let num_kmers_stride = this.len.saturating_sub(context - 1);
1172        let n = num_kmers_stride.div_ceil(L).next_multiple_of(Self::C8);
1173        let bytes_per_chunk = n / Self::C8;
1174        let padding = Self::C8 * L * bytes_per_chunk - num_kmers_stride;
1175
1176        let offsets: [usize; 8] = from_fn(|l| l * bytes_per_chunk);
1177
1178        //     prev2 prev    cur
1179        //           0..31 | 32..63
1180        // mask      00001111110000
1181        // mask      00000111111000
1182        // mask      00000011111100
1183        // mask      00000001111110
1184        // mask      00000000111111
1185        //           cur     next
1186        //           32..63| 64..95
1187        // mask      11111100000000
1188
1189        // [prev2, prev, cur]
1190        let mut cur = [S::ZERO; 3];
1191        let mut mask = [S::ZERO; 3];
1192        if k <= 32 {
1193            // high k bits of cur
1194            mask[2] = (S::MAX) << S::splat(32 - k as u32);
1195        } else {
1196            mask[2] = S::MAX;
1197            mask[1] = (S::MAX) << S::splat(64 - k as u32);
1198        }
1199
1200        #[inline(always)]
1201        fn rotate_mask(mask: &mut [S; 3], r: u32) {
1202            let carry01 = mask[0] >> S::splat(32 - r);
1203            let carry12 = mask[1] >> S::splat(32 - r);
1204            mask[0] = mask[0] << r;
1205            mask[1] = (mask[1] << r) | carry01;
1206            mask[2] = (mask[2] << r) | carry12;
1207        }
1208
1209        // Boxed, so it doesn't consume precious registers.
1210        // Without this, cur is not always inlined into a register.
1211        let mut buf = Box::new([S::ZERO; 8]);
1212
1213        // We skip the first o iterations.
1214        let par_len = if num_kmers == 0 { 0 } else { n + k + o - 1 };
1215
1216        let mut read = {
1217            #[inline(always)]
1218            move |i: usize, cur: &mut [S; 3]| {
1219                if i % Self::C256 == 0 {
1220                    // Read a u256 for each lane containing the next 128 characters.
1221                    let data: [u32x8; 8] = from_fn(
1222                        #[inline(always)]
1223                        |lane| read_slice_32(this.seq, offsets[lane] + (i / Self::C8)),
1224                    );
1225                    *buf = transpose(data);
1226                }
1227                cur[0] = cur[1];
1228                cur[1] = cur[2];
1229                cur[2] = buf[(i % Self::C256) / Self::C32];
1230            }
1231        };
1232
1233        // Precompute the first o+skip iterations.
1234        let mut to_skip = o + skip;
1235        let mut i = 0;
1236        while to_skip > 0 {
1237            read(i, &mut cur);
1238            i += 32;
1239            if to_skip >= 32 {
1240                to_skip -= 32;
1241            } else {
1242                mask[0] = mask[1];
1243                mask[1] = mask[2];
1244                mask[2] = S::splat(0);
1245                // rotate mask by remainder
1246                rotate_mask(&mut mask, to_skip as u32);
1247                break;
1248            }
1249        }
1250
1251        let it = (o + skip..par_len).map(
1252            #[inline(always)]
1253            move |i| {
1254                if i % Self::C32 == 0 {
1255                    read(i, &mut cur);
1256                    mask[0] = mask[1];
1257                    mask[1] = mask[2];
1258                    mask[2] = S::splat(0);
1259                }
1260
1261                rotate_mask(&mut mask, 1);
1262                !((cur[0] & mask[0]) | (cur[1] & mask[1]) | (cur[2] & mask[2])).cmp_eq(S::splat(0))
1263            },
1264        );
1265
1266        PaddedIt { it, padding }
1267    }
1268}
1269
1270#[cfg(target_feature = "neon")]
1271mod neon {
1272    use core::arch::aarch64::*;
1273
1274    #[inline(always)]
1275    pub fn movemask_64(v: uint8x16x4_t) -> u64 {
1276        // https://stackoverflow.com/questions/74722950/convert-vector-compare-mask-into-bit-mask-in-aarch64-simd-or-arm-neon/74748402#74748402
1277        unsafe {
1278            let acc = vsriq_n_u8(vsriq_n_u8(v.3, v.2, 1), vsriq_n_u8(v.1, v.0, 1), 2);
1279            vget_lane_u64(
1280                vreinterpret_u64_u8(vshrn_n_u16(
1281                    vreinterpretq_u16_u8(vsriq_n_u8(acc, acc, 4)),
1282                    4,
1283                )),
1284                0,
1285            )
1286        }
1287    }
1288
1289    #[inline(always)]
1290    pub fn map_8x16x4<F>(v: uint8x16x4_t, mut f: F) -> uint8x16x4_t
1291    where
1292        F: FnMut(uint8x16_t) -> uint8x16_t,
1293    {
1294        uint8x16x4_t(f(v.0), f(v.1), f(v.2), f(v.3))
1295    }
1296
1297    #[inline(always)]
1298    pub fn map_two_8x16x4<F>(v1: uint8x16x4_t, v2: uint8x16x4_t, mut f: F) -> uint8x16x4_t
1299    where
1300        F: FnMut(uint8x16_t, uint8x16_t) -> uint8x16_t,
1301    {
1302        uint8x16x4_t(f(v1.0, v2.0), f(v1.1, v2.1), f(v1.2, v2.2), f(v1.3, v2.3))
1303    }
1304}