packed_seq/
ascii_seq.rs

1use crate::{intrinsics::transpose, packed_seq::read_slice};
2
3use super::*;
4
5/// A `Vec<u8>` representing an ASCII-encoded DNA sequence of `ACGTacgt`.
6///
7/// Other characters will be mapped into `[0, 4)` via `(c>>1)&3`, or may cause panics.
8#[derive(Clone, Debug, Default, MemSize, MemDbg)]
9#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
10#[cfg_attr(feature = "epserde", derive(epserde::Epserde))]
11pub struct AsciiSeqVec {
12    pub seq: Vec<u8>,
13}
14
15/// A `&[u8]` representing an ASCII-encoded DNA sequence of `ACGTacgt`.
16///
17/// Other characters will be mapped into `[0, 4)` via `(c>>1)&3`, or may cause panics.
18#[derive(Copy, Clone, Debug, MemSize, MemDbg, PartialEq, Eq, PartialOrd, Ord)]
19pub struct AsciiSeq<'s>(pub &'s [u8]);
20
21/// Maps ASCII to `[0, 4)` on the fly.
22/// Prefer first packing into a `PackedSeqVec` for storage.
23impl<'s> Seq<'s> for AsciiSeq<'s> {
24    /// Each input byte stores a single character.
25    const BASES_PER_BYTE: usize = 1;
26    /// But each output bp only takes 2 bits!
27    const BITS_PER_CHAR: usize = 2;
28    type SeqVec = AsciiSeqVec;
29
30    #[inline(always)]
31    fn len(&self) -> usize {
32        self.0.len()
33    }
34
35    #[inline(always)]
36    fn is_empty(&self) -> bool {
37        self.0.is_empty()
38    }
39
40    #[inline(always)]
41    fn get(&self, index: usize) -> u8 {
42        pack_char(self.0[index])
43    }
44
45    #[inline(always)]
46    fn get_ascii(&self, index: usize) -> u8 {
47        self.0[index]
48    }
49
50    #[inline(always)]
51    fn as_u64(&self) -> u64 {
52        let len = self.len();
53        assert!(len <= u64::BITS as usize / 2);
54
55        let mut val = 0u64;
56
57        #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
58        {
59            for i in (0..len).step_by(8) {
60                let packed_bytes = if i + 8 <= self.len() {
61                    let chunk: &[u8; 8] = &self.0[i..i + 8].try_into().unwrap();
62                    let ascii = u64::from_ne_bytes(*chunk);
63                    unsafe { std::arch::x86_64::_pext_u64(ascii, 0x0606060606060606) }
64                } else {
65                    let mut chunk: [u8; 8] = [0; 8];
66                    // Copy only part of the slice to avoid out-of-bounds indexing.
67                    chunk[..self.len() - i].copy_from_slice(self.0[i..].try_into().unwrap());
68                    let ascii = u64::from_ne_bytes(chunk);
69                    unsafe { std::arch::x86_64::_pext_u64(ascii, 0x0606060606060606) }
70                };
71                val |= packed_bytes << (i * 2);
72            }
73        }
74
75        #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))]
76        {
77            for (i, &base) in self.0.iter().enumerate() {
78                val |= (pack_char(base) as u64) << (i * 2);
79            }
80        }
81
82        val
83    }
84
85    #[inline(always)]
86    fn revcomp_as_u64(&self) -> u64 {
87        packed_seq::revcomp_u64(self.as_u64(), self.len())
88    }
89
90    /// Convert to an owned version.
91    #[inline(always)]
92    fn to_vec(&self) -> AsciiSeqVec {
93        AsciiSeqVec {
94            seq: self.0.to_vec(),
95        }
96    }
97
98    #[inline(always)]
99    fn to_revcomp(&self) -> AsciiSeqVec {
100        AsciiSeqVec {
101            seq: self
102                .0
103                .iter()
104                .rev()
105                .copied()
106                .map(packed_seq::complement_char)
107                .collect(),
108        }
109    }
110
111    #[inline(always)]
112    fn slice(&self, range: Range<usize>) -> Self {
113        Self(&self.0[range])
114    }
115
116    /// Iterate the basepairs in the sequence, assuming values in `0..4`.
117    ///
118    /// NOTE: This is only efficient on x86_64 with `BMI2` support for `pext`.
119    #[inline(always)]
120    fn iter_bp(self) -> impl ExactSizeIterator<Item = u8> + Clone {
121        #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
122        {
123            let mut cache = 0;
124            (0..self.len()).map(
125                #[inline(always)]
126                move |i| {
127                    if i % 8 == 0 {
128                        if i + 8 <= self.len() {
129                            let chunk: &[u8; 8] = &self.0[i..i + 8].try_into().unwrap();
130                            let ascii = u64::from_ne_bytes(*chunk);
131                            cache = ascii >> 1;
132                        } else {
133                            let mut chunk: [u8; 8] = [0; 8];
134                            // Copy only part of the slice to avoid out-of-bounds indexing.
135                            chunk[..self.len() - i]
136                                .copy_from_slice(self.0[i..].try_into().unwrap());
137                            let ascii = u64::from_ne_bytes(chunk);
138                            cache = ascii >> 1;
139                        }
140                    }
141                    let base = cache & 0x03;
142                    cache >>= 8;
143                    base as u8
144                },
145            )
146        }
147
148        #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))]
149        self.0.iter().copied().map(pack_char)
150    }
151
152    /// Iterate the basepairs in the sequence in 8 parallel streams, assuming values in `0..4`.
153    #[inline(always)]
154    fn par_iter_bp(self, context: usize) -> (impl ExactSizeIterator<Item = S> + Clone, usize) {
155        let num_kmers = self.len().saturating_sub(context - 1);
156        let n = num_kmers.div_ceil(L);
157        let padding = L * n - num_kmers;
158
159        let offsets: [usize; 8] = from_fn(|l| (l * n));
160        let mut cur = S::ZERO;
161
162        // Boxed, so it doesn't consume precious registers.
163        // Without this, cur is not always inlined into a register.
164        let mut buf = Box::new([S::ZERO; 8]);
165
166        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
167        let it = (0..par_len).map(
168            #[inline(always)]
169            move |i| {
170                if i % 4 == 0 {
171                    if i % 32 == 0 {
172                        // Read a u256 for each lane containing the next 32 characters.
173                        let data: [u32x8; 8] = from_fn(
174                            #[inline(always)]
175                            |lane| read_slice(self.0, offsets[lane] + i),
176                        );
177                        *buf = transpose(data);
178                        for x in buf.iter_mut() {
179                            *x = *x >> 1;
180                        }
181                    }
182                    cur = buf[(i % 32) / 4];
183                }
184                // Extract the last 2 bits of each character.
185                let chars = cur & S::splat(0x03);
186                // Shift remaining characters to the right.
187                cur = cur >> S::splat(8);
188                chars
189            },
190        );
191
192        (it, padding)
193    }
194
195    #[inline(always)]
196    fn par_iter_bp_delayed(
197        self,
198        context: usize,
199        delay: usize,
200    ) -> (impl ExactSizeIterator<Item = (S, S)> + Clone, usize) {
201        assert!(
202            delay < usize::MAX / 2,
203            "Delay={} should be >=0.",
204            delay as isize
205        );
206
207        let num_kmers = self.len().saturating_sub(context - 1);
208        let n = num_kmers.div_ceil(L);
209        let padding = L * n - num_kmers;
210
211        let offsets: [usize; 8] = from_fn(|l| (l * n));
212        let mut upcoming = S::ZERO;
213        let mut upcoming_d = S::ZERO;
214
215        // Even buf_len is nice to only have the write==buf_len check once.
216        // We also make it the next power of 2, for faster modulo operations.
217        // delay/4: number of bp in a u32.
218        let buf_len = (delay / 4 + 8).next_power_of_two();
219        let buf_mask = buf_len - 1;
220        let mut buf = vec![S::ZERO; buf_len];
221        let mut write_idx = 0;
222        // We compensate for the first delay/16 triggers of the check below that
223        // happen before the delay is actually reached.
224        let mut read_idx = (buf_len - delay / 4) % buf_len;
225
226        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
227        let it = (0..par_len).map(
228            #[inline(always)]
229            move |i| {
230                if i % 4 == 0 {
231                    if i % 32 == 0 {
232                        // Read a u256 for each lane containing the next 32 characters.
233                        let data: [u32x8; 8] = from_fn(
234                            #[inline(always)]
235                            |lane| read_slice(self.0, offsets[lane] + i),
236                        );
237                        unsafe {
238                            let mut_array: &mut [u32x8; 8] = buf
239                                .get_unchecked_mut(write_idx..write_idx + 8)
240                                .try_into()
241                                .unwrap_unchecked();
242                            *mut_array = transpose(data);
243                            for x in mut_array {
244                                *x = *x >> 1;
245                            }
246                        }
247                    }
248                    upcoming = buf[write_idx];
249                    write_idx += 1;
250                    write_idx &= buf_mask;
251                }
252                if i % 4 == delay % 4 {
253                    unsafe { assert_unchecked(read_idx < buf.len()) };
254                    upcoming_d = buf[read_idx];
255                    read_idx += 1;
256                    read_idx &= buf_mask;
257                }
258                // Extract the last 2 bits of each character.
259                let chars = upcoming & S::splat(0x03);
260                let chars_d = upcoming_d & S::splat(0x03);
261                // Shift remaining characters to the right.
262                upcoming = upcoming >> S::splat(8);
263                upcoming_d = upcoming_d >> S::splat(8);
264                (chars, chars_d)
265            },
266        );
267
268        (it, padding)
269    }
270
271    #[inline(always)]
272    fn par_iter_bp_delayed_2(
273        self,
274        context: usize,
275        delay1: usize,
276        delay2: usize,
277    ) -> (impl ExactSizeIterator<Item = (S, S, S)> + Clone, usize) {
278        assert!(delay1 <= delay2, "Delay1 must be at most delay2.");
279
280        let num_kmers = self.len().saturating_sub(context - 1);
281        let n = num_kmers.div_ceil(L);
282        let padding = L * n - num_kmers;
283
284        let offsets: [usize; 8] = from_fn(|l| (l * n));
285
286        let mut upcoming = S::ZERO;
287        let mut upcoming_d1 = S::ZERO;
288        let mut upcoming_d2 = S::ZERO;
289
290        // Even buf_len is nice to only have the write==buf_len check once.
291        // We also make it the next power of 2, for faster modulo operations.
292        // delay/4: number of bp in a u32.
293        let buf_len = (delay2 / 4 + 8).next_power_of_two();
294        let buf_mask = buf_len - 1;
295        let mut buf = vec![S::ZERO; buf_len];
296        let mut write_idx = 0;
297        // We compensate for the first delay/16 triggers of the check below that
298        // happen before the delay is actually reached.
299        let mut read_idx1 = (buf_len - delay1 / 4) % buf_len;
300        let mut read_idx2 = (buf_len - delay2 / 4) % buf_len;
301
302        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
303        let it = (0..par_len).map(
304            #[inline(always)]
305            move |i| {
306                if i % 4 == 0 {
307                    if i % 32 == 0 {
308                        // Read a u256 for each lane containing the next 32 characters.
309                        let data: [u32x8; 8] = from_fn(
310                            #[inline(always)]
311                            |lane| read_slice(self.0, offsets[lane] + i),
312                        );
313                        unsafe {
314                            let mut_array: &mut [u32x8; 8] = buf
315                                .get_unchecked_mut(write_idx..write_idx + 8)
316                                .try_into()
317                                .unwrap_unchecked();
318                            *mut_array = transpose(data);
319                            for x in mut_array {
320                                *x = *x >> 1;
321                            }
322                        }
323                    }
324                    upcoming = buf[write_idx];
325                    write_idx += 1;
326                    write_idx &= buf_mask;
327                }
328                if i % 4 == delay1 % 4 {
329                    unsafe { assert_unchecked(read_idx1 < buf.len()) };
330                    upcoming_d1 = buf[read_idx1];
331                    read_idx1 += 1;
332                    read_idx1 &= buf_mask;
333                }
334                if i % 4 == delay2 % 4 {
335                    unsafe { assert_unchecked(read_idx2 < buf.len()) };
336                    upcoming_d2 = buf[read_idx2];
337                    read_idx2 += 1;
338                    read_idx2 &= buf_mask;
339                }
340                // Extract the last 2 bits of each character.
341                let chars = upcoming & S::splat(0x03);
342                let chars_d1 = upcoming_d1 & S::splat(0x03);
343                let chars_d2 = upcoming_d2 & S::splat(0x03);
344                // Shift remaining characters to the right.
345                upcoming = upcoming >> S::splat(8);
346                upcoming_d1 = upcoming_d1 >> S::splat(8);
347                upcoming_d2 = upcoming_d2 >> S::splat(8);
348                (chars, chars_d1, chars_d2)
349            },
350        );
351
352        (it, padding)
353    }
354
355    // TODO: This is not very optimized.
356    fn cmp_lcp(&self, other: &Self) -> (std::cmp::Ordering, usize) {
357        for i in 0..self.len().min(other.len()) {
358            if self.0[i] != other.0[i] {
359                return (self.0[i].cmp(&other.0[i]), i);
360            }
361        }
362        (self.len().cmp(&other.len()), self.len().min(other.len()))
363    }
364}
365
366impl AsciiSeqVec {
367    #[inline(always)]
368    pub const fn from_vec(seq: Vec<u8>) -> Self {
369        Self { seq }
370    }
371}
372
373impl SeqVec for AsciiSeqVec {
374    type Seq<'s> = AsciiSeq<'s>;
375
376    /// Get the underlying ASCII text.
377    #[inline(always)]
378    fn into_raw(self) -> Vec<u8> {
379        self.seq
380    }
381
382    #[inline(always)]
383    fn as_slice(&self) -> Self::Seq<'_> {
384        AsciiSeq(self.seq.as_slice())
385    }
386
387    #[inline(always)]
388    fn len(&self) -> usize {
389        self.seq.len()
390    }
391
392    #[inline(always)]
393    fn is_empty(&self) -> bool {
394        self.seq.is_empty()
395    }
396
397    #[inline(always)]
398    fn clear(&mut self) {
399        self.seq.clear()
400    }
401
402    #[inline(always)]
403    fn push_seq(&mut self, seq: AsciiSeq) -> Range<usize> {
404        let start = self.seq.len();
405        let end = start + seq.len();
406        let range = start..end;
407        self.seq.extend(seq.0);
408        range
409    }
410
411    #[inline(always)]
412    fn push_ascii(&mut self, seq: &[u8]) -> Range<usize> {
413        self.push_seq(AsciiSeq(seq))
414    }
415
416    #[cfg(feature = "rand")]
417    fn random(n: usize) -> Self {
418        use rand::{RngCore, SeedableRng};
419
420        let mut seq = vec![0; n];
421        rand::rngs::SmallRng::from_os_rng().fill_bytes(&mut seq);
422        Self {
423            seq: seq.into_iter().map(|b| b"ACGT"[b as usize % 4]).collect(),
424        }
425    }
426}