use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::{Metric, frac};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UmiMetric {
pub umi: String,
pub raw_observations: usize,
pub raw_observations_with_errors: usize,
pub unique_observations: usize,
pub fraction_raw_observations: f64,
pub fraction_unique_observations: f64,
}
impl UmiMetric {
#[must_use]
pub fn new(umi: String) -> Self {
Self {
umi,
raw_observations: 0,
raw_observations_with_errors: 0,
unique_observations: 0,
fraction_raw_observations: 0.0,
fraction_unique_observations: 0.0,
}
}
}
impl Default for UmiMetric {
fn default() -> Self {
Self::new(String::new())
}
}
impl Metric for UmiMetric {
fn metric_name() -> &'static str {
"UMI"
}
}
pub struct UmiCountTracker {
counts: HashMap<String, (usize, usize, usize)>,
}
impl UmiCountTracker {
#[must_use]
pub fn new() -> Self {
Self { counts: HashMap::new() }
}
pub fn record(&mut self, umi: &str, raw_count: usize, error_count: usize, is_unique: bool) {
if let Some(entry) = self.counts.get_mut(umi) {
entry.0 += raw_count;
entry.1 += error_count;
if is_unique {
entry.2 += 1;
}
} else {
self.counts.insert(umi.to_string(), (raw_count, error_count, usize::from(is_unique)));
}
}
#[must_use]
pub fn total_raw(&self) -> usize {
self.counts.values().map(|(raw, _, _)| raw).sum()
}
#[must_use]
pub fn total_unique(&self) -> usize {
self.counts.values().map(|(_, _, unique)| unique).sum()
}
pub(crate) fn iter(&self) -> impl Iterator<Item = (&str, usize, usize, usize)> {
self.counts.iter().map(|(umi, &(raw, errors, unique))| (umi.as_str(), raw, errors, unique))
}
#[must_use]
pub fn to_metrics(&self) -> Vec<UmiMetric> {
let total_raw = self.total_raw();
let total_unique = self.total_unique();
let mut metrics: Vec<_> = self
.iter()
.map(|(umi, raw, errors, unique)| UmiMetric {
umi: umi.to_string(),
raw_observations: raw,
raw_observations_with_errors: errors,
unique_observations: unique,
fraction_raw_observations: frac(raw, total_raw),
fraction_unique_observations: frac(unique, total_unique),
})
.collect();
metrics.sort_by(|a, b| a.umi.cmp(&b.umi));
metrics
}
}
impl Default for UmiCountTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_umi_count_tracker_empty() {
let tracker = UmiCountTracker::new();
assert_eq!(tracker.total_raw(), 0);
assert_eq!(tracker.total_unique(), 0);
assert_eq!(tracker.iter().count(), 0);
}
#[test]
fn test_umi_count_tracker_default() {
let tracker = UmiCountTracker::default();
assert_eq!(tracker.total_raw(), 0);
assert_eq!(tracker.total_unique(), 0);
assert_eq!(tracker.iter().count(), 0);
}
#[test]
fn test_umi_count_tracker_record_and_iter() {
let mut tracker = UmiCountTracker::new();
tracker.record("AAAA", 10, 2, true);
tracker.record("AAAA", 5, 1, false);
tracker.record("CCCC", 8, 0, true);
assert_eq!(tracker.total_raw(), 23); assert_eq!(tracker.total_unique(), 2);
let mut items: Vec<_> = tracker.iter().collect();
items.sort_by(|a, b| a.0.cmp(b.0));
assert_eq!(items.len(), 2);
assert_eq!(items[0], ("AAAA", 15, 3, 1));
assert_eq!(items[1], ("CCCC", 8, 0, 1));
}
#[test]
fn test_umi_metric_new() {
let metric = UmiMetric::new("ACGT".to_string());
assert_eq!(metric.umi, "ACGT");
assert_eq!(metric.raw_observations, 0);
assert_eq!(metric.unique_observations, 0);
}
#[test]
fn test_to_metrics_sorting() {
let mut tracker = UmiCountTracker::new();
tracker.record("ZZZZ", 1, 0, true);
tracker.record("AAAA", 1, 0, true);
tracker.record("MMMM", 1, 0, true);
let metrics = tracker.to_metrics();
assert_eq!(metrics[0].umi, "AAAA");
assert_eq!(metrics[1].umi, "MMMM");
assert_eq!(metrics[2].umi, "ZZZZ");
}
#[test]
fn test_to_metrics_fractions() {
let mut tracker = UmiCountTracker::new();
tracker.record("AAAA", 10, 2, true);
tracker.record("AAAA", 5, 1, false);
tracker.record("CCCC", 8, 0, true);
let metrics = tracker.to_metrics();
assert_eq!(metrics.len(), 2);
let aaaa =
metrics.iter().find(|m| m.umi == "AAAA").expect("AAAA UMI metric should be present");
assert_eq!(aaaa.raw_observations, 15);
assert_eq!(aaaa.raw_observations_with_errors, 3);
assert_eq!(aaaa.unique_observations, 1);
assert!((aaaa.fraction_raw_observations - 15.0 / 23.0).abs() < f64::EPSILON);
assert!((aaaa.fraction_unique_observations - 0.5).abs() < f64::EPSILON);
let cccc =
metrics.iter().find(|m| m.umi == "CCCC").expect("CCCC UMI metric should be present");
assert_eq!(cccc.raw_observations, 8);
assert_eq!(cccc.unique_observations, 1);
assert!((cccc.fraction_raw_observations - 8.0 / 23.0).abs() < f64::EPSILON);
assert!((cccc.fraction_unique_observations - 0.5).abs() < f64::EPSILON);
}
#[test]
fn test_to_metrics_empty() {
let tracker = UmiCountTracker::new();
let metrics = tracker.to_metrics();
assert!(metrics.is_empty());
}
}