use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::HashMap;
fn serialize_seq_counts<S>(counts: &HashMap<Vec<u8>, u64>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeMap;
let mut map = serializer.serialize_map(Some(counts.len()))?;
for (key, value) in counts {
let key_str = String::from_utf8_lossy(key);
map.serialize_entry(&key_str, value)?;
}
map.end()
}
fn deserialize_seq_counts<'de, D>(deserializer: D) -> Result<HashMap<Vec<u8>, u64>, D::Error>
where
D: Deserializer<'de>,
{
let string_map: HashMap<String, u64> = HashMap::deserialize(deserializer)?;
Ok(string_map
.into_iter()
.map(|(k, v)| (k.into_bytes(), v))
.collect())
}
const DEFAULT_SAMPLING_RATE: u32 = 20;
const DEFAULT_PREFIX_LENGTH: usize = 32;
const DEFAULT_TOP_N: usize = 20;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KmerStats {
#[serde(serialize_with = "serialize_seq_counts", deserialize_with = "deserialize_seq_counts")]
sequence_counts: HashMap<Vec<u8>, u64>,
sampling_rate: u32,
prefix_length: usize,
reads_sampled: u64,
enabled: bool,
}
impl Default for KmerStats {
fn default() -> Self {
Self::new()
}
}
impl KmerStats {
pub fn new() -> Self {
Self::with_sampling_rate(DEFAULT_SAMPLING_RATE)
}
pub fn disabled() -> Self {
Self {
sequence_counts: HashMap::new(),
sampling_rate: DEFAULT_SAMPLING_RATE,
prefix_length: DEFAULT_PREFIX_LENGTH,
reads_sampled: 0,
enabled: false,
}
}
pub fn with_sampling_rate(sampling_rate: u32) -> Self {
Self {
sequence_counts: HashMap::new(),
sampling_rate: sampling_rate.max(1), prefix_length: DEFAULT_PREFIX_LENGTH,
reads_sampled: 0,
enabled: true,
}
}
pub fn with_settings(sample_size: u64, prefix_length: usize) -> Self {
Self {
sequence_counts: HashMap::new(),
sampling_rate: 1, prefix_length,
reads_sampled: 0,
enabled: sample_size > 0,
}
}
#[inline]
pub fn is_enabled(&self) -> bool {
self.enabled
}
#[inline]
pub fn update(&mut self, seq: &[u8], read_count: u64) {
if !self.enabled {
return;
}
if (read_count.wrapping_sub(1)) % (self.sampling_rate as u64) != 0 {
return;
}
if seq.is_empty() {
return;
}
let prefix_len = seq.len().min(self.prefix_length);
let prefix = &seq[..prefix_len];
let normalized: Vec<u8> = prefix.iter().map(|&b| b.to_ascii_uppercase()).collect();
*self.sequence_counts.entry(normalized).or_insert(0) += 1;
self.reads_sampled += 1;
}
pub fn top_sequences(&self, n: usize) -> Vec<(Vec<u8>, u64)> {
let mut sorted: Vec<(Vec<u8>, u64)> = self
.sequence_counts
.iter()
.map(|(k, &v)| (k.clone(), v))
.collect();
sorted.sort_by(|a, b| b.1.cmp(&a.1));
sorted.truncate(n);
sorted
}
pub fn top_overrepresented(&self) -> Vec<(Vec<u8>, u64)> {
self.top_sequences(DEFAULT_TOP_N)
}
pub fn unique_count(&self) -> usize {
self.sequence_counts.len()
}
pub fn reads_processed(&self) -> u64 {
self.reads_sampled
}
pub fn reads_sampled(&self) -> u64 {
self.reads_sampled
}
pub fn contains(&self, prefix: &[u8]) -> bool {
let normalized: Vec<u8> = prefix.iter().map(|&b| b.to_ascii_uppercase()).collect();
self.sequence_counts.contains_key(&normalized)
}
pub fn get_count(&self, prefix: &[u8]) -> u64 {
let normalized: Vec<u8> = prefix.iter().map(|&b| b.to_ascii_uppercase()).collect();
*self.sequence_counts.get(&normalized).unwrap_or(&0)
}
pub fn overrepresentation_percent(&self, prefix: &[u8]) -> Option<f64> {
if self.reads_sampled == 0 {
return None;
}
let count = self.get_count(prefix);
Some((count as f64 / self.reads_sampled as f64) * 100.0)
}
pub fn sequences_above_threshold(&self, threshold_percent: f64) -> Vec<(Vec<u8>, u64, f64)> {
if self.reads_sampled == 0 {
return Vec::new();
}
let threshold_count = (self.reads_sampled as f64 * threshold_percent / 100.0) as u64;
let mut result: Vec<(Vec<u8>, u64, f64)> = self
.sequence_counts
.iter()
.filter(|(_, &count)| count >= threshold_count)
.map(|(seq, &count)| {
let percent = (count as f64 / self.reads_sampled as f64) * 100.0;
(seq.clone(), count, percent)
})
.collect();
result.sort_by(|a, b| b.1.cmp(&a.1));
result
}
pub fn merge(&mut self, other: &KmerStats) {
for (seq, &count) in &other.sequence_counts {
*self.sequence_counts.entry(seq.clone()).or_insert(0) += count;
}
self.reads_sampled += other.reads_sampled;
}
pub fn is_sampling(&self) -> bool {
self.enabled
}
pub fn sampling_rate(&self) -> u32 {
self.sampling_rate
}
#[deprecated(note = "Use sampling_rate() instead")]
pub fn sample_size(&self) -> u64 {
0 }
pub fn prefix_length(&self) -> usize {
self.prefix_length
}
}
const FIVEMER_ARRAY_SIZE: usize = 1024;
const BASE_TO_BITS: [u8; 256] = {
let mut table = [255u8; 256];
table[b'A' as usize] = 0;
table[b'a' as usize] = 0;
table[b'C' as usize] = 1;
table[b'c' as usize] = 1;
table[b'G' as usize] = 2;
table[b'g' as usize] = 2;
table[b'T' as usize] = 3;
table[b't' as usize] = 3;
table
};
fn decode_fivemer(index: usize) -> [u8; 5] {
const BASES: [u8; 4] = [b'A', b'C', b'G', b'T'];
let mut result = [0u8; 5];
let mut idx = index;
for i in (0..5).rev() {
result[i] = BASES[idx & 3];
idx >>= 2;
}
result
}
mod fivemer_array_serde {
use super::FIVEMER_ARRAY_SIZE;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S>(array: &[u64; FIVEMER_ARRAY_SIZE], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
array.as_slice().serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<[u64; FIVEMER_ARRAY_SIZE], D::Error>
where
D: Deserializer<'de>,
{
let vec = Vec::<u64>::deserialize(deserializer)?;
if vec.len() != FIVEMER_ARRAY_SIZE {
return Err(serde::de::Error::custom(format!(
"expected {} elements, got {}",
FIVEMER_ARRAY_SIZE,
vec.len()
)));
}
let mut array = [0u64; FIVEMER_ARRAY_SIZE];
array.copy_from_slice(&vec);
Ok(array)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FiveMerStats {
#[serde(with = "fivemer_array_serde")]
counts: [u64; FIVEMER_ARRAY_SIZE],
total_kmers: u64,
}
impl Default for FiveMerStats {
fn default() -> Self {
Self::new()
}
}
impl FiveMerStats {
pub fn new() -> Self {
Self {
counts: [0u64; FIVEMER_ARRAY_SIZE],
total_kmers: 0,
}
}
#[inline]
fn encode_fivemer(seq: &[u8]) -> Option<usize> {
if seq.len() < 5 {
return None;
}
let mut index = 0usize;
for &base in &seq[..5] {
let bits = BASE_TO_BITS[base as usize];
if bits == 255 {
return None; }
index = (index << 2) | (bits as usize);
}
Some(index)
}
#[inline]
pub fn update(&mut self, seq: &[u8]) {
if seq.len() < 5 {
return;
}
for window in seq.windows(5) {
if let Some(index) = Self::encode_fivemer(window) {
self.counts[index] += 1;
self.total_kmers += 1;
}
}
}
#[inline]
pub fn get_count(&self, kmer: &[u8]) -> u64 {
Self::encode_fivemer(kmer)
.map(|idx| self.counts[idx])
.unwrap_or(0)
}
#[inline]
pub fn total_kmers(&self) -> u64 {
self.total_kmers
}
pub fn frequency(&self, kmer: &[u8]) -> f64 {
if self.total_kmers == 0 {
return 0.0;
}
self.get_count(kmer) as f64 / self.total_kmers as f64
}
pub fn top_kmers(&self, n: usize) -> Vec<([u8; 5], u64)> {
let mut indexed: Vec<(usize, u64)> = self
.counts
.iter()
.enumerate()
.filter(|(_, &count)| count > 0)
.map(|(idx, &count)| (idx, count))
.collect();
indexed.sort_by(|a, b| b.1.cmp(&a.1));
indexed.truncate(n);
indexed
.into_iter()
.map(|(idx, count)| (decode_fivemer(idx), count))
.collect()
}
pub fn unique_count(&self) -> usize {
self.counts.iter().filter(|&&c| c > 0).count()
}
pub fn merge(&mut self, other: &FiveMerStats) {
for (i, &count) in other.counts.iter().enumerate() {
self.counts[i] += count;
}
self.total_kmers += other.total_kmers;
}
#[inline]
pub fn counts(&self) -> &[u64; FIVEMER_ARRAY_SIZE] {
&self.counts
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kmer_stats_new() {
let ks = KmerStats::new();
assert_eq!(ks.reads_processed(), 0);
assert_eq!(ks.unique_count(), 0);
assert!(ks.is_sampling());
}
#[test]
fn test_kmer_stats_update_single() {
let mut ks = KmerStats::new();
ks.update(b"ATGCATGCATGCATGC", 1);
assert_eq!(ks.reads_processed(), 1);
assert_eq!(ks.unique_count(), 1);
}
#[test]
fn test_kmer_stats_update_duplicate() {
let mut ks = KmerStats::with_sampling_rate(1);
ks.update(b"ATGCATGC", 1);
ks.update(b"ATGCATGC", 2);
assert_eq!(ks.reads_processed(), 2);
assert_eq!(ks.unique_count(), 1);
assert_eq!(ks.get_count(b"ATGCATGC"), 2);
}
#[test]
fn test_kmer_stats_case_insensitive() {
let mut ks = KmerStats::with_sampling_rate(1);
ks.update(b"ATGC", 1);
ks.update(b"atgc", 2);
assert_eq!(ks.unique_count(), 1);
assert_eq!(ks.get_count(b"ATGC"), 2);
assert_eq!(ks.get_count(b"atgc"), 2);
}
#[test]
fn test_kmer_stats_prefix_truncation() {
let mut ks = KmerStats::with_settings(100, 8);
ks.update(b"ATGCATGCATGCATGC", 1);
ks.update(b"ATGCATGCNNNNNNNN", 2);
assert_eq!(ks.unique_count(), 1);
assert_eq!(ks.get_count(b"ATGCATGC"), 2);
}
#[test]
fn test_kmer_stats_sampling_rate() {
let mut ks = KmerStats::with_sampling_rate(5);
for i in 1..=20 {
ks.update(format!("SEQ{:05}", i).as_bytes(), i);
}
assert_eq!(ks.reads_processed(), 4);
assert!(ks.is_sampling()); }
#[test]
fn test_kmer_stats_top_sequences() {
let mut ks = KmerStats::with_sampling_rate(1);
ks.update(b"AAAA", 1);
ks.update(b"AAAA", 2);
ks.update(b"AAAA", 3);
ks.update(b"TTTT", 4);
ks.update(b"TTTT", 5);
ks.update(b"GGGG", 6);
let top = ks.top_sequences(2);
assert_eq!(top.len(), 2);
assert_eq!(top[0].0, b"AAAA");
assert_eq!(top[0].1, 3);
assert_eq!(top[1].0, b"TTTT");
assert_eq!(top[1].1, 2);
}
#[test]
fn test_kmer_stats_overrepresentation_percent() {
let mut ks = KmerStats::with_sampling_rate(1);
for i in 1..=100 {
if i <= 10 {
ks.update(b"COMMON", i);
} else {
ks.update(format!("UNIQ{:03}", i).as_bytes(), i);
}
}
let percent = ks.overrepresentation_percent(b"COMMON").unwrap();
assert!((percent - 10.0).abs() < 0.1);
}
#[test]
fn test_kmer_stats_sequences_above_threshold() {
let mut ks = KmerStats::with_sampling_rate(1);
for i in 1..=100 {
if i <= 20 {
ks.update(b"COMMON", i);
} else {
ks.update(format!("UNIQ{:03}", i).as_bytes(), i);
}
}
let above = ks.sequences_above_threshold(15.0);
assert_eq!(above.len(), 1);
assert_eq!(above[0].0, b"COMMON");
}
#[test]
fn test_kmer_stats_merge() {
let mut ks1 = KmerStats::with_sampling_rate(1);
ks1.update(b"AAAA", 1);
ks1.update(b"TTTT", 2);
let mut ks2 = KmerStats::with_sampling_rate(1);
ks2.update(b"AAAA", 1);
ks2.update(b"GGGG", 2);
ks1.merge(&ks2);
assert_eq!(ks1.reads_processed(), 4);
assert_eq!(ks1.get_count(b"AAAA"), 2);
assert_eq!(ks1.get_count(b"TTTT"), 1);
assert_eq!(ks1.get_count(b"GGGG"), 1);
}
#[test]
fn test_kmer_stats_empty_sequence() {
let mut ks = KmerStats::new();
ks.update(b"", 1);
assert_eq!(ks.reads_processed(), 0);
assert_eq!(ks.unique_count(), 0);
}
#[test]
fn test_kmer_stats_contains() {
let mut ks = KmerStats::new();
ks.update(b"ATGC", 1);
assert!(ks.contains(b"ATGC"));
assert!(ks.contains(b"atgc"));
assert!(!ks.contains(b"GCTA"));
}
#[test]
fn test_kmer_stats_serialize() {
let mut ks = KmerStats::with_sampling_rate(1);
ks.update(b"ATGC", 1);
ks.update(b"GCTA", 2);
let json = serde_json::to_string(&ks).unwrap();
let ks2: KmerStats = serde_json::from_str(&json).unwrap();
assert_eq!(ks.reads_processed(), ks2.reads_processed());
assert_eq!(ks.unique_count(), ks2.unique_count());
}
#[test]
fn test_kmer_stats_empty_overrepresentation() {
let ks = KmerStats::new();
assert!(ks.overrepresentation_percent(b"ATGC").is_none());
}
#[test]
fn test_fivemer_stats_new() {
let fs = FiveMerStats::new();
assert_eq!(fs.total_kmers(), 0);
assert_eq!(fs.unique_count(), 0);
}
#[test]
fn test_fivemer_stats_update() {
let mut fs = FiveMerStats::new();
fs.update(b"ACGTA");
assert_eq!(fs.total_kmers(), 1);
assert_eq!(fs.unique_count(), 1);
assert_eq!(fs.get_count(b"ACGTA"), 1);
}
#[test]
fn test_fivemer_stats_sliding_window() {
let mut fs = FiveMerStats::new();
fs.update(b"ACGTAC");
assert_eq!(fs.total_kmers(), 2);
assert_eq!(fs.get_count(b"ACGTA"), 1);
assert_eq!(fs.get_count(b"CGTAC"), 1);
}
#[test]
fn test_fivemer_stats_case_insensitive() {
let mut fs = FiveMerStats::new();
fs.update(b"ACGTA");
fs.update(b"acgta");
assert_eq!(fs.get_count(b"ACGTA"), 2);
assert_eq!(fs.get_count(b"acgta"), 2);
}
#[test]
fn test_fivemer_stats_with_n() {
let mut fs = FiveMerStats::new();
fs.update(b"ACNTA");
assert_eq!(fs.total_kmers(), 0);
}
#[test]
fn test_fivemer_stats_short_sequence() {
let mut fs = FiveMerStats::new();
fs.update(b"ACGT");
assert_eq!(fs.total_kmers(), 0);
}
#[test]
fn test_fivemer_stats_top_kmers() {
let mut fs = FiveMerStats::new();
for _ in 0..10 {
fs.update(b"AAAAA");
}
for _ in 0..5 {
fs.update(b"CCCCC");
}
fs.update(b"GGGGG");
let top = fs.top_kmers(2);
assert_eq!(top.len(), 2);
assert_eq!(&top[0].0, b"AAAAA");
assert_eq!(top[0].1, 10);
assert_eq!(&top[1].0, b"CCCCC");
assert_eq!(top[1].1, 5);
}
#[test]
fn test_fivemer_stats_merge() {
let mut fs1 = FiveMerStats::new();
fs1.update(b"ACGTA");
let mut fs2 = FiveMerStats::new();
fs2.update(b"ACGTA");
fs2.update(b"TGCAT");
fs1.merge(&fs2);
assert_eq!(fs1.total_kmers(), 3);
assert_eq!(fs1.get_count(b"ACGTA"), 2);
assert_eq!(fs1.get_count(b"TGCAT"), 1);
}
#[test]
fn test_fivemer_encoding_decoding() {
let kmer = b"ACGTA";
let index = FiveMerStats::encode_fivemer(kmer).unwrap();
let decoded = decode_fivemer(index);
assert_eq!(&decoded, kmer);
}
#[test]
fn test_fivemer_stats_frequency() {
let mut fs = FiveMerStats::new();
fs.update(b"AAAAA"); fs.update(b"CCCCC");
assert!((fs.frequency(b"AAAAA") - 0.5).abs() < 0.001);
assert!((fs.frequency(b"CCCCC") - 0.5).abs() < 0.001);
}
}