1use traits::Seq;
2
3use crate::intrinsics::transpose;
4
5use super::*;
6
7#[derive(Copy, Clone, Debug, MemSize, MemDbg)]
9pub struct PackedSeq<'s> {
10 pub seq: &'s [u8],
12 pub offset: usize,
14 pub len: usize,
16}
17
18#[derive(Clone, Debug, Default, MemSize, MemDbg)]
20#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
21#[cfg_attr(feature = "epserde", derive(epserde::Epserde))]
22pub struct PackedSeqVec {
23 pub seq: Vec<u8>,
24 pub len: usize,
25}
26
27pub fn pack_char(base: u8) -> u8 {
29 match base {
30 b'a' | b'A' => 0,
31 b'c' | b'C' => 1,
32 b'g' | b'G' => 3,
33 b't' | b'T' => 2,
34 _ => panic!(
35 "Unexpected character '{}' with ASCII value {base}. Expected one of ACTGactg.",
36 base as char
37 ),
38 }
39}
40
41pub fn unpack_base(base: u8) -> u8 {
43 debug_assert!(base < 4, "Base {base} is not <4.");
44 b"ACTG"[base as usize]
45}
46
47pub const fn complement_char(base: u8) -> u8 {
49 match base {
50 b'A' => b'T',
51 b'C' => b'G',
52 b'G' => b'C',
53 b'T' => b'A',
54 _ => panic!("Unexpected character. Expected one of ACTGactg.",),
55 }
56}
57
58pub const fn complement_base(base: u8) -> u8 {
60 base ^ 2
61}
62
63pub fn complement_base_simd(base: u32x8) -> u32x8 {
65 base ^ u32x8::splat(2)
66}
67
68impl<'s> PackedSeq<'s> {
69 #[inline(always)]
71 pub fn normalize(&self) -> Self {
72 let start = self.offset / 4;
73 let end = (self.offset + self.len).div_ceil(4);
74 Self {
75 seq: &self.seq[start..end],
76 offset: self.offset % 4,
77 len: self.len,
78 }
79 }
80
81 pub fn unpack(&self) -> Vec<u8> {
83 self.iter_bp().map(unpack_base).collect()
84 }
85}
86
87#[inline(always)]
88pub(crate) fn read_slice(seq: &[u8], idx: usize) -> u32x8 {
89 let mut result = [0u8; 32];
91 let num_bytes = 32.min(seq.len().saturating_sub(idx));
92 unsafe {
93 let src = seq.as_ptr().add(idx);
94 std::ptr::copy_nonoverlapping(src, result.as_mut_ptr(), num_bytes);
95 std::mem::transmute(result)
96 }
97}
98
99impl<'s> Seq<'s> for PackedSeq<'s> {
100 const BASES_PER_BYTE: usize = 4;
101 const BITS_PER_CHAR: usize = 2;
102 type SeqVec = PackedSeqVec;
103
104 #[inline(always)]
105 fn len(&self) -> usize {
106 self.len
107 }
108
109 #[inline(always)]
110 fn get(&self, index: usize) -> u8 {
111 let offset = self.offset + index;
112 let idx = offset / 4;
113 let offset = offset % 4;
114 unsafe { (*self.seq.get_unchecked(idx) >> (2 * offset)) & 3 }
115 }
116
117 #[inline(always)]
118 fn get_ascii(&self, index: usize) -> u8 {
119 unpack_base(self.get(index))
120 }
121
122 #[inline(always)]
124 fn to_word(&self) -> usize {
125 debug_assert!(self.len() <= usize::BITS as usize / 2 - 3);
126 let mask = usize::MAX >> (64 - 2 * self.len());
127 unsafe {
128 ((self.seq.as_ptr() as *const usize).read_unaligned() >> (2 * self.offset)) & mask
129 }
130 }
131
132 fn to_vec(&self) -> PackedSeqVec {
133 assert_eq!(self.offset, 0);
134 PackedSeqVec {
135 seq: self.seq.to_vec(),
136 len: self.len,
137 }
138 }
139
140 #[inline(always)]
141 fn slice(&self, range: Range<usize>) -> Self {
142 debug_assert!(
143 range.end <= self.len,
144 "Slice index out of bounds: {} > {}",
145 range.end,
146 self.len
147 );
148 PackedSeq {
149 seq: self.seq,
150 offset: self.offset + range.start,
151 len: range.end - range.start,
152 }
153 .normalize()
154 }
155
156 #[inline(always)]
157 fn iter_bp(self) -> impl ExactSizeIterator<Item = u8> + Clone {
158 assert!(self.len <= self.seq.len() * 4);
159
160 let this = self.normalize();
161
162 let mut byte = 0;
164 let mut it = (0..this.len + this.offset).map(
165 #[inline(always)]
166 move |i| {
167 if i % 4 == 0 {
168 byte = this.seq[i / 4];
169 }
170 (byte >> (2 * (i % 4))) & 0b11
172 },
173 );
174 it.by_ref().take(this.offset).for_each(drop);
175 it
176 }
177
178 #[inline(always)]
179 fn par_iter_bp(self, context: usize) -> (impl ExactSizeIterator<Item = S> + Clone, usize) {
180 #[cfg(target_endian = "big")]
181 panic!("Big endian architectures are not supported.");
182
183 let this = self.normalize();
184 assert_eq!(this.offset, 0, "Non-byte offsets are not yet supported.");
185
186 let num_kmers = this.len.saturating_sub(context - 1);
187 let n = num_kmers.div_ceil(L).next_multiple_of(4);
188 let bytes_per_chunk = n / 4;
189 let padding = 4 * L * bytes_per_chunk - num_kmers;
190
191 let offsets: [usize; 8] = from_fn(|l| (l * bytes_per_chunk)).into();
192 let mut cur = S::ZERO;
193
194 let mut buf = Box::new([S::ZERO; 8]);
197
198 let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
199 let it = (0..par_len).map(
200 #[inline(always)]
201 move |i| {
202 if i % 16 == 0 {
203 if i % 128 == 0 {
204 let data: [u32x8; 8] = from_fn(
206 #[inline(always)]
207 |lane| read_slice(this.seq, offsets[lane] + (i / 4)),
208 );
209 *buf = transpose(data);
210 }
211 cur = buf[(i % 128) / 16];
212 }
213 let chars = cur & S::splat(0x03);
215 cur = cur >> S::splat(2);
217 chars
218 },
219 );
220
221 (it, padding)
222 }
223
224 #[inline(always)]
225 fn par_iter_bp_delayed(
226 self,
227 context: usize,
228 delay: usize,
229 ) -> (impl ExactSizeIterator<Item = (S, S)> + Clone, usize) {
230 #[cfg(target_endian = "big")]
231 panic!("Big endian architectures are not supported.");
232
233 assert!(
234 delay < usize::MAX / 2,
235 "Delay={} should be >=0.",
236 delay as isize
237 );
238
239 let this = self.normalize();
240 assert_eq!(this.offset, 0, "Non-byte offsets are not yet supported.");
241
242 let num_kmers = this.len.saturating_sub(context - 1);
243 let n = num_kmers.div_ceil(L).next_multiple_of(4);
244 let bytes_per_chunk = n / 4;
245 let padding = 4 * L * bytes_per_chunk - num_kmers;
246
247 let offsets: [usize; 8] = from_fn(|l| (l * bytes_per_chunk)).into();
248 let mut upcoming = S::ZERO;
249 let mut upcoming_d = S::ZERO;
250
251 let buf_len = (delay / 16 + 8).next_power_of_two();
255 let buf_mask = buf_len - 1;
256 let mut buf = vec![S::ZERO; buf_len];
257 let mut write_idx = 0;
258 let mut read_idx = (buf_len - delay / 16) % buf_len;
261
262 let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
263 let it = (0..par_len).map(
264 #[inline(always)]
265 move |i| {
266 if i % 16 == 0 {
267 if i % 128 == 0 {
268 let data: [u32x8; 8] = from_fn(
270 #[inline(always)]
271 |lane| read_slice(this.seq, offsets[lane] + (i / 4)),
272 );
273 unsafe {
274 *TryInto::<&mut [u32x8; 8]>::try_into(
275 buf.get_unchecked_mut(write_idx..write_idx + 8),
276 )
277 .unwrap_unchecked() = transpose(data);
278 }
279 }
280 upcoming = buf[write_idx];
281 write_idx += 1;
282 write_idx &= buf_mask;
283 }
284 if i % 16 == delay % 16 {
285 unsafe { assert_unchecked(read_idx < buf.len()) };
286 upcoming_d = buf[read_idx];
287 read_idx += 1;
288 read_idx &= buf_mask;
289 }
290 let chars = upcoming & S::splat(0x03);
292 let chars_d = upcoming_d & S::splat(0x03);
293 upcoming = upcoming >> S::splat(2);
295 upcoming_d = upcoming_d >> S::splat(2);
296 (chars, chars_d)
297 },
298 );
299
300 (it, padding)
301 }
302
303 #[inline(always)]
304 fn par_iter_bp_delayed_2(
305 self,
306 context: usize,
307 delay1: usize,
308 delay2: usize,
309 ) -> (impl ExactSizeIterator<Item = (S, S, S)> + Clone, usize) {
310 #[cfg(target_endian = "big")]
311 panic!("Big endian architectures are not supported.");
312
313 let this = self.normalize();
314 assert_eq!(this.offset, 0, "Non-byte offsets are not yet supported.");
315 assert!(delay1 <= delay2, "Delay1 must be at most delay2.");
316
317 let num_kmers = this.len.saturating_sub(context - 1);
318 let n = num_kmers.div_ceil(L).next_multiple_of(4);
319 let bytes_per_chunk = n / 4;
320 let padding = 4 * L * bytes_per_chunk - num_kmers;
321
322 let offsets: [usize; 8] = from_fn(|l| (l * bytes_per_chunk)).into();
323 let mut upcoming = S::ZERO;
324 let mut upcoming_d1 = S::ZERO;
325 let mut upcoming_d2 = S::ZERO;
326
327 let buf_len = (delay2 / 16 + 8).next_power_of_two();
329 let buf_mask = buf_len - 1;
330 let mut buf = vec![S::ZERO; buf_len];
331 let mut write_idx = 0;
332 let mut read_idx1 = (buf_len - delay1 / 16) % buf_len;
335 let mut read_idx2 = (buf_len - delay2 / 16) % buf_len;
336
337 let par_len = if num_kmers == 0 { 0 } else { n + context - 1 };
338 let it = (0..par_len).map(
339 #[inline(always)]
340 move |i| {
341 if i % 16 == 0 {
342 if i % 128 == 0 {
343 let data: [u32x8; 8] = from_fn(
345 #[inline(always)]
346 |lane| read_slice(this.seq, offsets[lane] + (i / 4)),
347 );
348 unsafe {
349 *TryInto::<&mut [u32x8; 8]>::try_into(
350 buf.get_unchecked_mut(write_idx..write_idx + 8),
351 )
352 .unwrap_unchecked() = transpose(data);
353 }
354 }
355 upcoming = buf[write_idx];
356 write_idx += 1;
357 write_idx &= buf_mask;
358 }
359 if i % 16 == delay1 % 16 {
360 unsafe { assert_unchecked(read_idx1 < buf.len()) };
361 upcoming_d1 = buf[read_idx1];
362 read_idx1 += 1;
363 read_idx1 &= buf_mask;
364 }
365 if i % 16 == delay2 % 16 {
366 unsafe { assert_unchecked(read_idx2 < buf.len()) };
367 upcoming_d2 = buf[read_idx2];
368 read_idx2 += 1;
369 read_idx2 &= buf_mask;
370 }
371 let chars = upcoming & S::splat(0x03);
373 let chars_d1 = upcoming_d1 & S::splat(0x03);
374 let chars_d2 = upcoming_d2 & S::splat(0x03);
375 upcoming = upcoming >> S::splat(2);
377 upcoming_d1 = upcoming_d1 >> S::splat(2);
378 upcoming_d2 = upcoming_d2 >> S::splat(2);
379 (chars, chars_d1, chars_d2)
380 },
381 );
382
383 (it, padding)
384 }
385
386 fn cmp_lcp(&self, other: &Self) -> (std::cmp::Ordering, usize) {
388 let mut lcp = 0;
389 let min_len = self.len.min(other.len);
390 for i in (0..min_len).step_by(29) {
391 let len = (min_len - i).min(29);
392 let this = self.slice(i..i + len);
393 let other = other.slice(i..i + len);
394 let this_word = this.to_word();
395 let other_word = other.to_word();
396 if this_word != other_word {
397 let eq = this_word ^ other_word;
399 let t = eq.trailing_zeros() / 2 * 2;
400 lcp += t as usize / 2;
401 let mask = 0b11 << t;
402 return ((this_word & mask).cmp(&(other_word & mask)), lcp);
403 }
404 lcp += len;
405 }
406 (self.len.cmp(&other.len), lcp)
407 }
408}
409
410impl PartialEq for PackedSeq<'_> {
411 fn eq(&self, other: &Self) -> bool {
413 if self.len != other.len {
414 return false;
415 }
416 for i in (0..self.len).step_by(29) {
417 let len = (self.len - i).min(29);
418 let this = self.slice(i..i + len);
419 let that = other.slice(i..i + len);
420 if this.to_word() != that.to_word() {
421 return false;
422 }
423 }
424 return true;
425 }
426}
427
428impl Eq for PackedSeq<'_> {}
429
430impl PartialOrd for PackedSeq<'_> {
431 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
432 Some(self.cmp(other))
433 }
434}
435
436impl Ord for PackedSeq<'_> {
437 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
439 let min_len = self.len.min(other.len);
440 for i in (0..min_len).step_by(29) {
441 let len = (min_len - i).min(29);
442 let this = self.slice(i..i + len);
443 let other = other.slice(i..i + len);
444 let this_word = this.to_word();
445 let other_word = other.to_word();
446 if this_word != other_word {
447 let eq = this_word ^ other_word;
449 let t = eq.trailing_zeros() / 2 * 2;
450 let mask = 0b11 << t;
451 return (this_word & mask).cmp(&(other_word & mask));
452 }
453 }
454 self.len.cmp(&other.len)
455 }
456}
457
458impl SeqVec for PackedSeqVec {
459 type Seq<'s> = PackedSeq<'s>;
460
461 fn into_raw(self) -> Vec<u8> {
462 self.seq
463 }
464
465 #[inline(always)]
466 fn as_slice(&self) -> Self::Seq<'_> {
467 PackedSeq {
468 seq: &self.seq,
469 offset: 0,
470 len: self.len,
471 }
472 }
473
474 fn push_seq<'a>(&mut self, seq: PackedSeq<'_>) -> Range<usize> {
475 let start = 4 * self.seq.len() + seq.offset;
476 let end = start + seq.len();
477 self.seq.extend(seq.seq);
478 self.len = 4 * self.seq.len();
479 start..end
480 }
481
482 fn push_ascii(&mut self, seq: &[u8]) -> Range<usize> {
494 let start_aligned = 4 * self.seq.len();
495 let start = self.len;
496 let len = seq.len();
497
498 let unaligned = core::cmp::min(start_aligned - start, len);
499 if unaligned > 0 {
500 let mut packed_byte = *self.seq.last().unwrap();
501 for &base in &seq[..unaligned] {
502 packed_byte |= pack_char(base) << ((self.len % 4) * 2);
503 self.len += 1;
504 }
505 *self.seq.last_mut().unwrap() = packed_byte;
506 }
507
508 #[allow(unused)]
509 let mut last = unaligned;
510
511 #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))]
512 {
513 last = unaligned + (len - unaligned) / 8 * 8;
514
515 for i in (unaligned..last).step_by(8) {
516 let chunk = &seq[i..i + 8].try_into().unwrap();
517 let ascii = u64::from_ne_bytes(*chunk);
518 let packed_bytes =
519 unsafe { std::arch::x86_64::_pext_u64(ascii, 0x0606060606060606) };
520 self.seq.push(packed_bytes as u8);
521 self.seq.push((packed_bytes >> 8) as u8);
522 self.len += 8;
523 }
524 }
525
526 let mut packed_byte = 0;
527 for &base in &seq[last..] {
528 packed_byte |= pack_char(base) << ((self.len % 4) * 2);
529 self.len += 1;
530 if self.len % 4 == 0 {
531 self.seq.push(packed_byte);
532 packed_byte = 0;
533 }
534 }
535 if self.len % 4 != 0 && last < len {
536 self.seq.push(packed_byte);
537 }
538 start..start + len
539 }
540
541 fn random(n: usize) -> Self {
542 let mut seq = vec![0; n.div_ceil(4)];
543 rand::rngs::SmallRng::from_os_rng().fill_bytes(&mut seq);
544 PackedSeqVec { seq, len: n }
545 }
546}