Skip to main content

jellyfish_reader/
mer.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::hash::{Hash, Hasher};
4use std::str::FromStr;
5
6use crate::error::Error;
7
8/// Number of bases that fit in a single u64 word.
9const BASES_PER_WORD: usize = 32;
10
11/// Encode a DNA base character to its 2-bit representation.
12///
13/// A=0, C=1, G=2, T=3. Returns None for invalid characters.
14#[inline]
15pub fn encode_base(c: u8) -> Option<u8> {
16    match c {
17        b'A' | b'a' => Some(0),
18        b'C' | b'c' => Some(1),
19        b'G' | b'g' => Some(2),
20        b'T' | b't' => Some(3),
21        _ => None,
22    }
23}
24
25/// Decode a 2-bit encoding back to a DNA base character.
26#[inline]
27pub fn decode_base(code: u8) -> u8 {
28    match code & 0x3 {
29        0 => b'A',
30        1 => b'C',
31        2 => b'G',
32        3 => b'T',
33        _ => unreachable!(),
34    }
35}
36
37/// Complement a 2-bit encoded base (A↔T, C↔G).
38#[inline]
39pub fn complement_code(code: u8) -> u8 {
40    code ^ 0x3
41}
42
43/// Compute the number of u64 words needed for k bases.
44#[inline]
45pub fn words_for_k(k: usize) -> usize {
46    k.div_ceil(BASES_PER_WORD)
47}
48
49/// Reverse complement of a single u64 word containing packed 2-bit bases.
50///
51/// Reverses the order of 2-bit pairs and complements each.
52fn word_reverse_complement(mut word: u64) -> u64 {
53    // Complement all bits (swaps A↔T, C↔G)
54    word = !word;
55    // Reverse 2-bit pairs using byte-swap and bit manipulation
56    // Swap adjacent 2-bit groups
57    word = ((word >> 2) & 0x3333_3333_3333_3333) | ((word & 0x3333_3333_3333_3333) << 2);
58    // Swap adjacent 4-bit groups
59    word = ((word >> 4) & 0x0F0F_0F0F_0F0F_0F0F) | ((word & 0x0F0F_0F0F_0F0F_0F0F) << 4);
60    // Reverse bytes
61    word.swap_bytes()
62}
63
64/// A DNA k-mer stored as 2-bit packed encoding in u64 words.
65///
66/// This matches the Jellyfish MerDNA representation:
67/// - Each base is 2 bits: A=0, C=1, G=2, T=3
68/// - Bases are packed from LSB to MSB within each word
69/// - Words are ordered from least significant to most significant
70///
71/// # Examples
72///
73/// ```
74/// use jellyfish_reader::MerDna;
75///
76/// let mer: MerDna = "ACGT".parse().unwrap();
77/// assert_eq!(mer.to_string(), "ACGT");
78/// assert_eq!(mer.k(), 4);
79///
80/// let rc = mer.get_reverse_complement();
81/// assert_eq!(rc.to_string(), "ACGT"); // ACGT is its own reverse complement
82/// ```
83#[derive(Clone)]
84pub struct MerDna {
85    /// Packed 2-bit encoded bases.
86    words: Vec<u64>,
87    /// K-mer length (number of bases).
88    k: usize,
89}
90
91impl MerDna {
92    /// Create a new k-mer of the given length, initialized to all A's.
93    pub fn new(k: usize) -> Self {
94        Self {
95            words: vec![0u64; words_for_k(k)],
96            k,
97        }
98    }
99
100    /// Create a MerDna from raw word data and k-mer length.
101    ///
102    /// The words should contain 2-bit packed bases matching Jellyfish's encoding.
103    pub fn from_words(words: Vec<u64>, k: usize) -> Self {
104        debug_assert!(words.len() == words_for_k(k));
105        let mut mer = Self { words, k };
106        mer.clean_high_bits();
107        mer
108    }
109
110    /// Create a MerDna by reading packed bytes (as stored in Jellyfish binary files).
111    ///
112    /// Bytes are read in order and packed into u64 words in little-endian byte order.
113    pub fn from_bytes(bytes: &[u8], k: usize) -> Self {
114        let n_words = words_for_k(k);
115        let mut words = vec![0u64; n_words];
116
117        for (i, &byte) in bytes.iter().enumerate() {
118            let word_idx = i / 8;
119            let byte_idx = i % 8;
120            if word_idx < n_words {
121                words[word_idx] |= (byte as u64) << (byte_idx * 8);
122            }
123        }
124
125        let mut mer = Self { words, k };
126        mer.clean_high_bits();
127        mer
128    }
129
130    /// The k-mer length (number of bases).
131    #[inline]
132    pub fn k(&self) -> usize {
133        self.k
134    }
135
136    /// Access the raw u64 words.
137    #[inline]
138    pub fn words(&self) -> &[u64] {
139        &self.words
140    }
141
142    /// Get the base at position `i` (0-indexed from the right/LSB end).
143    ///
144    /// # Panics
145    /// Panics if `i >= k`.
146    pub fn get_base(&self, i: usize) -> u8 {
147        assert!(i < self.k, "base index {i} out of range for k={}", self.k);
148        let word_idx = i / BASES_PER_WORD;
149        let bit_offset = (i % BASES_PER_WORD) * 2;
150        ((self.words[word_idx] >> bit_offset) & 0x3) as u8
151    }
152
153    /// Set the base at position `i` (0-indexed from the right/LSB end).
154    ///
155    /// # Panics
156    /// Panics if `i >= k` or if `base_code` is not in 0..4.
157    pub fn set_base(&mut self, i: usize, base_code: u8) {
158        assert!(i < self.k, "base index {i} out of range for k={}", self.k);
159        assert!(base_code < 4, "invalid base code: {base_code}");
160        let word_idx = i / BASES_PER_WORD;
161        let bit_offset = (i % BASES_PER_WORD) * 2;
162        self.words[word_idx] &= !(0x3u64 << bit_offset);
163        self.words[word_idx] |= (base_code as u64) << bit_offset;
164    }
165
166    /// Shift the k-mer left by one position, inserting `base` at the right end.
167    ///
168    /// Returns the base character that was shifted out from the left end.
169    pub fn shift_left(&mut self, base: u8) -> Option<u8> {
170        let code = encode_base(base)?;
171        let old_high = self.get_base(self.k - 1);
172
173        // Shift each word left by 2 bits, propagating carries
174        let n = self.words.len();
175        for i in (1..n).rev() {
176            self.words[i] = (self.words[i] << 2) | (self.words[i - 1] >> 62);
177        }
178        self.words[0] = (self.words[0] << 2) | (code as u64);
179        self.clean_high_bits();
180
181        Some(decode_base(old_high))
182    }
183
184    /// Shift the k-mer right by one position, inserting `base` at the left end.
185    ///
186    /// Returns the base character that was shifted out from the right end.
187    pub fn shift_right(&mut self, base: u8) -> Option<u8> {
188        let code = encode_base(base)?;
189        let old_low = self.get_base(0);
190
191        // Shift each word right by 2 bits, propagating carries
192        let n = self.words.len();
193        for i in 0..n - 1 {
194            self.words[i] = (self.words[i] >> 2) | (self.words[i + 1] << 62);
195        }
196        self.words[n - 1] >>= 2;
197
198        // Insert new base at the high end
199        let high_pos = self.k - 1;
200        let word_idx = high_pos / BASES_PER_WORD;
201        let bit_offset = (high_pos % BASES_PER_WORD) * 2;
202        self.words[word_idx] |= (code as u64) << bit_offset;
203
204        Some(decode_base(old_low))
205    }
206
207    /// Compute the reverse complement of this k-mer.
208    pub fn get_reverse_complement(&self) -> MerDna {
209        let n = self.words.len();
210
211        if n == 1 {
212            let mut result = vec![0u64; 1];
213            result[0] = word_reverse_complement(self.words[0]) >> (64 - self.k * 2);
214            let mut mer = MerDna {
215                words: result,
216                k: self.k,
217            };
218            mer.clean_high_bits();
219            return mer;
220        }
221
222        // For multi-word k-mers, use base-by-base approach for correctness.
223        // Position i in self maps to position (k-1-i) in result, with complement.
224        let mut result = MerDna::new(self.k);
225        for i in 0..self.k {
226            let base = self.get_base(i);
227            result.set_base(self.k - 1 - i, complement_code(base));
228        }
229        result
230    }
231
232    /// Modify this k-mer in place to its reverse complement.
233    pub fn reverse_complement(&mut self) {
234        *self = self.get_reverse_complement();
235    }
236
237    /// Get the canonical form (lexicographically smaller of self and reverse complement).
238    pub fn get_canonical(&self) -> MerDna {
239        let rc = self.get_reverse_complement();
240        if *self <= rc { self.clone() } else { rc }
241    }
242
243    /// Modify this k-mer in place to its canonical form.
244    pub fn canonicalize(&mut self) {
245        let rc = self.get_reverse_complement();
246        if rc < *self {
247            *self = rc;
248        }
249    }
250
251    /// Check if this k-mer is a homopolymer (all same base).
252    pub fn is_homopolymer(&self) -> bool {
253        if self.k == 0 {
254            return true;
255        }
256        let base = self.get_base(0);
257        (1..self.k).all(|i| self.get_base(i) == base)
258    }
259
260    /// Set all bases to A.
261    pub fn poly_a(&mut self) {
262        self.words.fill(0);
263    }
264
265    /// Set all bases to C.
266    pub fn poly_c(&mut self) {
267        self.fill_with_code(1);
268    }
269
270    /// Set all bases to G.
271    pub fn poly_g(&mut self) {
272        self.fill_with_code(2);
273    }
274
275    /// Set all bases to T.
276    pub fn poly_t(&mut self) {
277        self.fill_with_code(3);
278    }
279
280    /// Fill all bases with the given 2-bit code.
281    fn fill_with_code(&mut self, code: u8) {
282        let pattern = match code {
283            0 => 0x0000_0000_0000_0000u64,
284            1 => 0x5555_5555_5555_5555u64,
285            2 => 0xAAAA_AAAA_AAAA_AAAAu64,
286            3 => 0xFFFF_FFFF_FFFF_FFFFu64,
287            _ => unreachable!(),
288        };
289        self.words.fill(pattern);
290        self.clean_high_bits();
291    }
292
293    /// Zero out bits above the k-mer length in the highest word.
294    fn clean_high_bits(&mut self) {
295        if self.k == 0 {
296            return;
297        }
298        let used_bits = self.k * 2;
299        let total_bits = self.words.len() * 64;
300        if used_bits < total_bits {
301            let last = self.words.len() - 1;
302            let bits_in_last = used_bits - last * 64;
303            self.words[last] &= (1u64 << bits_in_last) - 1;
304        }
305    }
306}
307
308impl fmt::Debug for MerDna {
309    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
310        write!(f, "MerDna(\"{}\")", self)
311    }
312}
313
314impl fmt::Display for MerDna {
315    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316        for i in (0..self.k).rev() {
317            let code = self.get_base(i);
318            f.write_str(std::str::from_utf8(&[decode_base(code)]).unwrap())?;
319        }
320        Ok(())
321    }
322}
323
324impl FromStr for MerDna {
325    type Err = Error;
326
327    fn from_str(s: &str) -> Result<Self, Error> {
328        let k = s.len();
329        if k == 0 {
330            return Err(Error::InvalidKmer("empty k-mer string".to_string()));
331        }
332
333        let mut mer = MerDna::new(k);
334        let bytes = s.as_bytes();
335
336        for (i, &ch) in bytes.iter().enumerate() {
337            let code = encode_base(ch).ok_or_else(|| {
338                Error::InvalidKmer(format!("invalid base '{}' at position {i}", ch as char))
339            })?;
340            // String is stored with first character at highest position
341            let pos = k - 1 - i;
342            mer.set_base(pos, code);
343        }
344
345        Ok(mer)
346    }
347}
348
349impl PartialEq for MerDna {
350    fn eq(&self, other: &Self) -> bool {
351        self.k == other.k && self.words == other.words
352    }
353}
354
355impl Eq for MerDna {}
356
357impl PartialOrd for MerDna {
358    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
359        Some(self.cmp(other))
360    }
361}
362
363impl Ord for MerDna {
364    fn cmp(&self, other: &Self) -> Ordering {
365        // Compare from most significant word to least
366        assert_eq!(
367            self.k, other.k,
368            "cannot compare k-mers of different lengths"
369        );
370        for i in (0..self.words.len()).rev() {
371            match self.words[i].cmp(&other.words[i]) {
372                Ordering::Equal => continue,
373                ord => return ord,
374            }
375        }
376        Ordering::Equal
377    }
378}
379
380impl Hash for MerDna {
381    fn hash<H: Hasher>(&self, state: &mut H) {
382        self.k.hash(state);
383        self.words.hash(state);
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn test_encode_decode_bases() {
393        for (ch, code) in [(b'A', 0), (b'C', 1), (b'G', 2), (b'T', 3)] {
394            assert_eq!(encode_base(ch), Some(code));
395            assert_eq!(decode_base(code), ch);
396        }
397        // Lowercase
398        for (ch, code) in [(b'a', 0), (b'c', 1), (b'g', 2), (b't', 3)] {
399            assert_eq!(encode_base(ch), Some(code));
400        }
401        assert_eq!(encode_base(b'N'), None);
402        assert_eq!(encode_base(b'X'), None);
403    }
404
405    #[test]
406    fn test_complement_code() {
407        assert_eq!(complement_code(0), 3); // A -> T
408        assert_eq!(complement_code(1), 2); // C -> G
409        assert_eq!(complement_code(2), 1); // G -> C
410        assert_eq!(complement_code(3), 0); // T -> A
411    }
412
413    #[test]
414    fn test_new_mer() {
415        let mer = MerDna::new(4);
416        assert_eq!(mer.k(), 4);
417        assert_eq!(mer.to_string(), "AAAA");
418    }
419
420    #[test]
421    fn test_from_str_basic() {
422        let mer: MerDna = "ACGT".parse().unwrap();
423        assert_eq!(mer.k(), 4);
424        assert_eq!(mer.to_string(), "ACGT");
425    }
426
427    #[test]
428    fn test_from_str_lowercase() {
429        let mer: MerDna = "acgt".parse().unwrap();
430        assert_eq!(mer.to_string(), "ACGT");
431    }
432
433    #[test]
434    fn test_from_str_single_base() {
435        for (ch, expected) in [("A", "A"), ("C", "C"), ("G", "G"), ("T", "T")] {
436            let mer: MerDna = ch.parse().unwrap();
437            assert_eq!(mer.to_string(), expected);
438        }
439    }
440
441    #[test]
442    fn test_from_str_invalid() {
443        assert!("ACGN".parse::<MerDna>().is_err());
444        assert!("".parse::<MerDna>().is_err());
445        assert!("ACGX".parse::<MerDna>().is_err());
446    }
447
448    #[test]
449    fn test_roundtrip_various_lengths() {
450        let seqs = [
451            "A",
452            "AC",
453            "ACG",
454            "ACGT",
455            "ACGTACGT",
456            "ACGTACGTACGTACGTACGTACGTACGTACGT", // 32 bases = 1 word exactly
457            "ACGTACGTACGTACGTACGTACGTACGTACGTA", // 33 bases = 2 words
458        ];
459        for seq in seqs {
460            let mer: MerDna = seq.parse().unwrap();
461            assert_eq!(mer.to_string(), seq, "roundtrip failed for {seq}");
462        }
463    }
464
465    #[test]
466    fn test_get_set_base() {
467        let mut mer: MerDna = "ACGT".parse().unwrap();
468        // String "ACGT": A is at position 3, C at 2, G at 1, T at 0
469        assert_eq!(mer.get_base(0), 3); // T
470        assert_eq!(mer.get_base(1), 2); // G
471        assert_eq!(mer.get_base(2), 1); // C
472        assert_eq!(mer.get_base(3), 0); // A
473
474        mer.set_base(0, 0); // T -> A
475        assert_eq!(mer.to_string(), "ACGA");
476    }
477
478    #[test]
479    fn test_reverse_complement_palindrome() {
480        // ACGT is its own reverse complement
481        let mer: MerDna = "ACGT".parse().unwrap();
482        let rc = mer.get_reverse_complement();
483        assert_eq!(rc.to_string(), "ACGT");
484    }
485
486    #[test]
487    fn test_reverse_complement_simple() {
488        let mer: MerDna = "AAAA".parse().unwrap();
489        let rc = mer.get_reverse_complement();
490        assert_eq!(rc.to_string(), "TTTT");
491    }
492
493    #[test]
494    fn test_reverse_complement_asymmetric() {
495        let mer: MerDna = "AACG".parse().unwrap();
496        let rc = mer.get_reverse_complement();
497        assert_eq!(rc.to_string(), "CGTT");
498    }
499
500    #[test]
501    fn test_reverse_complement_involution() {
502        // RC(RC(x)) == x
503        let seqs = ["ACGT", "AAAA", "GCTA", "AACG", "TTTCCCGGGAAA"];
504        for seq in seqs {
505            let mer: MerDna = seq.parse().unwrap();
506            let rc2 = mer.get_reverse_complement().get_reverse_complement();
507            assert_eq!(mer, rc2, "RC involution failed for {seq}");
508        }
509    }
510
511    #[test]
512    fn test_canonical_already_canonical() {
513        let mer: MerDna = "AAAA".parse().unwrap();
514        let canonical = mer.get_canonical();
515        assert_eq!(canonical.to_string(), "AAAA"); // AAAA < TTTT
516    }
517
518    #[test]
519    fn test_canonical_needs_rc() {
520        let mer: MerDna = "TTTT".parse().unwrap();
521        let canonical = mer.get_canonical();
522        assert_eq!(canonical.to_string(), "AAAA"); // AAAA < TTTT
523    }
524
525    #[test]
526    fn test_canonical_palindrome() {
527        let mer: MerDna = "ACGT".parse().unwrap();
528        let canonical = mer.get_canonical();
529        assert_eq!(canonical.to_string(), "ACGT");
530    }
531
532    #[test]
533    fn test_canonical_idempotent() {
534        let seqs = ["ACGT", "TGCA", "AAAA", "CCCC", "AACG"];
535        for seq in seqs {
536            let mer: MerDna = seq.parse().unwrap();
537            let c1 = mer.get_canonical();
538            let c2 = c1.get_canonical();
539            assert_eq!(c1, c2, "canonical not idempotent for {seq}");
540        }
541    }
542
543    #[test]
544    fn test_canonicalize_in_place() {
545        let mut mer: MerDna = "TTTT".parse().unwrap();
546        mer.canonicalize();
547        assert_eq!(mer.to_string(), "AAAA");
548    }
549
550    #[test]
551    fn test_ordering() {
552        let a: MerDna = "AAAA".parse().unwrap();
553        let c: MerDna = "CCCC".parse().unwrap();
554        let g: MerDna = "GGGG".parse().unwrap();
555        let t: MerDna = "TTTT".parse().unwrap();
556        assert!(a < c);
557        assert!(c < g);
558        assert!(g < t);
559    }
560
561    #[test]
562    fn test_hash_consistency() {
563        use std::collections::HashMap;
564        let mer1: MerDna = "ACGT".parse().unwrap();
565        let mer2: MerDna = "ACGT".parse().unwrap();
566        let mut map = HashMap::new();
567        map.insert(mer1, 42);
568        assert_eq!(map.get(&mer2), Some(&42));
569    }
570
571    #[test]
572    fn test_shift_left() {
573        let mut mer: MerDna = "ACGT".parse().unwrap();
574        let out = mer.shift_left(b'A');
575        assert_eq!(out, Some(b'A'));
576        assert_eq!(mer.to_string(), "CGTA");
577    }
578
579    #[test]
580    fn test_shift_right() {
581        let mut mer: MerDna = "ACGT".parse().unwrap();
582        let out = mer.shift_right(b'A');
583        assert_eq!(out, Some(b'T'));
584        assert_eq!(mer.to_string(), "AACG");
585    }
586
587    #[test]
588    fn test_shift_invalid_base() {
589        let mut mer: MerDna = "ACGT".parse().unwrap();
590        assert_eq!(mer.shift_left(b'N'), None);
591        assert_eq!(mer.to_string(), "ACGT"); // unchanged
592    }
593
594    #[test]
595    fn test_homopolymer() {
596        let aaaa: MerDna = "AAAA".parse().unwrap();
597        assert!(aaaa.is_homopolymer());
598
599        let cccc: MerDna = "CCCC".parse().unwrap();
600        assert!(cccc.is_homopolymer());
601
602        let acgt: MerDna = "ACGT".parse().unwrap();
603        assert!(!acgt.is_homopolymer());
604    }
605
606    #[test]
607    fn test_poly_constructors() {
608        let mut mer = MerDna::new(4);
609        mer.poly_a();
610        assert_eq!(mer.to_string(), "AAAA");
611
612        mer.poly_c();
613        assert_eq!(mer.to_string(), "CCCC");
614
615        mer.poly_g();
616        assert_eq!(mer.to_string(), "GGGG");
617
618        mer.poly_t();
619        assert_eq!(mer.to_string(), "TTTT");
620    }
621
622    #[test]
623    fn test_equality() {
624        let a: MerDna = "ACGT".parse().unwrap();
625        let b: MerDna = "ACGT".parse().unwrap();
626        let c: MerDna = "ACGA".parse().unwrap();
627        assert_eq!(a, b);
628        assert_ne!(a, c);
629    }
630
631    #[test]
632    fn test_from_bytes() {
633        // For k=4 (8 bits), "ACGT" in 2-bit encoding from LSB:
634        // Position 0 (T) = 11, Position 1 (G) = 10, Position 2 (C) = 01, Position 3 (A) = 00
635        // = 0b00_01_10_11 = 0x1B
636        let mer = MerDna::from_bytes(&[0x1B], 4);
637        assert_eq!(mer.to_string(), "ACGT");
638    }
639
640    #[test]
641    fn test_long_kmer() {
642        // Test with k=33 (requires 2 words)
643        let seq = "ACGTACGTACGTACGTACGTACGTACGTACGTA";
644        let mer: MerDna = seq.parse().unwrap();
645        assert_eq!(mer.k(), 33);
646        assert_eq!(mer.to_string(), seq);
647
648        // Test reverse complement involution
649        let rc2 = mer.get_reverse_complement().get_reverse_complement();
650        assert_eq!(mer, rc2);
651    }
652
653    #[test]
654    fn test_word_reverse_complement_basic() {
655        // All A's (0x0000) -> All T's (0xFFFF...)
656        assert_eq!(word_reverse_complement(0), u64::MAX);
657    }
658}