Skip to main content

cf1_rs/
kmer.rs

1use crate::dna::Base;
2use std::hash::{Hash, Hasher};
3
4/// Reverse complement lookup table for bytes (4 bases per byte).
5/// Each byte encodes 4 bases in 2-bit format. The reverse complement reverses
6/// the order and complements each base. Matches C++ Kmer_Utility::REVERSE_COMPLEMENT_BYTE.
7pub const REVERSE_COMPLEMENT_BYTE: [u8; 256] = [
8    255, 191, 127, 63, 239, 175, 111, 47, 223, 159, 95, 31, 207, 143, 79, 15, 251, 187, 123, 59,
9    235, 171, 107, 43, 219, 155, 91, 27, 203, 139, 75, 11, 247, 183, 119, 55, 231, 167, 103, 39,
10    215, 151, 87, 23, 199, 135, 71, 7, 243, 179, 115, 51, 227, 163, 99, 35, 211, 147, 83, 19,
11    195, 131, 67, 3, 254, 190, 126, 62, 238, 174, 110, 46, 222, 158, 94, 30, 206, 142, 78, 14,
12    250, 186, 122, 58, 234, 170, 106, 42, 218, 154, 90, 26, 202, 138, 74, 10, 246, 182, 118, 54,
13    230, 166, 102, 38, 214, 150, 86, 22, 198, 134, 70, 6, 242, 178, 114, 50, 226, 162, 98, 34,
14    210, 146, 82, 18, 194, 130, 66, 2, 253, 189, 125, 61, 237, 173, 109, 45, 221, 157, 93, 29,
15    205, 141, 77, 13, 249, 185, 121, 57, 233, 169, 105, 41, 217, 153, 89, 25, 201, 137, 73, 9,
16    245, 181, 117, 53, 229, 165, 101, 37, 213, 149, 85, 21, 197, 133, 69, 5, 241, 177, 113, 49,
17    225, 161, 97, 33, 209, 145, 81, 17, 193, 129, 65, 1, 252, 188, 124, 60, 236, 172, 108, 44,
18    220, 156, 92, 28, 204, 140, 76, 12, 248, 184, 120, 56, 232, 168, 104, 40, 216, 152, 88, 24,
19    200, 136, 72, 8, 244, 180, 116, 52, 228, 164, 100, 36, 212, 148, 84, 20, 196, 132, 68, 4,
20    240, 176, 112, 48, 224, 160, 96, 32, 208, 144, 80, 16, 192, 128, 64, 0,
21];
22
23/// Trait mapping K -> storage type. Implemented for each supported K value.
24pub trait KmerBits: Sized {
25    type Storage: Copy + Clone + Eq + Ord + Hash + Send + Sync + Default + std::fmt::Debug;
26    const NUM_WORDS: usize;
27
28    fn word(storage: &Self::Storage, idx: usize) -> u64;
29    fn set_word(storage: &mut Self::Storage, idx: usize, val: u64);
30    fn as_bytes(storage: &Self::Storage) -> &[u8];
31}
32
33#[derive(Clone, Copy, Debug)]
34pub struct Kmer<const K: usize>
35where
36    Kmer<K>: KmerBits,
37{
38    pub(crate) bits: <Kmer<K> as KmerBits>::Storage,
39}
40
41impl<const K: usize> Default for Kmer<K>
42where
43    Kmer<K>: KmerBits,
44{
45    fn default() -> Self {
46        Kmer {
47            bits: Default::default(),
48        }
49    }
50}
51
52impl<const K: usize> PartialEq for Kmer<K>
53where
54    Kmer<K>: KmerBits,
55{
56    fn eq(&self, other: &Self) -> bool {
57        self.bits == other.bits
58    }
59}
60
61impl<const K: usize> Eq for Kmer<K> where Kmer<K>: KmerBits {}
62
63impl<const K: usize> PartialOrd for Kmer<K>
64where
65    Kmer<K>: KmerBits,
66{
67    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
68        Some(self.cmp(other))
69    }
70}
71
72/// Comparison: high word first (matches C++ operator<).
73impl<const K: usize> Ord for Kmer<K>
74where
75    Kmer<K>: KmerBits,
76{
77    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
78        let n = <Kmer<K> as KmerBits>::NUM_WORDS;
79        for idx in (0..n).rev() {
80            let a = <Kmer<K> as KmerBits>::word(&self.bits, idx);
81            let b = <Kmer<K> as KmerBits>::word(&other.bits, idx);
82            match a.cmp(&b) {
83                std::cmp::Ordering::Equal => continue,
84                other => return other,
85            }
86        }
87        std::cmp::Ordering::Equal
88    }
89}
90
91impl<const K: usize> Hash for Kmer<K>
92where
93    Kmer<K>: KmerBits,
94{
95    fn hash<H: Hasher>(&self, state: &mut H) {
96        self.bits.hash(state);
97    }
98}
99
100impl<const K: usize> Kmer<K>
101where
102    Kmer<K>: KmerBits,
103{
104    const NUM_WORDS: usize = K.div_ceil(32);
105    const CLEAR_MSN_MASK: u64 = !(0b11u64 << (2 * ((K - 1) % 32)));
106    const NUM_BYTES: usize = K.div_ceil(4);
107    const MSN_SHIFT: usize = 2 * ((K - 1) % 32);
108
109    /// Create a k-mer from ASCII sequence at the given offset.
110    pub fn from_ascii(seq: &[u8], offset: usize) -> Self {
111        let label = &seq[offset..offset + K];
112        let mut kmer = Kmer::default();
113
114        let packed_word_count = K / 32;
115
116        // Get fully packed words' binary representations.
117        for data_idx in 0..packed_word_count {
118            let start = K - (data_idx << 5) - 32;
119            let word = encode_word::<32>(&label[start..start + 32]);
120            <Kmer<K> as KmerBits>::set_word(&mut kmer.bits, data_idx, word);
121        }
122
123        // Get the partially packed (highest index) word's binary representation.
124        let rem = K & 31;
125        if rem > 0 {
126            let word = encode_word_dyn(rem, &label[0..rem]);
127            <Kmer<K> as KmerBits>::set_word(&mut kmer.bits, Self::NUM_WORDS - 1, word);
128        }
129
130        kmer
131    }
132
133    /// Create a k-mer from 2-bit packed data at the given base offset.
134    /// Packed format: 4 bases per byte, MSB-first (base j is at bits 6-2*(j%4) of byte j/4).
135    /// The packed data must not contain placeholder (N) bases.
136    pub fn from_packed_2bit(packed: &[u8], base_offset: usize) -> Self {
137        let byte_start = base_offset / 4;
138        let sub_offset = base_offset % 4;
139        let total_bases = sub_offset + K;
140        let bytes_needed = total_bases.div_ceil(4);
141
142        // Read bytes into u128 accumulator (handles up to 64 bases = K≤61 with sub_offset≤3).
143        let mut val = 0u128;
144        for i in 0..bytes_needed {
145            val = (val << 8) | packed[byte_start + i] as u128;
146        }
147        // Left-align to 128 bits, then skip sub_offset bases.
148        val <<= (16 - bytes_needed) * 8 + 2 * sub_offset;
149        // Extract top 2*K bits → right-align.
150        let result = val >> (128 - 2 * K);
151
152        let mut kmer = Kmer::default();
153        <Kmer<K> as KmerBits>::set_word(&mut kmer.bits, 0, result as u64);
154        if <Kmer<K> as KmerBits>::NUM_WORDS > 1 {
155            <Kmer<K> as KmerBits>::set_word(&mut kmer.bits, 1, (result >> 64) as u64);
156        }
157        kmer
158    }
159
160    /// Reverse complement of this k-mer. Matches C++ byte-level algorithm.
161    pub fn reverse_complement(&self) -> Self {
162        let mut rc = Kmer::default();
163
164        let data = <Kmer<K> as KmerBits>::as_bytes(&self.bits);
165
166        // Use raw pointer to get mutable access to rc's bytes.
167        let rc_ptr = &mut rc.bits as *mut <Kmer<K> as KmerBits>::Storage as *mut u8;
168
169        let packed_byte_count = K / 4;
170        for byte_idx in 0..packed_byte_count {
171            unsafe {
172                *rc_ptr.add(packed_byte_count - 1 - byte_idx) =
173                    REVERSE_COMPLEMENT_BYTE[data[byte_idx] as usize];
174            }
175        }
176
177        let rem_base_count = K % 4;
178        if rem_base_count == 0 {
179            return rc;
180        }
181
182        unsafe {
183            *rc_ptr.add(packed_byte_count) = 0;
184        }
185
186        // Left shift by rem_base_count positions.
187        rc.left_shift_by(rem_base_count);
188
189        // Process remaining bases individually.
190        let partial_byte = data[packed_byte_count];
191        for i in 0..rem_base_count {
192            let base = Base::from_2bit((partial_byte >> (2 * i)) & 0b11);
193            let compl = base.complement();
194            unsafe {
195                *rc_ptr |= (compl as u8) << (2 * (rem_base_count - 1 - i));
196            }
197        }
198
199        rc
200    }
201
202    /// Canonical form = min(self, self.reverse_complement()).
203    pub fn canonical(&self) -> Self {
204        let rc = self.reverse_complement();
205        if *self <= rc {
206            *self
207        } else {
208            rc
209        }
210    }
211
212    /// Canonical form given a precomputed reverse complement.
213    pub fn canonical_with_rc(&self, rev_compl: &Self) -> Self {
214        if *self <= *rev_compl {
215            *self
216        } else {
217            *rev_compl
218        }
219    }
220
221    /// True if self is in forward direction relative to kmer_hat (the canonical form).
222    pub fn in_forward(&self, kmer_hat: &Self) -> bool {
223        self == kmer_hat
224    }
225
226    /// Extract front base (MSN = most significant nucleotide, n_{k-1}).
227    pub fn front(&self) -> Base {
228        let w = <Kmer<K> as KmerBits>::word(&self.bits, Self::NUM_WORDS - 1);
229        Base::from_2bit(((w >> Self::MSN_SHIFT) & 0b11) as u8)
230    }
231
232    /// Extract back base (LSN = least significant nucleotide, n_0).
233    pub fn back(&self) -> Base {
234        let w = <Kmer<K> as KmerBits>::word(&self.bits, 0);
235        Base::from_2bit((w & 0b11) as u8)
236    }
237
238    /// Roll forward: chop off front, append base at back.
239    /// Also updates rev_compl accordingly.
240    pub fn roll_to_next_kmer(&mut self, base: Base, rev_compl: &mut Self) {
241        // Clear MSN.
242        let w = <Kmer<K> as KmerBits>::word(&self.bits, Self::NUM_WORDS - 1);
243        <Kmer<K> as KmerBits>::set_word(
244            &mut self.bits,
245            Self::NUM_WORDS - 1,
246            w & Self::CLEAR_MSN_MASK,
247        );
248        // Left shift by 1.
249        self.left_shift();
250        // Insert new base at LSB.
251        let w0 = <Kmer<K> as KmerBits>::word(&self.bits, 0);
252        <Kmer<K> as KmerBits>::set_word(&mut self.bits, 0, w0 | base as u64);
253
254        // Update reverse complement.
255        rev_compl.right_shift();
256        let rw = <Kmer<K> as KmerBits>::word(&rev_compl.bits, Self::NUM_WORDS - 1);
257        <Kmer<K> as KmerBits>::set_word(
258            &mut rev_compl.bits,
259            Self::NUM_WORDS - 1,
260            rw | ((base.complement() as u64) << Self::MSN_SHIFT),
261        );
262    }
263
264    /// Roll forward from ASCII character.
265    pub fn roll_to_next_kmer_char(&mut self, next_char: u8, rev_compl: &mut Self) {
266        let base = Base::map_base(next_char);
267        self.roll_to_next_kmer(base, rev_compl);
268    }
269
270    /// XXH3 64-bit hash matching C++ `to_u64()`.
271    pub fn hash_xxh3(&self) -> u64 {
272        let bytes = <Kmer<K> as KmerBits>::as_bytes(&self.bits);
273        xxhash_rust::xxh3::xxh3_64_with_seed(&bytes[..Self::NUM_BYTES], 0)
274    }
275
276    /// Get string label of the k-mer.
277    pub fn string_label(&self) -> String {
278        let mut copy = *self;
279        let mut label = Vec::with_capacity(K);
280        for _ in 0..K {
281            let w = <Kmer<K> as KmerBits>::word(&copy.bits, 0);
282            let base = Base::from_2bit((w & 0b11) as u8);
283            label.push(base.to_char());
284            copy.right_shift();
285        }
286        label.reverse();
287        String::from_utf8(label).unwrap()
288    }
289
290    /// Left shift all words by 2 bits (one base).
291    fn left_shift(&mut self) {
292        let n = Self::NUM_WORDS;
293        for idx in (1..n).rev() {
294            let curr = <Kmer<K> as KmerBits>::word(&self.bits, idx);
295            let prev = <Kmer<K> as KmerBits>::word(&self.bits, idx - 1);
296            <Kmer<K> as KmerBits>::set_word(
297                &mut self.bits,
298                idx,
299                (curr << 2) | (prev >> 62),
300            );
301        }
302        let w0 = <Kmer<K> as KmerBits>::word(&self.bits, 0);
303        <Kmer<K> as KmerBits>::set_word(&mut self.bits, 0, w0 << 2);
304    }
305
306    /// Left shift by B bases (2B bits).
307    fn left_shift_by(&mut self, b: usize) {
308        if b == 0 {
309            return;
310        }
311        let num_bit_shift = 2 * b;
312        let n = Self::NUM_WORDS;
313        let mask_msns = ((1u64 << num_bit_shift) - 1) << (64 - num_bit_shift);
314        for idx in (1..n).rev() {
315            let curr = <Kmer<K> as KmerBits>::word(&self.bits, idx);
316            let prev = <Kmer<K> as KmerBits>::word(&self.bits, idx - 1);
317            <Kmer<K> as KmerBits>::set_word(
318                &mut self.bits,
319                idx,
320                (curr << num_bit_shift) | ((prev & mask_msns) >> (64 - num_bit_shift)),
321            );
322        }
323        let w0 = <Kmer<K> as KmerBits>::word(&self.bits, 0);
324        <Kmer<K> as KmerBits>::set_word(&mut self.bits, 0, w0 << num_bit_shift);
325    }
326
327    /// Right shift all words by 2 bits (one base).
328    fn right_shift(&mut self) {
329        let n = Self::NUM_WORDS;
330        for idx in 0..n - 1 {
331            let curr = <Kmer<K> as KmerBits>::word(&self.bits, idx);
332            let next = <Kmer<K> as KmerBits>::word(&self.bits, idx + 1);
333            <Kmer<K> as KmerBits>::set_word(
334                &mut self.bits,
335                idx,
336                (curr >> 2) | ((next & 0b11) << 62),
337            );
338        }
339        let wn = <Kmer<K> as KmerBits>::word(&self.bits, n - 1);
340        <Kmer<K> as KmerBits>::set_word(&mut self.bits, n - 1, wn >> 2);
341    }
342
343    /// Write the unitig sequence for this k-mer range to a buffer.
344    pub fn write_label_to<W: std::io::Write>(
345        &self,
346        k: usize,
347        seq: &[u8],
348        start_kmer_idx: usize,
349        end_kmer_idx: usize,
350        dir: bool, // true = FWD
351        writer: &mut W,
352    ) -> std::io::Result<()> {
353        let segment_len = end_kmer_idx - start_kmer_idx + k;
354        if dir {
355            for offset in 0..segment_len {
356                writer.write_all(&[crate::dna::to_upper(seq[start_kmer_idx + offset])])?;
357            }
358        } else {
359            for offset in 0..segment_len {
360                writer.write_all(&[crate::dna::complement_char(
361                    seq[end_kmer_idx + k - 1 - offset],
362                )])?;
363            }
364        }
365        Ok(())
366    }
367}
368
369impl Base {
370    #[inline]
371    pub fn from_2bit(v: u8) -> Base {
372        debug_assert!(v < 4);
373        unsafe { std::mem::transmute(v) }
374    }
375}
376
377/// Encode a fixed-size word from ASCII DNA characters.
378/// First character occupies the most-significant bits, matching C++ cuttlefish.
379fn encode_word<const N: usize>(label: &[u8]) -> u64 {
380    debug_assert!(label.len() >= N);
381    label.iter().take(N).fold(0u64, |acc, &b| (acc << 2) | (Base::map_base(b) as u64))
382}
383
384/// Encode a variable-size word.
385fn encode_word_dyn(n: usize, label: &[u8]) -> u64 {
386    debug_assert!(label.len() >= n);
387    label.iter().take(n).fold(0u64, |acc, &b| (acc << 2) | (Base::map_base(b) as u64))
388}
389
390// Implementation for K=1..=32 (single u64 storage).
391macro_rules! impl_kmer_bits_u64 {
392    ($($k:literal),*) => {
393        $(
394            impl KmerBits for Kmer<$k> {
395                type Storage = u64;
396                const NUM_WORDS: usize = 1;
397
398                #[inline]
399                fn word(storage: &u64, _idx: usize) -> u64 {
400                    *storage
401                }
402
403                #[inline]
404                fn set_word(storage: &mut u64, _idx: usize, val: u64) {
405                    *storage = val;
406                }
407
408                #[inline]
409                fn as_bytes(storage: &u64) -> &[u8] {
410                    unsafe { std::slice::from_raw_parts(storage as *const u64 as *const u8, 8) }
411                }
412            }
413        )*
414    };
415}
416
417// Implementation for K=33..=63 (u128 storage = 2 x u64 words).
418macro_rules! impl_kmer_bits_u128 {
419    ($($k:literal),*) => {
420        $(
421            impl KmerBits for Kmer<$k> {
422                type Storage = u128;
423                const NUM_WORDS: usize = 2;
424
425                #[inline]
426                fn word(storage: &u128, idx: usize) -> u64 {
427                    if idx == 0 {
428                        *storage as u64
429                    } else {
430                        (*storage >> 64) as u64
431                    }
432                }
433
434                #[inline]
435                fn set_word(storage: &mut u128, idx: usize, val: u64) {
436                    if idx == 0 {
437                        *storage = (*storage & (0xFFFF_FFFF_FFFF_FFFFu128 << 64)) | val as u128;
438                    } else {
439                        *storage = (*storage & 0xFFFF_FFFF_FFFF_FFFFu128) | ((val as u128) << 64);
440                    }
441                }
442
443                #[inline]
444                fn as_bytes(storage: &u128) -> &[u8] {
445                    unsafe { std::slice::from_raw_parts(storage as *const u128 as *const u8, 16) }
446                }
447            }
448        )*
449    };
450}
451
452impl_kmer_bits_u64!(
453    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
454    26, 27, 28, 29, 30, 31, 32
455);
456
457impl_kmer_bits_u128!(
458    33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
459    56, 57, 58, 59, 60, 61, 62, 63
460);
461
462/// Dispatch macro: calls $func::<K>($($arg),*) for the given runtime k value.
463/// Only odd values in [1, 63] are supported (de Bruijn graph convention).
464#[macro_export]
465macro_rules! dispatch_k {
466    ($k:expr, $func:ident $(, $arg:expr)*) => {
467        match $k {
468            1 => $func::<1>($($arg),*),
469            3 => $func::<3>($($arg),*),
470            5 => $func::<5>($($arg),*),
471            7 => $func::<7>($($arg),*),
472            9 => $func::<9>($($arg),*),
473            11 => $func::<11>($($arg),*),
474            13 => $func::<13>($($arg),*),
475            15 => $func::<15>($($arg),*),
476            17 => $func::<17>($($arg),*),
477            19 => $func::<19>($($arg),*),
478            21 => $func::<21>($($arg),*),
479            23 => $func::<23>($($arg),*),
480            25 => $func::<25>($($arg),*),
481            27 => $func::<27>($($arg),*),
482            29 => $func::<29>($($arg),*),
483            31 => $func::<31>($($arg),*),
484            33 => $func::<33>($($arg),*),
485            35 => $func::<35>($($arg),*),
486            37 => $func::<37>($($arg),*),
487            39 => $func::<39>($($arg),*),
488            41 => $func::<41>($($arg),*),
489            43 => $func::<43>($($arg),*),
490            45 => $func::<45>($($arg),*),
491            47 => $func::<47>($($arg),*),
492            49 => $func::<49>($($arg),*),
493            51 => $func::<51>($($arg),*),
494            53 => $func::<53>($($arg),*),
495            55 => $func::<55>($($arg),*),
496            57 => $func::<57>($($arg),*),
497            59 => $func::<59>($($arg),*),
498            61 => $func::<61>($($arg),*),
499            63 => $func::<63>($($arg),*),
500            _ => panic!("k must be odd and in [1, 63], got {}", $k),
501        }
502    };
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_from_ascii_and_label() {
511        let seq = b"ACGTACGTACGTACGTACGTACGTACGTACG"; // 31 bases
512        let kmer = Kmer::<31>::from_ascii(seq, 0);
513        let label = kmer.string_label();
514        assert_eq!(label, "ACGTACGTACGTACGTACGTACGTACGTACG");
515    }
516
517    #[test]
518    fn test_reverse_complement() {
519        let seq = b"ACGT";
520        let kmer = Kmer::<4>::from_ascii(seq, 0);
521        let rc = kmer.reverse_complement();
522        assert_eq!(rc.string_label(), "ACGT"); // ACGT is its own reverse complement
523    }
524
525    #[test]
526    fn test_reverse_complement_asymmetric() {
527        let seq = b"AAAC";
528        let kmer = Kmer::<4>::from_ascii(seq, 0);
529        let rc = kmer.reverse_complement();
530        assert_eq!(rc.string_label(), "GTTT");
531    }
532
533    #[test]
534    fn test_canonical() {
535        let seq = b"AAAC";
536        let kmer = Kmer::<4>::from_ascii(seq, 0);
537        let can = kmer.canonical();
538        // AAAC < GTTT so canonical should be AAAC
539        assert_eq!(can.string_label(), "AAAC");
540    }
541
542    #[test]
543    fn test_front_back() {
544        let seq = b"ACGT";
545        let kmer = Kmer::<4>::from_ascii(seq, 0);
546        assert_eq!(kmer.front(), Base::A);
547        assert_eq!(kmer.back(), Base::T);
548    }
549
550    #[test]
551    fn test_roll() {
552        let seq = b"ACGTG";
553        let mut kmer = Kmer::<4>::from_ascii(seq, 0);
554        let mut rc = kmer.reverse_complement();
555        kmer.roll_to_next_kmer(Base::G, &mut rc);
556        assert_eq!(kmer.string_label(), "CGTG");
557        assert_eq!(rc.string_label(), "CACG");
558    }
559
560    #[test]
561    fn test_hash_consistency() {
562        let seq = b"ACGTACGTACGTACGTACGTACGTACGTACG";
563        let kmer = Kmer::<31>::from_ascii(seq, 0);
564        let h1 = kmer.hash_xxh3();
565        let h2 = kmer.hash_xxh3();
566        assert_eq!(h1, h2);
567    }
568
569    #[test]
570    fn test_ordering() {
571        let a = Kmer::<4>::from_ascii(b"AAAA", 0);
572        let b = Kmer::<4>::from_ascii(b"AAAC", 0);
573        let c = Kmer::<4>::from_ascii(b"TTTT", 0);
574        assert!(a < b);
575        assert!(b < c);
576    }
577
578    #[test]
579    fn test_k33() {
580        let seq = b"ACGTACGTACGTACGTACGTACGTACGTACGTACG"; // 35 bases
581        let kmer = Kmer::<33>::from_ascii(seq, 0);
582        let label = kmer.string_label();
583        assert_eq!(&label, "ACGTACGTACGTACGTACGTACGTACGTACGTA");
584
585        let rc = kmer.reverse_complement();
586        let rc_label = rc.string_label();
587        // Verify round-trip
588        let rc_rc = rc.reverse_complement();
589        assert_eq!(rc_rc, kmer);
590        let _ = rc_label;
591    }
592}