packed_seq/
ascii_seq.rs

1use crate::{intrinsics::transpose, packed_seq::read_slice_32, padded_it::ChunkIt};
2
3use super::*;
4
5/// A `Vec<u8>` representing an ASCII-encoded DNA sequence of `ACGTacgt`.
6///
7/// Other characters will be mapped into `[0, 4)` via `(c>>1)&3`, or may cause panics.
8#[derive(Clone, Debug, Default, MemSize, MemDbg)]
9#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
10#[cfg_attr(feature = "epserde", derive(epserde::Epserde))]
11pub struct AsciiSeqVec {
12    pub seq: Vec<u8>,
13}
14
15/// A `&[u8]` representing an ASCII-encoded DNA sequence of `ACGTacgt`.
16///
17/// Other characters will be mapped into `[0, 4)` via `(c>>1)&3`, or may cause panics.
18#[derive(Copy, Clone, Debug, MemSize, MemDbg, PartialEq, Eq, PartialOrd, Ord)]
19pub struct AsciiSeq<'s>(pub &'s [u8]);
20
21/// Maps ASCII to `[0, 4)` on the fly.
22/// Prefer first packing into a `PackedSeqVec` for storage.
23impl<'s> Seq<'s> for AsciiSeq<'s> {
24    /// Each input byte stores a single character.
25    const BASES_PER_BYTE: usize = 1;
26    /// But each output bp only takes 2 bits!
27    const BITS_PER_CHAR: usize = 2;
28    type SeqVec = AsciiSeqVec;
29
30    #[inline(always)]
31    fn len(&self) -> usize {
32        self.0.len()
33    }
34
35    #[inline(always)]
36    fn is_empty(&self) -> bool {
37        self.0.is_empty()
38    }
39
40    #[inline(always)]
41    fn get(&self, index: usize) -> u8 {
42        pack_char(self.0[index])
43    }
44
45    #[inline(always)]
46    fn get_ascii(&self, index: usize) -> u8 {
47        self.0[index]
48    }
49
50    #[inline(always)]
51    fn as_u64(&self) -> u64 {
52        let len = self.len();
53        assert!(len <= u64::BITS as usize / 2);
54
55        let mut val = 0u64;
56
57        #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
58        {
59            for i in (0..len).step_by(8) {
60                let packed_bytes = if i + 8 <= self.len() {
61                    let chunk: &[u8; 8] = &self.0[i..i + 8].try_into().unwrap();
62                    let ascii = u64::from_ne_bytes(*chunk);
63                    unsafe { std::arch::x86_64::_pext_u64(ascii, 0x0606060606060606) }
64                } else {
65                    let mut chunk: [u8; 8] = [0; 8];
66                    // Copy only part of the slice to avoid out-of-bounds indexing.
67                    chunk[..self.len() - i].copy_from_slice(self.0[i..].try_into().unwrap());
68                    let ascii = u64::from_ne_bytes(chunk);
69                    unsafe { std::arch::x86_64::_pext_u64(ascii, 0x0606060606060606) }
70                };
71                val |= packed_bytes << (i * 2);
72            }
73        }
74
75        #[cfg(target_feature = "neon")]
76        {
77            use core::arch::aarch64::{vandq_u8, vdup_n_u8, vld1q_u8, vpadd_u8, vshlq_u8};
78            use core::mem::transmute;
79
80            for i in (0..len).step_by(16) {
81                let packed_bytes: u64 = if i + 16 <= self.len() {
82                    unsafe {
83                        let ascii = vld1q_u8(self.0.as_ptr().add(i));
84                        let masked_bits = vandq_u8(ascii, transmute([6i8; 16]));
85                        let (bits_0, bits_1) = transmute(vshlq_u8(
86                            masked_bits,
87                            transmute([-1i8, 1, 3, 5, -1, 1, 3, 5, -1, 1, 3, 5, -1, 1, 3, 5]),
88                        ));
89                        let half_packed = vpadd_u8(bits_0, bits_1);
90                        let packed = vpadd_u8(half_packed, vdup_n_u8(0));
91                        transmute(packed)
92                    }
93                } else {
94                    let mut chunk: [u8; 16] = [0; 16];
95                    // Copy only part of the slice to avoid out-of-bounds indexing.
96                    chunk[..self.len() - i].copy_from_slice(self.0[i..].try_into().unwrap());
97                    unsafe {
98                        let ascii = vld1q_u8(chunk.as_ptr());
99                        let masked_bits = vandq_u8(ascii, transmute([6i8; 16]));
100                        let (bits_0, bits_1) = transmute(vshlq_u8(
101                            masked_bits,
102                            transmute([-1i8, 1, 3, 5, -1, 1, 3, 5, -1, 1, 3, 5, -1, 1, 3, 5]),
103                        ));
104                        let half_packed = vpadd_u8(bits_0, bits_1);
105                        let packed = vpadd_u8(half_packed, vdup_n_u8(0));
106                        transmute(packed)
107                    }
108                };
109                val |= packed_bytes << (i * 2);
110            }
111        }
112
113        #[cfg(not(any(
114            all(target_arch = "x86_64", target_feature = "bmi2"),
115            target_feature = "neon"
116        )))]
117        {
118            for (i, &base) in self.0.iter().enumerate() {
119                val |= (pack_char(base) as u64) << (i * 2);
120            }
121        }
122
123        val
124    }
125
126    // TODO: Dedup against as_u64.
127    #[inline(always)]
128    fn as_u128(&self) -> u128 {
129        let len = self.len();
130        assert!(len <= u128::BITS as usize / 2);
131
132        let mut val = 0u128;
133
134        #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
135        {
136            for i in (0..len).step_by(8) {
137                let packed_bytes = if i + 8 <= self.len() {
138                    let chunk: &[u8; 8] = &self.0[i..i + 8].try_into().unwrap();
139                    let ascii = u64::from_ne_bytes(*chunk);
140                    unsafe { std::arch::x86_64::_pext_u64(ascii, 0x0606060606060606) }
141                } else {
142                    let mut chunk: [u8; 8] = [0; 8];
143                    // Copy only part of the slice to avoid out-of-bounds indexing.
144                    chunk[..self.len() - i].copy_from_slice(self.0[i..].try_into().unwrap());
145                    let ascii = u64::from_ne_bytes(chunk);
146                    unsafe { std::arch::x86_64::_pext_u64(ascii, 0x0606060606060606) }
147                };
148                val |= (packed_bytes as u128) << (i * 2);
149            }
150        }
151
152        #[cfg(target_feature = "neon")]
153        {
154            use core::arch::aarch64::{vandq_u8, vdup_n_u8, vld1q_u8, vpadd_u8, vshlq_u8};
155            use core::mem::transmute;
156
157            for i in (0..len).step_by(16) {
158                let packed_bytes: u64 = if i + 16 <= self.len() {
159                    unsafe {
160                        let ascii = vld1q_u8(self.0.as_ptr().add(i));
161                        let masked_bits = vandq_u8(ascii, transmute([6i8; 16]));
162                        let (bits_0, bits_1) = transmute(vshlq_u8(
163                            masked_bits,
164                            transmute([-1i8, 1, 3, 5, -1, 1, 3, 5, -1, 1, 3, 5, -1, 1, 3, 5]),
165                        ));
166                        let half_packed = vpadd_u8(bits_0, bits_1);
167                        let packed = vpadd_u8(half_packed, vdup_n_u8(0));
168                        transmute(packed)
169                    }
170                } else {
171                    let mut chunk: [u8; 16] = [0; 16];
172                    // Copy only part of the slice to avoid out-of-bounds indexing.
173                    chunk[..self.len() - i].copy_from_slice(self.0[i..].try_into().unwrap());
174                    unsafe {
175                        let ascii = vld1q_u8(chunk.as_ptr());
176                        let masked_bits = vandq_u8(ascii, transmute([6i8; 16]));
177                        let (bits_0, bits_1) = transmute(vshlq_u8(
178                            masked_bits,
179                            transmute([-1i8, 1, 3, 5, -1, 1, 3, 5, -1, 1, 3, 5, -1, 1, 3, 5]),
180                        ));
181                        let half_packed = vpadd_u8(bits_0, bits_1);
182                        let packed = vpadd_u8(half_packed, vdup_n_u8(0));
183                        transmute(packed)
184                    }
185                };
186                val |= (packed_bytes as u128) << (i * 2);
187            }
188        }
189
190        #[cfg(not(any(
191            all(target_arch = "x86_64", target_feature = "bmi2"),
192            target_feature = "neon"
193        )))]
194        {
195            for (i, &base) in self.0.iter().enumerate() {
196                val |= (pack_char(base) as u128) << (i * 2);
197            }
198        }
199
200        val
201    }
202
203    #[inline(always)]
204    fn revcomp_as_u64(&self) -> u64 {
205        packed_seq::revcomp_u64(self.as_u64(), self.len())
206    }
207
208    #[inline(always)]
209    fn revcomp_as_u128(&self) -> u128 {
210        packed_seq::revcomp_u128(self.as_u128(), self.len())
211    }
212
213    /// Convert to an owned version.
214    #[inline(always)]
215    fn to_vec(&self) -> AsciiSeqVec {
216        AsciiSeqVec {
217            seq: self.0.to_vec(),
218        }
219    }
220
221    #[inline(always)]
222    fn to_revcomp(&self) -> AsciiSeqVec {
223        AsciiSeqVec {
224            seq: self
225                .0
226                .iter()
227                .rev()
228                .copied()
229                .map(packed_seq::complement_char)
230                .collect(),
231        }
232    }
233
234    #[inline(always)]
235    fn slice(&self, range: Range<usize>) -> Self {
236        Self(&self.0[range])
237    }
238
239    /// Iterate the basepairs in the sequence, assuming values in `0..4`.
240    ///
241    /// NOTE: This is only efficient on x86_64 with `BMI2` support for `pext`.
242    #[inline(always)]
243    fn iter_bp(self) -> impl ExactSizeIterator<Item = u8> {
244        #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
245        {
246            let mut cache = 0;
247            (0..self.len()).map(
248                #[inline(always)]
249                move |i| {
250                    if i % 8 == 0 {
251                        if i + 8 <= self.len() {
252                            let chunk: &[u8; 8] = &self.0[i..i + 8].try_into().unwrap();
253                            let ascii = u64::from_ne_bytes(*chunk);
254                            cache = ascii >> 1;
255                        } else {
256                            let mut chunk: [u8; 8] = [0; 8];
257                            // Copy only part of the slice to avoid out-of-bounds indexing.
258                            chunk[..self.len() - i]
259                                .copy_from_slice(self.0[i..].try_into().unwrap());
260                            let ascii = u64::from_ne_bytes(chunk);
261                            cache = ascii >> 1;
262                        }
263                    }
264                    let base = cache & 0x03;
265                    cache >>= 8;
266                    base as u8
267                },
268            )
269        }
270
271        #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))]
272        self.0.iter().copied().map(pack_char)
273    }
274
275    /// Iterate the basepairs in the sequence in 8 parallel streams, assuming values in `0..4`.
276    #[inline(always)]
277    fn par_iter_bp(self, context: usize) -> PaddedIt<impl ChunkIt<S>> {
278        let num_kmers = self.len().saturating_sub(context - 1);
279        let n = num_kmers.div_ceil(L);
280        let padding = L * n - num_kmers;
281
282        let offsets: [usize; 8] = from_fn(|l| l * n);
283        let mut cur = S::ZERO;
284
285        // Boxed, so it doesn't consume precious registers.
286        // Without this, cur is not always inlined into a register.
287        let mut buf = Box::new([S::ZERO; 8]);
288
289        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
290        let it = (0..par_len).map(
291            #[inline(always)]
292            move |i| {
293                if i % 4 == 0 {
294                    if i % 32 == 0 {
295                        // Read a u256 for each lane containing the next 32 characters.
296                        let data: [u32x8; 8] = from_fn(
297                            #[inline(always)]
298                            |lane| read_slice_32(self.0, offsets[lane] + i),
299                        );
300                        *buf = transpose(data);
301                        for x in buf.iter_mut() {
302                            *x = *x >> 1;
303                        }
304                    }
305                    cur = buf[(i % 32) / 4];
306                }
307                // Extract the last 2 bits of each character.
308                let chars = cur & S::splat(0x03);
309                // Shift remaining characters to the right.
310                cur = cur >> S::splat(8);
311                chars
312            },
313        );
314
315        PaddedIt { it, padding }
316    }
317
318    #[inline(always)]
319    fn par_iter_bp_delayed(
320        self,
321        context: usize,
322        Delay(delay): Delay,
323    ) -> PaddedIt<impl ChunkIt<(S, S)>> {
324        assert!(
325            delay < usize::MAX / 2,
326            "Delay={} should be >=0.",
327            delay as isize
328        );
329
330        let num_kmers = self.len().saturating_sub(context - 1);
331        let n = num_kmers.div_ceil(L);
332        let padding = L * n - num_kmers;
333
334        let offsets: [usize; 8] = from_fn(|l| l * n);
335        let mut upcoming = S::ZERO;
336        let mut upcoming_d = S::ZERO;
337
338        // Even buf_len is nice to only have the write==buf_len check once.
339        // We also make it the next power of 2, for faster modulo operations.
340        // delay/4: number of bp in a u32.
341        let buf_len = (delay / 4 + 8).next_power_of_two();
342        let buf_mask = buf_len - 1;
343        let mut buf = vec![S::ZERO; buf_len];
344        let mut write_idx = 0;
345        // We compensate for the first delay/16 triggers of the check below that
346        // happen before the delay is actually reached.
347        let mut read_idx = (buf_len - delay / 4) % buf_len;
348
349        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
350        let it = (0..par_len).map(
351            #[inline(always)]
352            move |i| {
353                if i % 4 == 0 {
354                    if i % 32 == 0 {
355                        // Read a u256 for each lane containing the next 32 characters.
356                        let data: [u32x8; 8] = from_fn(
357                            #[inline(always)]
358                            |lane| read_slice_32(self.0, offsets[lane] + i),
359                        );
360                        unsafe {
361                            let mut_array: &mut [u32x8; 8] = buf
362                                .get_unchecked_mut(write_idx..write_idx + 8)
363                                .try_into()
364                                .unwrap_unchecked();
365                            *mut_array = transpose(data);
366                            for x in mut_array {
367                                *x = *x >> 1;
368                            }
369                        }
370                    }
371                    upcoming = buf[write_idx];
372                    write_idx += 1;
373                    write_idx &= buf_mask;
374                }
375                if i % 4 == delay % 4 {
376                    unsafe { assert_unchecked(read_idx < buf.len()) };
377                    upcoming_d = buf[read_idx];
378                    read_idx += 1;
379                    read_idx &= buf_mask;
380                }
381                // Extract the last 2 bits of each character.
382                let chars = upcoming & S::splat(0x03);
383                let chars_d = upcoming_d & S::splat(0x03);
384                // Shift remaining characters to the right.
385                upcoming = upcoming >> S::splat(8);
386                upcoming_d = upcoming_d >> S::splat(8);
387                (chars, chars_d)
388            },
389        );
390
391        PaddedIt { it, padding }
392    }
393
394    #[inline(always)]
395    fn par_iter_bp_delayed_2(
396        self,
397        context: usize,
398        Delay(delay1): Delay,
399        Delay(delay2): Delay,
400    ) -> PaddedIt<impl ChunkIt<(S, S, S)>> {
401        assert!(delay1 <= delay2, "Delay1 must be at most delay2.");
402
403        let num_kmers = self.len().saturating_sub(context - 1);
404        let n = num_kmers.div_ceil(L);
405        let padding = L * n - num_kmers;
406
407        let offsets: [usize; 8] = from_fn(|l| l * n);
408
409        let mut upcoming = S::ZERO;
410        let mut upcoming_d1 = S::ZERO;
411        let mut upcoming_d2 = S::ZERO;
412
413        // Even buf_len is nice to only have the write==buf_len check once.
414        // We also make it the next power of 2, for faster modulo operations.
415        // delay/4: number of bp in a u32.
416        let buf_len = (delay2 / 4 + 8).next_power_of_two();
417        let buf_mask = buf_len - 1;
418        let mut buf = vec![S::ZERO; buf_len];
419        let mut write_idx = 0;
420        // We compensate for the first delay/16 triggers of the check below that
421        // happen before the delay is actually reached.
422        let mut read_idx1 = (buf_len - delay1 / 4) % buf_len;
423        let mut read_idx2 = (buf_len - delay2 / 4) % buf_len;
424
425        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
426        let it = (0..par_len).map(
427            #[inline(always)]
428            move |i| {
429                if i % 4 == 0 {
430                    if i % 32 == 0 {
431                        // Read a u256 for each lane containing the next 32 characters.
432                        let data: [u32x8; 8] = from_fn(
433                            #[inline(always)]
434                            |lane| read_slice_32(self.0, offsets[lane] + i),
435                        );
436                        unsafe {
437                            let mut_array: &mut [u32x8; 8] = buf
438                                .get_unchecked_mut(write_idx..write_idx + 8)
439                                .try_into()
440                                .unwrap_unchecked();
441                            *mut_array = transpose(data);
442                            for x in mut_array {
443                                *x = *x >> 1;
444                            }
445                        }
446                    }
447                    upcoming = buf[write_idx];
448                    write_idx += 1;
449                    write_idx &= buf_mask;
450                }
451                if i % 4 == delay1 % 4 {
452                    unsafe { assert_unchecked(read_idx1 < buf.len()) };
453                    upcoming_d1 = buf[read_idx1];
454                    read_idx1 += 1;
455                    read_idx1 &= buf_mask;
456                }
457                if i % 4 == delay2 % 4 {
458                    unsafe { assert_unchecked(read_idx2 < buf.len()) };
459                    upcoming_d2 = buf[read_idx2];
460                    read_idx2 += 1;
461                    read_idx2 &= buf_mask;
462                }
463                // Extract the last 2 bits of each character.
464                let chars = upcoming & S::splat(0x03);
465                let chars_d1 = upcoming_d1 & S::splat(0x03);
466                let chars_d2 = upcoming_d2 & S::splat(0x03);
467                // Shift remaining characters to the right.
468                upcoming = upcoming >> S::splat(8);
469                upcoming_d1 = upcoming_d1 >> S::splat(8);
470                upcoming_d2 = upcoming_d2 >> S::splat(8);
471                (chars, chars_d1, chars_d2)
472            },
473        );
474
475        PaddedIt { it, padding }
476    }
477
478    // TODO: This is not very optimized.
479    fn cmp_lcp(&self, other: &Self) -> (std::cmp::Ordering, usize) {
480        for i in 0..self.len().min(other.len()) {
481            if self.0[i] != other.0[i] {
482                return (self.0[i].cmp(&other.0[i]), i);
483            }
484        }
485        (self.len().cmp(&other.len()), self.len().min(other.len()))
486    }
487}
488
489impl AsciiSeqVec {
490    #[inline(always)]
491    pub const fn from_vec(seq: Vec<u8>) -> Self {
492        Self { seq }
493    }
494}
495
496impl SeqVec for AsciiSeqVec {
497    type Seq<'s> = AsciiSeq<'s>;
498
499    /// Get the underlying ASCII text.
500    #[inline(always)]
501    fn into_raw(self) -> Vec<u8> {
502        self.seq
503    }
504
505    #[inline(always)]
506    fn as_slice(&self) -> Self::Seq<'_> {
507        AsciiSeq(self.seq.as_slice())
508    }
509
510    #[inline(always)]
511    fn len(&self) -> usize {
512        self.seq.len()
513    }
514
515    #[inline(always)]
516    fn is_empty(&self) -> bool {
517        self.seq.is_empty()
518    }
519
520    #[inline(always)]
521    fn clear(&mut self) {
522        self.seq.clear()
523    }
524
525    #[inline(always)]
526    fn push_seq(&mut self, seq: AsciiSeq) -> Range<usize> {
527        let start = self.seq.len();
528        let end = start + seq.len();
529        let range = start..end;
530        self.seq.extend(seq.0);
531        range
532    }
533
534    #[inline(always)]
535    fn push_ascii(&mut self, seq: &[u8]) -> Range<usize> {
536        self.push_seq(AsciiSeq(seq))
537    }
538
539    #[cfg(feature = "rand")]
540    fn random(n: usize) -> Self {
541        use rand::{RngCore, SeedableRng};
542
543        let mut seq = vec![0; n];
544        rand::rngs::SmallRng::from_os_rng().fill_bytes(&mut seq);
545        Self {
546            seq: seq.into_iter().map(|b| b"ACGT"[b as usize % 4]).collect(),
547        }
548    }
549}