use rustc_hash::FxHashMap;
use std::collections::HashMap;
pub struct KmerCounter {
k: usize,
counts: FxHashMap<Vec<u8>, usize>,
}
impl KmerCounter {
pub fn new(k: usize) -> Self {
Self {
k,
counts: FxHashMap::default(),
}
}
pub fn count_sequence(&mut self, seq: &[u8]) {
if seq.len() < self.k {
return;
}
for i in 0..=(seq.len() - self.k) {
let kmer = &seq[i..i + self.k];
*self.counts.entry(kmer.to_vec()).or_insert(0) += 1;
}
}
pub fn get_count(&self, kmer: &[u8]) -> usize {
self.counts.get(kmer).copied().unwrap_or(0)
}
pub fn iter(&self) -> impl Iterator<Item = (&Vec<u8>, &usize)> {
self.counts.iter()
}
pub fn num_distinct_kmers(&self) -> usize {
self.counts.len()
}
pub fn total_count(&self) -> usize {
self.counts.values().sum()
}
pub fn clear(&mut self) {
self.counts.clear();
}
pub fn top_kmers(&self, top_k: usize) -> Vec<(Vec<u8>, usize)> {
let mut items: Vec<_> = self
.counts
.iter()
.map(|(kmer, &count)| (kmer.clone(), count))
.collect();
items.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
items.truncate(top_k);
items
}
}
pub struct KmerIterator<'a> {
seq: &'a [u8],
k: usize,
pos: usize,
}
impl<'a> KmerIterator<'a> {
pub fn new(seq: &'a [u8], k: usize) -> Self {
Self { seq, k, pos: 0 }
}
}
impl<'a> Iterator for KmerIterator<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
if self.pos + self.k <= self.seq.len() {
let kmer = &self.seq[self.pos..self.pos + self.k];
self.pos += 1;
Some(kmer)
} else {
None
}
}
}
pub fn reverse_complement_kmer(kmer: &[u8], comp_table: &[u8; 256]) -> Vec<u8> {
kmer.iter()
.rev()
.map(|&base| comp_table[base as usize])
.collect()
}
pub fn canonical_kmer(kmer: &[u8], comp_table: &[u8; 256]) -> Vec<u8> {
let rc = reverse_complement_kmer(kmer, comp_table);
if kmer < rc.as_slice() {
kmer.to_vec()
} else {
rc
}
}
pub fn kmer_frequency_distribution(counts: &FxHashMap<Vec<u8>, usize>) -> HashMap<usize, usize> {
let mut distribution = HashMap::new();
for &count in counts.values() {
*distribution.entry(count).or_insert(0) += 1;
}
distribution
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kmer_counter_basic() {
let seq = b"ACGTACGT";
let mut counter = KmerCounter::new(3);
counter.count_sequence(seq);
assert_eq!(counter.get_count(b"ACG"), 2);
assert_eq!(counter.get_count(b"CGT"), 2);
assert_eq!(counter.get_count(b"GTA"), 1);
assert_eq!(counter.get_count(b"TAC"), 1);
assert_eq!(counter.num_distinct_kmers(), 4);
assert_eq!(counter.total_count(), 6);
}
#[test]
fn test_kmer_counter_short_seq() {
let seq = b"AC"; let mut counter = KmerCounter::new(3);
counter.count_sequence(seq);
assert_eq!(counter.num_distinct_kmers(), 0);
assert_eq!(counter.total_count(), 0);
}
#[test]
fn test_kmer_counter_top_kmers() {
let seq = b"ACGTACGTACG";
let mut counter = KmerCounter::new(3);
counter.count_sequence(seq);
let top = counter.top_kmers(2);
assert_eq!(top.len(), 2);
assert!(top[0].1 >= top[1].1); }
#[test]
fn test_kmer_iterator() {
let seq = b"ACGT";
let kmers: Vec<_> = KmerIterator::new(seq, 2).collect();
assert_eq!(kmers.len(), 3); assert_eq!(kmers[0], b"AC");
assert_eq!(kmers[1], b"CG");
assert_eq!(kmers[2], b"GT");
}
#[test]
fn test_kmer_iterator_empty() {
let seq = b"A";
let kmers: Vec<_> = KmerIterator::new(seq, 3).collect();
assert_eq!(kmers.len(), 0); }
#[test]
fn test_reverse_complement_kmer() {
use crate::core::tables::LOOKUP_TABLES;
let kmer = b"ACGT";
let rc = reverse_complement_kmer(kmer, &LOOKUP_TABLES.comp);
assert_eq!(rc, b"ACGT");
let kmer2 = b"AAAA";
let rc2 = reverse_complement_kmer(kmer2, &LOOKUP_TABLES.comp);
assert_eq!(rc2, b"TTTT");
}
#[test]
fn test_canonical_kmer() {
use crate::core::tables::LOOKUP_TABLES;
let kmer1 = b"ACG";
let canonical1 = canonical_kmer(kmer1, &LOOKUP_TABLES.comp);
assert_eq!(canonical1, b"ACG");
let kmer2 = b"CGT";
let canonical2 = canonical_kmer(kmer2, &LOOKUP_TABLES.comp);
assert_eq!(canonical2, b"ACG");
assert_eq!(canonical1, canonical2);
}
#[test]
fn test_kmer_frequency_distribution() {
let mut counts = FxHashMap::default();
counts.insert(b"ACG".to_vec(), 1);
counts.insert(b"CGT".to_vec(), 2);
counts.insert(b"GTA".to_vec(), 2);
counts.insert(b"TAC".to_vec(), 3);
let dist = kmer_frequency_distribution(&counts);
assert_eq!(dist.get(&1), Some(&1)); assert_eq!(dist.get(&2), Some(&2)); assert_eq!(dist.get(&3), Some(&1)); }
#[test]
fn test_kmer_counter_clear() {
let seq = b"ACGT";
let mut counter = KmerCounter::new(2);
counter.count_sequence(seq);
assert!(counter.num_distinct_kmers() > 0);
counter.clear();
assert_eq!(counter.num_distinct_kmers(), 0);
assert_eq!(counter.total_count(), 0);
}
}