packed_seq/
ascii_seq.rs

1use crate::{intrinsics::transpose, packed_seq::read_slice};
2
3use super::*;
4
5/// A `Vec<u8>` representing an ASCII-encoded DNA sequence of `ACGTacgt`.
6///
7/// Other characters will be mapped into `[0, 4)` via `(c>>1)&3`, or may cause panics.
8#[derive(Clone, Debug, Default, MemSize, MemDbg)]
9#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
10#[cfg_attr(feature = "epserde", derive(epserde::Epserde))]
11pub struct AsciiSeqVec {
12    pub seq: Vec<u8>,
13}
14
15/// A `&[u8]` representing an ASCII-encoded DNA sequence of `ACGTacgt`.
16///
17/// Other characters will be mapped into `[0, 4)` via `(c>>1)&3`, or may cause panics.
18#[derive(Copy, Clone, Debug, MemSize, MemDbg, PartialEq, Eq, PartialOrd, Ord)]
19pub struct AsciiSeq<'s>(pub &'s [u8]);
20
21/// Maps ASCII to `[0, 4)` on the fly.
22/// Prefer first packing into a `PackedSeqVec` for storage.
23impl<'s> Seq<'s> for AsciiSeq<'s> {
24    /// Each input byte stores a single character.
25    const BASES_PER_BYTE: usize = 1;
26    /// But each output bp only takes 2 bits!
27    const BITS_PER_CHAR: usize = 2;
28    type SeqVec = AsciiSeqVec;
29
30    #[inline(always)]
31    fn len(&self) -> usize {
32        self.0.len()
33    }
34
35    #[inline(always)]
36    fn get(&self, index: usize) -> u8 {
37        pack_char(self.0[index])
38    }
39
40    #[inline(always)]
41    fn get_ascii(&self, index: usize) -> u8 {
42        self.0[index]
43    }
44
45    #[inline(always)]
46    fn to_word(&self) -> usize {
47        let len = self.len();
48        assert!(len <= usize::BITS as usize / 2);
49
50        let mut val = 0u64;
51
52        #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
53        {
54            for i in (0..len).step_by(8) {
55                let packed_bytes = if i + 8 <= self.len() {
56                    let chunk: &[u8; 8] = &self.0[i..i + 8].try_into().unwrap();
57                    let ascii = u64::from_ne_bytes(*chunk);
58                    unsafe { std::arch::x86_64::_pext_u64(ascii, 0x0606060606060606) }
59                } else {
60                    let mut chunk: [u8; 8] = [0; 8];
61                    // Copy only part of the slice to avoid out-of-bounds indexing.
62                    chunk[..self.len() - i].copy_from_slice(self.0[i..].try_into().unwrap());
63                    let ascii = u64::from_ne_bytes(chunk);
64                    unsafe { std::arch::x86_64::_pext_u64(ascii, 0x0606060606060606) }
65                };
66                val |= packed_bytes << (i * 2);
67            }
68        }
69
70        #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))]
71        {
72            for (i, &base) in self.0.iter().enumerate() {
73                val |= (pack_char(base) as u64) << (i * 2);
74            }
75        }
76
77        val as usize
78    }
79
80    /// Convert to an owned version.
81    fn to_vec(&self) -> AsciiSeqVec {
82        AsciiSeqVec {
83            seq: self.0.to_vec(),
84        }
85    }
86
87    #[inline(always)]
88    fn slice(&self, range: Range<usize>) -> Self {
89        Self(&self.0[range])
90    }
91
92    /// Iterate the basepairs in the sequence, assuming values in `0..4`.
93    ///
94    /// NOTE: This is only efficient on x86_64 with `BMI2` support for `pext`.
95    #[inline(always)]
96    fn iter_bp(self) -> impl ExactSizeIterator<Item = u8> + Clone {
97        #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
98        {
99            let mut cache = 0;
100            (0..self.len()).map(
101                #[inline(always)]
102                move |i| {
103                    if i % 8 == 0 {
104                        if i + 8 <= self.len() {
105                            let chunk: &[u8; 8] = &self.0[i..i + 8].try_into().unwrap();
106                            let ascii = u64::from_ne_bytes(*chunk);
107                            cache = ascii >> 1;
108                        } else {
109                            let mut chunk: [u8; 8] = [0; 8];
110                            // Copy only part of the slice to avoid out-of-bounds indexing.
111                            chunk[..self.len() - i]
112                                .copy_from_slice(self.0[i..].try_into().unwrap());
113                            let ascii = u64::from_ne_bytes(chunk);
114                            cache = ascii >> 1;
115                        }
116                    }
117                    let base = cache & 0x03;
118                    cache >>= 8;
119                    base as u8
120                },
121            )
122        }
123
124        #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))]
125        self.0.iter().copied().map(pack_char)
126    }
127
128    /// Iterate the basepairs in the sequence in 8 parallel streams, assuming values in `0..4`.
129    #[inline(always)]
130    fn par_iter_bp(self, context: usize) -> (impl ExactSizeIterator<Item = S> + Clone, usize) {
131        let num_kmers = self.len().saturating_sub(context - 1);
132        let n = num_kmers.div_ceil(L);
133        let padding = L * n - num_kmers;
134
135        let offsets: [usize; 8] = from_fn(|l| (l * n)).into();
136        let mut cur = S::ZERO;
137
138        // Boxed, so it doesn't consume precious registers.
139        // Without this, cur is not always inlined into a register.
140        let mut buf = Box::new([S::ZERO; 8]);
141
142        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
143        let it = (0..par_len).map(
144            #[inline(always)]
145            move |i| {
146                if i % 4 == 0 {
147                    if i % 32 == 0 {
148                        // Read a u256 for each lane containing the next 32 characters.
149                        let data: [u32x8; 8] = from_fn(
150                            #[inline(always)]
151                            |lane| read_slice(self.0, offsets[lane] + i),
152                        );
153                        *buf = transpose(data);
154                        for x in buf.iter_mut() {
155                            *x = *x >> 1;
156                        }
157                    }
158                    cur = buf[(i % 32) / 4];
159                }
160                // Extract the last 2 bits of each character.
161                let chars = cur & S::splat(0x03);
162                // Shift remaining characters to the right.
163                cur = cur >> S::splat(8);
164                chars
165            },
166        );
167
168        (it, padding)
169    }
170
171    #[inline(always)]
172    fn par_iter_bp_delayed(
173        self,
174        context: usize,
175        delay: usize,
176    ) -> (impl ExactSizeIterator<Item = (S, S)> + Clone, usize) {
177        assert!(
178            delay < usize::MAX / 2,
179            "Delay={} should be >=0.",
180            delay as isize
181        );
182
183        let num_kmers = self.len().saturating_sub(context - 1);
184        let n = num_kmers.div_ceil(L);
185        eprintln!(
186            "len {} kmers {num_kmers} n {n} total out {}",
187            self.len(),
188            L * n
189        );
190        let padding = L * n - num_kmers;
191
192        let offsets: [usize; 8] = from_fn(|l| (l * n)).into();
193        let mut upcoming = S::ZERO;
194        let mut upcoming_d = S::ZERO;
195
196        // Even buf_len is nice to only have the write==buf_len check once.
197        // We also make it the next power of 2, for faster modulo operations.
198        // delay/4: number of bp in a u32.
199        let buf_len = (delay / 4 + 8).next_power_of_two();
200        let buf_mask = buf_len - 1;
201        let mut buf = vec![S::ZERO; buf_len];
202        let mut write_idx = 0;
203        // We compensate for the first delay/16 triggers of the check below that
204        // happen before the delay is actually reached.
205        let mut read_idx = (buf_len - delay / 4) % buf_len;
206
207        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
208        let it = (0..par_len).map(
209            #[inline(always)]
210            move |i| {
211                if i % 4 == 0 {
212                    if i % 32 == 0 {
213                        // Read a u256 for each lane containing the next 32 characters.
214                        let data: [u32x8; 8] = from_fn(
215                            #[inline(always)]
216                            |lane| read_slice(self.0, offsets[lane] + i),
217                        );
218                        unsafe {
219                            let mut_array: &mut [u32x8; 8] = buf
220                                .get_unchecked_mut(write_idx..write_idx + 8)
221                                .try_into()
222                                .unwrap_unchecked();
223                            *mut_array = transpose(data);
224                            for x in mut_array {
225                                *x = *x >> 1;
226                            }
227                        }
228                    }
229                    upcoming = buf[write_idx];
230                    write_idx += 1;
231                    write_idx &= buf_mask;
232                }
233                if i % 4 == delay % 4 {
234                    unsafe { assert_unchecked(read_idx < buf.len()) };
235                    upcoming_d = buf[read_idx];
236                    read_idx += 1;
237                    read_idx &= buf_mask;
238                }
239                // Extract the last 2 bits of each character.
240                let chars = upcoming & S::splat(0x03);
241                let chars_d = upcoming_d & S::splat(0x03);
242                // Shift remaining characters to the right.
243                upcoming = upcoming >> S::splat(8);
244                upcoming_d = upcoming_d >> S::splat(8);
245                (chars, chars_d)
246            },
247        );
248
249        (it, padding)
250    }
251
252    #[inline(always)]
253    fn par_iter_bp_delayed_2(
254        self,
255        context: usize,
256        delay1: usize,
257        delay2: usize,
258    ) -> (impl ExactSizeIterator<Item = (S, S, S)> + Clone, usize) {
259        assert!(delay1 <= delay2, "Delay1 must be at most delay2.");
260
261        let num_kmers = self.len().saturating_sub(context - 1);
262        let n = num_kmers.div_ceil(L);
263        let padding = L * n - num_kmers;
264
265        let offsets: [usize; 8] = from_fn(|l| (l * n)).into();
266
267        let mut upcoming = S::ZERO;
268        let mut upcoming_d1 = S::ZERO;
269        let mut upcoming_d2 = S::ZERO;
270
271        // Even buf_len is nice to only have the write==buf_len check once.
272        // We also make it the next power of 2, for faster modulo operations.
273        // delay/4: number of bp in a u32.
274        let buf_len = (delay2 / 4 + 8).next_power_of_two();
275        let buf_mask = buf_len - 1;
276        let mut buf = vec![S::ZERO; buf_len];
277        let mut write_idx = 0;
278        // We compensate for the first delay/16 triggers of the check below that
279        // happen before the delay is actually reached.
280        let mut read_idx1 = (buf_len - delay1 / 4) % buf_len;
281        let mut read_idx2 = (buf_len - delay2 / 4) % buf_len;
282
283        let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
284        let it = (0..par_len).map(
285            #[inline(always)]
286            move |i| {
287                if i % 4 == 0 {
288                    if i % 32 == 0 {
289                        // Read a u256 for each lane containing the next 32 characters.
290                        let data: [u32x8; 8] = from_fn(
291                            #[inline(always)]
292                            |lane| read_slice(self.0, offsets[lane] + i),
293                        );
294                        unsafe {
295                            let mut_array: &mut [u32x8; 8] = buf
296                                .get_unchecked_mut(write_idx..write_idx + 8)
297                                .try_into()
298                                .unwrap_unchecked();
299                            *mut_array = transpose(data);
300                            for x in mut_array {
301                                *x = *x >> 1;
302                            }
303                        }
304                    }
305                    upcoming = buf[write_idx];
306                    write_idx += 1;
307                    write_idx &= buf_mask;
308                }
309                if i % 4 == delay1 % 4 {
310                    unsafe { assert_unchecked(read_idx1 < buf.len()) };
311                    upcoming_d1 = buf[read_idx1];
312                    read_idx1 += 1;
313                    read_idx1 &= buf_mask;
314                }
315                if i % 4 == delay2 % 4 {
316                    unsafe { assert_unchecked(read_idx2 < buf.len()) };
317                    upcoming_d2 = buf[read_idx2];
318                    read_idx2 += 1;
319                    read_idx2 &= buf_mask;
320                }
321                // Extract the last 2 bits of each character.
322                let chars = upcoming & S::splat(0x03);
323                let chars_d1 = upcoming_d1 & S::splat(0x03);
324                let chars_d2 = upcoming_d2 & S::splat(0x03);
325                // Shift remaining characters to the right.
326                upcoming = upcoming >> S::splat(8);
327                upcoming_d1 = upcoming_d1 >> S::splat(8);
328                upcoming_d2 = upcoming_d2 >> S::splat(8);
329                (chars, chars_d1, chars_d2)
330            },
331        );
332
333        (it, padding)
334    }
335
336    // TODO: This is not very optimized.
337    fn cmp_lcp(&self, other: &Self) -> (std::cmp::Ordering, usize) {
338        for i in 0..self.len().min(other.len()) {
339            if self.0[i] != other.0[i] {
340                return (self.0[i].cmp(&other.0[i]), i);
341            }
342        }
343        (self.len().cmp(&other.len()), self.len().min(other.len()))
344    }
345}
346
347impl AsciiSeqVec {
348    pub fn from_vec(seq: Vec<u8>) -> Self {
349        Self { seq }
350    }
351}
352
353impl SeqVec for AsciiSeqVec {
354    type Seq<'s> = AsciiSeq<'s>;
355
356    /// Get the underlying ASCII text.
357    fn into_raw(self) -> Vec<u8> {
358        self.seq
359    }
360
361    #[inline(always)]
362    fn as_slice(&self) -> Self::Seq<'_> {
363        AsciiSeq(self.seq.as_slice())
364    }
365
366    fn push_seq(&mut self, seq: AsciiSeq) -> Range<usize> {
367        let start = self.seq.len();
368        let end = start + seq.len();
369        let range = start..end;
370        self.seq.extend(seq.0);
371        range
372    }
373
374    fn push_ascii(&mut self, seq: &[u8]) -> Range<usize> {
375        self.push_seq(AsciiSeq(seq))
376    }
377
378    fn random(n: usize) -> Self {
379        let mut seq = vec![0; n];
380        rand::rngs::SmallRng::from_os_rng().fill_bytes(&mut seq);
381        Self {
382            seq: seq.into_iter().map(|b| b"ACGT"[b as usize % 4]).collect(),
383        }
384    }
385}