use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LengthStats {
distribution: HashMap<usize, u64>,
min_length: usize,
max_length: usize,
total_bases: u64,
total_reads: u64,
}
impl Default for LengthStats {
fn default() -> Self {
Self::new()
}
}
impl LengthStats {
pub fn new() -> Self {
Self {
distribution: HashMap::new(),
min_length: usize::MAX,
max_length: 0,
total_bases: 0,
total_reads: 0,
}
}
#[inline]
pub fn update(&mut self, len: usize) {
if len == 0 {
return;
}
*self.distribution.entry(len).or_insert(0) += 1;
self.min_length = self.min_length.min(len);
self.max_length = self.max_length.max(len);
self.total_bases += len as u64;
self.total_reads += 1;
}
pub fn mean_length(&self) -> f64 {
if self.total_reads == 0 {
0.0
} else {
self.total_bases as f64 / self.total_reads as f64
}
}
pub fn min_length(&self) -> usize {
if self.total_reads == 0 {
0
} else {
self.min_length
}
}
pub fn max_length(&self) -> usize {
self.max_length
}
pub fn total_bases(&self) -> u64 {
self.total_bases
}
pub fn total_reads(&self) -> u64 {
self.total_reads
}
pub fn distribution(&self) -> &HashMap<usize, u64> {
&self.distribution
}
pub fn n50(&self) -> usize {
if self.total_reads == 0 {
return 0;
}
let mut lengths: Vec<(usize, u64)> = self.distribution.iter()
.map(|(&len, &count)| (len, count))
.collect();
lengths.sort_by(|a, b| b.0.cmp(&a.0));
let half_bases = self.total_bases / 2;
let mut cumulative_bases: u64 = 0;
for (len, count) in lengths {
cumulative_bases += (len as u64) * count;
if cumulative_bases >= half_bases {
return len;
}
}
0
}
pub fn n90(&self) -> usize {
self.nx(90)
}
pub fn nx(&self, x: u8) -> usize {
if self.total_reads == 0 || x == 0 || x > 100 {
return 0;
}
let mut lengths: Vec<(usize, u64)> = self.distribution.iter()
.map(|(&len, &count)| (len, count))
.collect();
lengths.sort_by(|a, b| b.0.cmp(&a.0));
let target_bases = (self.total_bases as f64 * (x as f64 / 100.0)) as u64;
let mut cumulative_bases: u64 = 0;
for (len, count) in lengths {
cumulative_bases += (len as u64) * count;
if cumulative_bases >= target_bases {
return len;
}
}
0
}
pub fn median_length(&self) -> usize {
if self.total_reads == 0 {
return 0;
}
let mut lengths: Vec<(usize, u64)> = self.distribution.iter()
.map(|(&len, &count)| (len, count))
.collect();
lengths.sort_by_key(|&(len, _)| len);
let half = self.total_reads / 2;
let mut cumulative: u64 = 0;
for (len, count) in lengths {
cumulative += count;
if cumulative > half {
return len;
}
}
0
}
pub fn merge(&mut self, other: &LengthStats) {
for (&len, &count) in &other.distribution {
*self.distribution.entry(len).or_insert(0) += count;
}
if other.total_reads > 0 {
self.min_length = self.min_length.min(other.min_length);
self.max_length = self.max_length.max(other.max_length);
}
self.total_bases += other.total_bases;
self.total_reads += other.total_reads;
}
pub fn percentile(&self, p: f64) -> usize {
if self.total_reads == 0 || !(0.0..=100.0).contains(&p) {
return 0;
}
let mut lengths: Vec<(usize, u64)> = self.distribution.iter()
.map(|(&len, &count)| (len, count))
.collect();
lengths.sort_by_key(|&(len, _)| len);
let target = ((self.total_reads as f64 * p) / 100.0) as u64;
let mut cumulative: u64 = 0;
for (len, count) in lengths {
cumulative += count;
if cumulative >= target {
return len;
}
}
self.max_length
}
pub fn from_raw(
distribution: HashMap<usize, u64>,
min_length: usize,
max_length: usize,
total_bases: u64,
total_reads: u64,
) -> Self {
Self {
distribution,
min_length,
max_length,
total_bases,
total_reads,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_length_stats_new() {
let ls = LengthStats::new();
assert_eq!(ls.total_reads(), 0);
assert_eq!(ls.total_bases(), 0);
assert_eq!(ls.min_length(), 0);
assert_eq!(ls.max_length(), 0);
}
#[test]
fn test_length_stats_update_single() {
let mut ls = LengthStats::new();
ls.update(100);
assert_eq!(ls.total_reads(), 1);
assert_eq!(ls.total_bases(), 100);
assert_eq!(ls.min_length(), 100);
assert_eq!(ls.max_length(), 100);
assert!((ls.mean_length() - 100.0).abs() < 0.001);
}
#[test]
fn test_length_stats_update_multiple() {
let mut ls = LengthStats::new();
ls.update(100);
ls.update(200);
ls.update(300);
assert_eq!(ls.total_reads(), 3);
assert_eq!(ls.total_bases(), 600);
assert_eq!(ls.min_length(), 100);
assert_eq!(ls.max_length(), 300);
assert!((ls.mean_length() - 200.0).abs() < 0.001);
}
#[test]
fn test_length_stats_update_zero() {
let mut ls = LengthStats::new();
ls.update(0);
assert_eq!(ls.total_reads(), 0);
assert_eq!(ls.total_bases(), 0);
}
#[test]
fn test_length_stats_n50() {
let mut ls = LengthStats::new();
ls.update(100);
ls.update(200);
ls.update(300);
ls.update(400);
ls.update(500);
assert_eq!(ls.n50(), 400);
}
#[test]
fn test_length_stats_n50_long_reads() {
let mut ls = LengthStats::new();
for _ in 0..10 {
ls.update(1000);
}
for _ in 0..5 {
ls.update(10000);
}
ls.update(50000);
assert_eq!(ls.n50(), 10000);
}
#[test]
fn test_length_stats_median() {
let mut ls = LengthStats::new();
ls.update(100);
ls.update(200);
ls.update(300);
assert_eq!(ls.median_length(), 200);
}
#[test]
fn test_length_stats_percentile() {
let mut ls = LengthStats::new();
for i in 1..=100 {
ls.update(i);
}
assert!(ls.percentile(25.0) >= 24 && ls.percentile(25.0) <= 26);
assert!(ls.percentile(75.0) >= 74 && ls.percentile(75.0) <= 76);
}
#[test]
fn test_length_stats_merge() {
let mut ls1 = LengthStats::new();
ls1.update(100);
ls1.update(200);
let mut ls2 = LengthStats::new();
ls2.update(300);
ls2.update(400);
ls1.merge(&ls2);
assert_eq!(ls1.total_reads(), 4);
assert_eq!(ls1.total_bases(), 1000);
assert_eq!(ls1.min_length(), 100);
assert_eq!(ls1.max_length(), 400);
}
#[test]
fn test_length_stats_merge_empty() {
let mut ls1 = LengthStats::new();
ls1.update(100);
let ls2 = LengthStats::new();
ls1.merge(&ls2);
assert_eq!(ls1.total_reads(), 1);
assert_eq!(ls1.min_length(), 100);
}
#[test]
fn test_length_stats_distribution() {
let mut ls = LengthStats::new();
ls.update(100);
ls.update(100);
ls.update(200);
let dist = ls.distribution();
assert_eq!(dist.get(&100), Some(&2));
assert_eq!(dist.get(&200), Some(&1));
}
#[test]
fn test_length_stats_nx() {
let mut ls = LengthStats::new();
ls.update(100);
ls.update(200);
ls.update(300);
ls.update(400);
ls.update(500);
assert!(ls.n90() <= ls.n50());
}
#[test]
fn test_length_stats_serialize() {
let mut ls = LengthStats::new();
ls.update(100);
ls.update(200);
let json = serde_json::to_string(&ls).unwrap();
let ls2: LengthStats = serde_json::from_str(&json).unwrap();
assert_eq!(ls.total_reads(), ls2.total_reads());
assert_eq!(ls.total_bases(), ls2.total_bases());
}
#[test]
fn test_length_stats_empty_n50() {
let ls = LengthStats::new();
assert_eq!(ls.n50(), 0);
assert_eq!(ls.median_length(), 0);
}
}