1use std::cmp::Ordering;
2use std::fmt;
3use std::hash::{Hash, Hasher};
4use std::str::FromStr;
5
6use crate::error::Error;
7
8const BASES_PER_WORD: usize = 32;
10
11#[inline]
15pub fn encode_base(c: u8) -> Option<u8> {
16 match c {
17 b'A' | b'a' => Some(0),
18 b'C' | b'c' => Some(1),
19 b'G' | b'g' => Some(2),
20 b'T' | b't' => Some(3),
21 _ => None,
22 }
23}
24
25#[inline]
27pub fn decode_base(code: u8) -> u8 {
28 match code & 0x3 {
29 0 => b'A',
30 1 => b'C',
31 2 => b'G',
32 3 => b'T',
33 _ => unreachable!(),
34 }
35}
36
37#[inline]
39pub fn complement_code(code: u8) -> u8 {
40 code ^ 0x3
41}
42
43#[inline]
45pub fn words_for_k(k: usize) -> usize {
46 k.div_ceil(BASES_PER_WORD)
47}
48
49fn word_reverse_complement(mut word: u64) -> u64 {
53 word = !word;
55 word = ((word >> 2) & 0x3333_3333_3333_3333) | ((word & 0x3333_3333_3333_3333) << 2);
58 word = ((word >> 4) & 0x0F0F_0F0F_0F0F_0F0F) | ((word & 0x0F0F_0F0F_0F0F_0F0F) << 4);
60 word.swap_bytes()
62}
63
64#[derive(Clone)]
84pub struct MerDna {
85 words: Vec<u64>,
87 k: usize,
89}
90
91impl MerDna {
92 pub fn new(k: usize) -> Self {
94 Self {
95 words: vec![0u64; words_for_k(k)],
96 k,
97 }
98 }
99
100 pub fn from_words(words: Vec<u64>, k: usize) -> Self {
104 debug_assert!(words.len() == words_for_k(k));
105 let mut mer = Self { words, k };
106 mer.clean_high_bits();
107 mer
108 }
109
110 pub fn from_bytes(bytes: &[u8], k: usize) -> Self {
114 let n_words = words_for_k(k);
115 let mut words = vec![0u64; n_words];
116
117 for (i, &byte) in bytes.iter().enumerate() {
118 let word_idx = i / 8;
119 let byte_idx = i % 8;
120 if word_idx < n_words {
121 words[word_idx] |= (byte as u64) << (byte_idx * 8);
122 }
123 }
124
125 let mut mer = Self { words, k };
126 mer.clean_high_bits();
127 mer
128 }
129
130 #[inline]
132 pub fn k(&self) -> usize {
133 self.k
134 }
135
136 #[inline]
138 pub fn words(&self) -> &[u64] {
139 &self.words
140 }
141
142 pub fn get_base(&self, i: usize) -> u8 {
147 assert!(i < self.k, "base index {i} out of range for k={}", self.k);
148 let word_idx = i / BASES_PER_WORD;
149 let bit_offset = (i % BASES_PER_WORD) * 2;
150 ((self.words[word_idx] >> bit_offset) & 0x3) as u8
151 }
152
153 pub fn set_base(&mut self, i: usize, base_code: u8) {
158 assert!(i < self.k, "base index {i} out of range for k={}", self.k);
159 assert!(base_code < 4, "invalid base code: {base_code}");
160 let word_idx = i / BASES_PER_WORD;
161 let bit_offset = (i % BASES_PER_WORD) * 2;
162 self.words[word_idx] &= !(0x3u64 << bit_offset);
163 self.words[word_idx] |= (base_code as u64) << bit_offset;
164 }
165
166 pub fn shift_left(&mut self, base: u8) -> Option<u8> {
170 let code = encode_base(base)?;
171 let old_high = self.get_base(self.k - 1);
172
173 let n = self.words.len();
175 for i in (1..n).rev() {
176 self.words[i] = (self.words[i] << 2) | (self.words[i - 1] >> 62);
177 }
178 self.words[0] = (self.words[0] << 2) | (code as u64);
179 self.clean_high_bits();
180
181 Some(decode_base(old_high))
182 }
183
184 pub fn shift_right(&mut self, base: u8) -> Option<u8> {
188 let code = encode_base(base)?;
189 let old_low = self.get_base(0);
190
191 let n = self.words.len();
193 for i in 0..n - 1 {
194 self.words[i] = (self.words[i] >> 2) | (self.words[i + 1] << 62);
195 }
196 self.words[n - 1] >>= 2;
197
198 let high_pos = self.k - 1;
200 let word_idx = high_pos / BASES_PER_WORD;
201 let bit_offset = (high_pos % BASES_PER_WORD) * 2;
202 self.words[word_idx] |= (code as u64) << bit_offset;
203
204 Some(decode_base(old_low))
205 }
206
207 pub fn get_reverse_complement(&self) -> MerDna {
209 let n = self.words.len();
210
211 if n == 1 {
212 let mut result = vec![0u64; 1];
213 result[0] = word_reverse_complement(self.words[0]) >> (64 - self.k * 2);
214 let mut mer = MerDna {
215 words: result,
216 k: self.k,
217 };
218 mer.clean_high_bits();
219 return mer;
220 }
221
222 let mut result = MerDna::new(self.k);
225 for i in 0..self.k {
226 let base = self.get_base(i);
227 result.set_base(self.k - 1 - i, complement_code(base));
228 }
229 result
230 }
231
232 pub fn reverse_complement(&mut self) {
234 *self = self.get_reverse_complement();
235 }
236
237 pub fn get_canonical(&self) -> MerDna {
239 let rc = self.get_reverse_complement();
240 if *self <= rc { self.clone() } else { rc }
241 }
242
243 pub fn canonicalize(&mut self) {
245 let rc = self.get_reverse_complement();
246 if rc < *self {
247 *self = rc;
248 }
249 }
250
251 pub fn is_homopolymer(&self) -> bool {
253 if self.k == 0 {
254 return true;
255 }
256 let base = self.get_base(0);
257 (1..self.k).all(|i| self.get_base(i) == base)
258 }
259
260 pub fn poly_a(&mut self) {
262 self.words.fill(0);
263 }
264
265 pub fn poly_c(&mut self) {
267 self.fill_with_code(1);
268 }
269
270 pub fn poly_g(&mut self) {
272 self.fill_with_code(2);
273 }
274
275 pub fn poly_t(&mut self) {
277 self.fill_with_code(3);
278 }
279
280 fn fill_with_code(&mut self, code: u8) {
282 let pattern = match code {
283 0 => 0x0000_0000_0000_0000u64,
284 1 => 0x5555_5555_5555_5555u64,
285 2 => 0xAAAA_AAAA_AAAA_AAAAu64,
286 3 => 0xFFFF_FFFF_FFFF_FFFFu64,
287 _ => unreachable!(),
288 };
289 self.words.fill(pattern);
290 self.clean_high_bits();
291 }
292
293 fn clean_high_bits(&mut self) {
295 if self.k == 0 {
296 return;
297 }
298 let used_bits = self.k * 2;
299 let total_bits = self.words.len() * 64;
300 if used_bits < total_bits {
301 let last = self.words.len() - 1;
302 let bits_in_last = used_bits - last * 64;
303 self.words[last] &= (1u64 << bits_in_last) - 1;
304 }
305 }
306}
307
308impl fmt::Debug for MerDna {
309 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
310 write!(f, "MerDna(\"{}\")", self)
311 }
312}
313
314impl fmt::Display for MerDna {
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 for i in (0..self.k).rev() {
317 let code = self.get_base(i);
318 f.write_str(std::str::from_utf8(&[decode_base(code)]).unwrap())?;
319 }
320 Ok(())
321 }
322}
323
324impl FromStr for MerDna {
325 type Err = Error;
326
327 fn from_str(s: &str) -> Result<Self, Error> {
328 let k = s.len();
329 if k == 0 {
330 return Err(Error::InvalidKmer("empty k-mer string".to_string()));
331 }
332
333 let mut mer = MerDna::new(k);
334 let bytes = s.as_bytes();
335
336 for (i, &ch) in bytes.iter().enumerate() {
337 let code = encode_base(ch).ok_or_else(|| {
338 Error::InvalidKmer(format!("invalid base '{}' at position {i}", ch as char))
339 })?;
340 let pos = k - 1 - i;
342 mer.set_base(pos, code);
343 }
344
345 Ok(mer)
346 }
347}
348
349impl PartialEq for MerDna {
350 fn eq(&self, other: &Self) -> bool {
351 self.k == other.k && self.words == other.words
352 }
353}
354
355impl Eq for MerDna {}
356
357impl PartialOrd for MerDna {
358 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
359 Some(self.cmp(other))
360 }
361}
362
363impl Ord for MerDna {
364 fn cmp(&self, other: &Self) -> Ordering {
365 assert_eq!(
367 self.k, other.k,
368 "cannot compare k-mers of different lengths"
369 );
370 for i in (0..self.words.len()).rev() {
371 match self.words[i].cmp(&other.words[i]) {
372 Ordering::Equal => continue,
373 ord => return ord,
374 }
375 }
376 Ordering::Equal
377 }
378}
379
380impl Hash for MerDna {
381 fn hash<H: Hasher>(&self, state: &mut H) {
382 self.k.hash(state);
383 self.words.hash(state);
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_encode_decode_bases() {
393 for (ch, code) in [(b'A', 0), (b'C', 1), (b'G', 2), (b'T', 3)] {
394 assert_eq!(encode_base(ch), Some(code));
395 assert_eq!(decode_base(code), ch);
396 }
397 for (ch, code) in [(b'a', 0), (b'c', 1), (b'g', 2), (b't', 3)] {
399 assert_eq!(encode_base(ch), Some(code));
400 }
401 assert_eq!(encode_base(b'N'), None);
402 assert_eq!(encode_base(b'X'), None);
403 }
404
405 #[test]
406 fn test_complement_code() {
407 assert_eq!(complement_code(0), 3); assert_eq!(complement_code(1), 2); assert_eq!(complement_code(2), 1); assert_eq!(complement_code(3), 0); }
412
413 #[test]
414 fn test_new_mer() {
415 let mer = MerDna::new(4);
416 assert_eq!(mer.k(), 4);
417 assert_eq!(mer.to_string(), "AAAA");
418 }
419
420 #[test]
421 fn test_from_str_basic() {
422 let mer: MerDna = "ACGT".parse().unwrap();
423 assert_eq!(mer.k(), 4);
424 assert_eq!(mer.to_string(), "ACGT");
425 }
426
427 #[test]
428 fn test_from_str_lowercase() {
429 let mer: MerDna = "acgt".parse().unwrap();
430 assert_eq!(mer.to_string(), "ACGT");
431 }
432
433 #[test]
434 fn test_from_str_single_base() {
435 for (ch, expected) in [("A", "A"), ("C", "C"), ("G", "G"), ("T", "T")] {
436 let mer: MerDna = ch.parse().unwrap();
437 assert_eq!(mer.to_string(), expected);
438 }
439 }
440
441 #[test]
442 fn test_from_str_invalid() {
443 assert!("ACGN".parse::<MerDna>().is_err());
444 assert!("".parse::<MerDna>().is_err());
445 assert!("ACGX".parse::<MerDna>().is_err());
446 }
447
448 #[test]
449 fn test_roundtrip_various_lengths() {
450 let seqs = [
451 "A",
452 "AC",
453 "ACG",
454 "ACGT",
455 "ACGTACGT",
456 "ACGTACGTACGTACGTACGTACGTACGTACGT", "ACGTACGTACGTACGTACGTACGTACGTACGTA", ];
459 for seq in seqs {
460 let mer: MerDna = seq.parse().unwrap();
461 assert_eq!(mer.to_string(), seq, "roundtrip failed for {seq}");
462 }
463 }
464
465 #[test]
466 fn test_get_set_base() {
467 let mut mer: MerDna = "ACGT".parse().unwrap();
468 assert_eq!(mer.get_base(0), 3); assert_eq!(mer.get_base(1), 2); assert_eq!(mer.get_base(2), 1); assert_eq!(mer.get_base(3), 0); mer.set_base(0, 0); assert_eq!(mer.to_string(), "ACGA");
476 }
477
478 #[test]
479 fn test_reverse_complement_palindrome() {
480 let mer: MerDna = "ACGT".parse().unwrap();
482 let rc = mer.get_reverse_complement();
483 assert_eq!(rc.to_string(), "ACGT");
484 }
485
486 #[test]
487 fn test_reverse_complement_simple() {
488 let mer: MerDna = "AAAA".parse().unwrap();
489 let rc = mer.get_reverse_complement();
490 assert_eq!(rc.to_string(), "TTTT");
491 }
492
493 #[test]
494 fn test_reverse_complement_asymmetric() {
495 let mer: MerDna = "AACG".parse().unwrap();
496 let rc = mer.get_reverse_complement();
497 assert_eq!(rc.to_string(), "CGTT");
498 }
499
500 #[test]
501 fn test_reverse_complement_involution() {
502 let seqs = ["ACGT", "AAAA", "GCTA", "AACG", "TTTCCCGGGAAA"];
504 for seq in seqs {
505 let mer: MerDna = seq.parse().unwrap();
506 let rc2 = mer.get_reverse_complement().get_reverse_complement();
507 assert_eq!(mer, rc2, "RC involution failed for {seq}");
508 }
509 }
510
511 #[test]
512 fn test_canonical_already_canonical() {
513 let mer: MerDna = "AAAA".parse().unwrap();
514 let canonical = mer.get_canonical();
515 assert_eq!(canonical.to_string(), "AAAA"); }
517
518 #[test]
519 fn test_canonical_needs_rc() {
520 let mer: MerDna = "TTTT".parse().unwrap();
521 let canonical = mer.get_canonical();
522 assert_eq!(canonical.to_string(), "AAAA"); }
524
525 #[test]
526 fn test_canonical_palindrome() {
527 let mer: MerDna = "ACGT".parse().unwrap();
528 let canonical = mer.get_canonical();
529 assert_eq!(canonical.to_string(), "ACGT");
530 }
531
532 #[test]
533 fn test_canonical_idempotent() {
534 let seqs = ["ACGT", "TGCA", "AAAA", "CCCC", "AACG"];
535 for seq in seqs {
536 let mer: MerDna = seq.parse().unwrap();
537 let c1 = mer.get_canonical();
538 let c2 = c1.get_canonical();
539 assert_eq!(c1, c2, "canonical not idempotent for {seq}");
540 }
541 }
542
543 #[test]
544 fn test_canonicalize_in_place() {
545 let mut mer: MerDna = "TTTT".parse().unwrap();
546 mer.canonicalize();
547 assert_eq!(mer.to_string(), "AAAA");
548 }
549
550 #[test]
551 fn test_ordering() {
552 let a: MerDna = "AAAA".parse().unwrap();
553 let c: MerDna = "CCCC".parse().unwrap();
554 let g: MerDna = "GGGG".parse().unwrap();
555 let t: MerDna = "TTTT".parse().unwrap();
556 assert!(a < c);
557 assert!(c < g);
558 assert!(g < t);
559 }
560
561 #[test]
562 fn test_hash_consistency() {
563 use std::collections::HashMap;
564 let mer1: MerDna = "ACGT".parse().unwrap();
565 let mer2: MerDna = "ACGT".parse().unwrap();
566 let mut map = HashMap::new();
567 map.insert(mer1, 42);
568 assert_eq!(map.get(&mer2), Some(&42));
569 }
570
571 #[test]
572 fn test_shift_left() {
573 let mut mer: MerDna = "ACGT".parse().unwrap();
574 let out = mer.shift_left(b'A');
575 assert_eq!(out, Some(b'A'));
576 assert_eq!(mer.to_string(), "CGTA");
577 }
578
579 #[test]
580 fn test_shift_right() {
581 let mut mer: MerDna = "ACGT".parse().unwrap();
582 let out = mer.shift_right(b'A');
583 assert_eq!(out, Some(b'T'));
584 assert_eq!(mer.to_string(), "AACG");
585 }
586
587 #[test]
588 fn test_shift_invalid_base() {
589 let mut mer: MerDna = "ACGT".parse().unwrap();
590 assert_eq!(mer.shift_left(b'N'), None);
591 assert_eq!(mer.to_string(), "ACGT"); }
593
594 #[test]
595 fn test_homopolymer() {
596 let aaaa: MerDna = "AAAA".parse().unwrap();
597 assert!(aaaa.is_homopolymer());
598
599 let cccc: MerDna = "CCCC".parse().unwrap();
600 assert!(cccc.is_homopolymer());
601
602 let acgt: MerDna = "ACGT".parse().unwrap();
603 assert!(!acgt.is_homopolymer());
604 }
605
606 #[test]
607 fn test_poly_constructors() {
608 let mut mer = MerDna::new(4);
609 mer.poly_a();
610 assert_eq!(mer.to_string(), "AAAA");
611
612 mer.poly_c();
613 assert_eq!(mer.to_string(), "CCCC");
614
615 mer.poly_g();
616 assert_eq!(mer.to_string(), "GGGG");
617
618 mer.poly_t();
619 assert_eq!(mer.to_string(), "TTTT");
620 }
621
622 #[test]
623 fn test_equality() {
624 let a: MerDna = "ACGT".parse().unwrap();
625 let b: MerDna = "ACGT".parse().unwrap();
626 let c: MerDna = "ACGA".parse().unwrap();
627 assert_eq!(a, b);
628 assert_ne!(a, c);
629 }
630
631 #[test]
632 fn test_from_bytes() {
633 let mer = MerDna::from_bytes(&[0x1B], 4);
637 assert_eq!(mer.to_string(), "ACGT");
638 }
639
640 #[test]
641 fn test_long_kmer() {
642 let seq = "ACGTACGTACGTACGTACGTACGTACGTACGTA";
644 let mer: MerDna = seq.parse().unwrap();
645 assert_eq!(mer.k(), 33);
646 assert_eq!(mer.to_string(), seq);
647
648 let rc2 = mer.get_reverse_complement().get_reverse_complement();
650 assert_eq!(mer, rc2);
651 }
652
653 #[test]
654 fn test_word_reverse_complement_basic() {
655 assert_eq!(word_reverse_complement(0), u64::MAX);
657 }
658}