use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::shared::{UmiCountTracker, UmiMetric};
use crate::{Metric, frac};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FamilySizeMetric {
pub family_size: usize,
pub cs_count: usize,
pub cs_fraction: f64,
pub cs_fraction_gt_or_eq_size: f64,
pub ss_count: usize,
pub ss_fraction: f64,
pub ss_fraction_gt_or_eq_size: f64,
pub ds_count: usize,
pub ds_fraction: f64,
pub ds_fraction_gt_or_eq_size: f64,
}
impl FamilySizeMetric {
#[must_use]
pub fn new(family_size: usize) -> Self {
Self {
family_size,
cs_count: 0,
cs_fraction: 0.0,
cs_fraction_gt_or_eq_size: 0.0,
ss_count: 0,
ss_fraction: 0.0,
ss_fraction_gt_or_eq_size: 0.0,
ds_count: 0,
ds_fraction: 0.0,
ds_fraction_gt_or_eq_size: 0.0,
}
}
}
impl Default for FamilySizeMetric {
fn default() -> Self {
Self::new(0)
}
}
impl Metric for FamilySizeMetric {
fn metric_name() -> &'static str {
"duplex family size"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DuplexFamilySizeMetric {
pub ab_size: usize,
pub ba_size: usize,
pub count: usize,
pub fraction: f64,
pub fraction_gt_or_eq_size: f64,
}
impl DuplexFamilySizeMetric {
#[must_use]
pub fn new(ab_size: usize, ba_size: usize) -> Self {
Self { ab_size, ba_size, count: 0, fraction: 0.0, fraction_gt_or_eq_size: 0.0 }
}
}
impl Default for DuplexFamilySizeMetric {
fn default() -> Self {
Self::new(0, 0)
}
}
impl Metric for DuplexFamilySizeMetric {
fn metric_name() -> &'static str {
"duplex AB/BA family size"
}
}
impl Ord for DuplexFamilySizeMetric {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.ab_size.cmp(&other.ab_size).then_with(|| self.ba_size.cmp(&other.ba_size))
}
}
impl PartialOrd for DuplexFamilySizeMetric {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Eq for DuplexFamilySizeMetric {}
impl PartialEq for DuplexFamilySizeMetric {
fn eq(&self, other: &Self) -> bool {
self.ab_size == other.ab_size && self.ba_size == other.ba_size
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DuplexYieldMetric {
pub fraction: f64,
pub read_pairs: usize,
pub cs_families: usize,
pub ss_families: usize,
pub ds_families: usize,
pub ds_duplexes: usize,
pub ds_fraction_duplexes: f64,
pub ds_fraction_duplexes_ideal: f64,
}
impl DuplexYieldMetric {
#[must_use]
pub fn new(fraction: f64) -> Self {
Self {
fraction,
read_pairs: 0,
cs_families: 0,
ss_families: 0,
ds_families: 0,
ds_duplexes: 0,
ds_fraction_duplexes: 0.0,
ds_fraction_duplexes_ideal: 0.0,
}
}
}
impl Default for DuplexYieldMetric {
fn default() -> Self {
Self::new(0.0)
}
}
impl Metric for DuplexYieldMetric {
fn metric_name() -> &'static str {
"duplex yield"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DuplexUmiMetric {
pub umi: String,
pub raw_observations: usize,
pub raw_observations_with_errors: usize,
pub unique_observations: usize,
pub fraction_raw_observations: f64,
pub fraction_unique_observations: f64,
pub fraction_unique_observations_expected: f64,
}
impl DuplexUmiMetric {
#[must_use]
pub fn new(umi: String) -> Self {
Self {
umi,
raw_observations: 0,
raw_observations_with_errors: 0,
unique_observations: 0,
fraction_raw_observations: 0.0,
fraction_unique_observations: 0.0,
fraction_unique_observations_expected: 0.0,
}
}
}
impl Default for DuplexUmiMetric {
fn default() -> Self {
Self::new(String::new())
}
}
impl Metric for DuplexUmiMetric {
fn metric_name() -> &'static str {
"duplex UMI"
}
}
pub struct DuplexMetricsCollector {
collect_duplex_umi_counts: bool,
cs_family_sizes: HashMap<usize, usize>,
ss_family_sizes: HashMap<usize, usize>,
ds_family_sizes: HashMap<usize, usize>,
duplex_family_sizes: HashMap<(usize, usize), usize>,
umi_counts: UmiCountTracker,
duplex_umi_counts: UmiCountTracker,
}
impl DuplexMetricsCollector {
#[must_use]
pub fn new(collect_duplex_umi_counts: bool) -> Self {
Self {
collect_duplex_umi_counts,
cs_family_sizes: HashMap::new(),
ss_family_sizes: HashMap::new(),
ds_family_sizes: HashMap::new(),
duplex_family_sizes: HashMap::new(),
umi_counts: UmiCountTracker::new(),
duplex_umi_counts: UmiCountTracker::new(),
}
}
pub fn record_cs_family(&mut self, size: usize) {
*self.cs_family_sizes.entry(size).or_insert(0) += 1;
}
pub fn record_ss_family(&mut self, size: usize) {
*self.ss_family_sizes.entry(size).or_insert(0) += 1;
}
pub fn record_ds_family(&mut self, size: usize) {
*self.ds_family_sizes.entry(size).or_insert(0) += 1;
}
pub fn record_duplex_family(&mut self, ab_size: usize, ba_size: usize) {
let (ab, ba) = if ab_size >= ba_size { (ab_size, ba_size) } else { (ba_size, ab_size) };
*self.duplex_family_sizes.entry((ab, ba)).or_insert(0) += 1;
}
pub fn record_umi(&mut self, umi: &str, raw_count: usize, error_count: usize, is_unique: bool) {
self.umi_counts.record(umi, raw_count, error_count, is_unique);
}
pub fn record_duplex_umi(
&mut self,
umi: &str,
raw_count: usize,
error_count: usize,
is_unique: bool,
) {
if !self.collect_duplex_umi_counts {
return;
}
self.duplex_umi_counts.record(umi, raw_count, error_count, is_unique);
}
#[must_use]
pub fn family_size_metrics(&self) -> Vec<FamilySizeMetric> {
let max_size = *[
self.cs_family_sizes.keys().max().unwrap_or(&0),
self.ss_family_sizes.keys().max().unwrap_or(&0),
self.ds_family_sizes.keys().max().unwrap_or(&0),
]
.iter()
.max()
.expect("array of three elements always has a maximum");
let coord_strand_total: usize = self.cs_family_sizes.values().sum();
let single_strand_total: usize = self.ss_family_sizes.values().sum();
let double_strand_total: usize = self.ds_family_sizes.values().sum();
let mut metrics = Vec::new();
for size in 1..=*max_size {
let mut metric = FamilySizeMetric::new(size);
metric.cs_count = *self.cs_family_sizes.get(&size).unwrap_or(&0);
metric.cs_fraction = frac(metric.cs_count, coord_strand_total);
metric.ss_count = *self.ss_family_sizes.get(&size).unwrap_or(&0);
metric.ss_fraction = frac(metric.ss_count, single_strand_total);
metric.ds_count = *self.ds_family_sizes.get(&size).unwrap_or(&0);
metric.ds_fraction = frac(metric.ds_count, double_strand_total);
metrics.push(metric);
}
for i in (0..metrics.len()).rev() {
let next_coord_strand =
if i + 1 < metrics.len() { metrics[i + 1].cs_fraction_gt_or_eq_size } else { 0.0 };
let next_single_strand =
if i + 1 < metrics.len() { metrics[i + 1].ss_fraction_gt_or_eq_size } else { 0.0 };
let next_double_strand =
if i + 1 < metrics.len() { metrics[i + 1].ds_fraction_gt_or_eq_size } else { 0.0 };
metrics[i].cs_fraction_gt_or_eq_size = metrics[i].cs_fraction + next_coord_strand;
metrics[i].ss_fraction_gt_or_eq_size = metrics[i].ss_fraction + next_single_strand;
metrics[i].ds_fraction_gt_or_eq_size = metrics[i].ds_fraction + next_double_strand;
}
metrics
}
#[must_use]
pub fn duplex_family_size_metrics(&self) -> Vec<DuplexFamilySizeMetric> {
let total: usize = self.duplex_family_sizes.values().sum();
let mut metrics: Vec<_> = self
.duplex_family_sizes
.iter()
.map(|((ab, ba), count)| {
let mut metric = DuplexFamilySizeMetric::new(*ab, *ba);
metric.count = *count;
metric.fraction = frac(*count, total);
metric
})
.collect();
metrics.sort();
if total > 0 {
let max_ab = self.duplex_family_sizes.keys().map(|(a, _)| *a).max().unwrap_or(0);
let max_ba = self.duplex_family_sizes.keys().map(|(_, b)| *b).max().unwrap_or(0);
let cols = max_ba + 1;
let mut grid = vec![0usize; (max_ab + 1) * cols];
for (&(a, b), &count) in &self.duplex_family_sizes {
grid[a * cols + b] = count;
}
for a in 0..=max_ab {
for b in (0..max_ba).rev() {
grid[a * cols + b] += grid[a * cols + b + 1];
}
}
for b in 0..=max_ba {
for a in (0..max_ab).rev() {
grid[a * cols + b] += grid[(a + 1) * cols + b];
}
}
for metric in &mut metrics {
let cumulative_count = grid[metric.ab_size * cols + metric.ba_size];
metric.fraction_gt_or_eq_size = frac(cumulative_count, total);
}
}
metrics
}
#[must_use]
pub fn umi_metrics(&self) -> Vec<UmiMetric> {
self.umi_counts.to_metrics()
}
#[must_use]
pub fn duplex_umi_metrics(&self, umi_metrics: &[UmiMetric]) -> Vec<DuplexUmiMetric> {
if !self.collect_duplex_umi_counts {
return Vec::new();
}
let single_umi_fractions: HashMap<&str, f64> =
umi_metrics.iter().map(|m| (m.umi.as_str(), m.fraction_unique_observations)).collect();
let total_raw = self.duplex_umi_counts.total_raw();
let total_unique = self.duplex_umi_counts.total_unique();
let mut metrics: Vec<_> = self
.duplex_umi_counts
.iter()
.map(|(umi, raw, errors, unique)| {
let mut metric = DuplexUmiMetric::new(umi.to_string());
metric.raw_observations = raw;
metric.raw_observations_with_errors = errors;
metric.unique_observations = unique;
metric.fraction_raw_observations = frac(raw, total_raw);
metric.fraction_unique_observations = frac(unique, total_unique);
metric.fraction_unique_observations_expected =
if let Some((umi1, umi2)) = umi.split_once('-') {
let freq1 = single_umi_fractions.get(umi1).copied().unwrap_or(0.0);
let freq2 = single_umi_fractions.get(umi2).copied().unwrap_or(0.0);
freq1 * freq2
} else {
0.0
};
metric
})
.collect();
metrics.sort_by(|a, b| b.unique_observations.cmp(&a.unique_observations));
metrics
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_family_size_metric_new() {
let metric = FamilySizeMetric::new(5);
assert_eq!(metric.family_size, 5);
assert_eq!(metric.cs_count, 0);
assert!(metric.cs_fraction.abs() < f64::EPSILON);
assert_eq!(metric.ss_count, 0);
assert_eq!(metric.ds_count, 0);
}
#[test]
fn test_duplex_family_size_metric_new() {
let metric = DuplexFamilySizeMetric::new(10, 5);
assert_eq!(metric.ab_size, 10);
assert_eq!(metric.ba_size, 5);
assert_eq!(metric.count, 0);
assert!(metric.fraction.abs() < f64::EPSILON);
}
#[test]
fn test_duplex_family_size_metric_ordering() {
let m1 = DuplexFamilySizeMetric::new(5, 3);
let m2 = DuplexFamilySizeMetric::new(5, 4);
let m3 = DuplexFamilySizeMetric::new(6, 2);
assert!(m1 < m2);
assert!(m1 < m3);
assert!(m2 < m3);
}
#[test]
fn test_duplex_family_size_metric_equality() {
let m1 = DuplexFamilySizeMetric::new(5, 3);
let m2 = DuplexFamilySizeMetric::new(5, 3);
let m3 = DuplexFamilySizeMetric::new(5, 4);
assert_eq!(m1, m2);
assert_ne!(m1, m3);
}
#[test]
fn test_duplex_yield_metric_new() {
let metric = DuplexYieldMetric::new(0.5);
assert!((metric.fraction - 0.5).abs() < f64::EPSILON);
assert_eq!(metric.read_pairs, 0);
assert_eq!(metric.cs_families, 0);
assert_eq!(metric.ds_duplexes, 0);
}
#[test]
fn test_duplex_umi_metric_new() {
let metric = DuplexUmiMetric::new("ACGT-TGCA".to_string());
assert_eq!(metric.umi, "ACGT-TGCA");
assert_eq!(metric.raw_observations, 0);
assert!(metric.fraction_unique_observations_expected.abs() < f64::EPSILON);
}
#[test]
fn test_record_cs_family() {
let mut collector = DuplexMetricsCollector::new(false);
collector.record_cs_family(5);
collector.record_cs_family(5);
collector.record_cs_family(10);
let metrics = collector.family_size_metrics();
let size_5 = metrics
.iter()
.find(|m| m.family_size == 5)
.expect("family_size 5 metric should be present");
assert_eq!(size_5.cs_count, 2);
let size_10 = metrics
.iter()
.find(|m| m.family_size == 10)
.expect("family_size 10 metric should be present");
assert_eq!(size_10.cs_count, 1);
}
#[test]
fn test_record_ss_family() {
let mut collector = DuplexMetricsCollector::new(false);
collector.record_ss_family(3);
collector.record_ss_family(3);
collector.record_ss_family(3);
let metrics = collector.family_size_metrics();
let size_3 = metrics
.iter()
.find(|m| m.family_size == 3)
.expect("family_size 3 metric should be present");
assert_eq!(size_3.ss_count, 3);
}
#[test]
fn test_record_ds_family() {
let mut collector = DuplexMetricsCollector::new(false);
collector.record_ds_family(2);
let metrics = collector.family_size_metrics();
let size_2 = metrics
.iter()
.find(|m| m.family_size == 2)
.expect("family_size 2 metric should be present");
assert_eq!(size_2.ds_count, 1);
}
#[test]
fn test_record_duplex_family_normalization() {
let mut collector = DuplexMetricsCollector::new(false);
collector.record_duplex_family(3, 5);
collector.record_duplex_family(5, 3);
let metrics = collector.duplex_family_size_metrics();
assert_eq!(metrics.len(), 1);
assert_eq!(metrics[0].ab_size, 5);
assert_eq!(metrics[0].ba_size, 3);
assert_eq!(metrics[0].count, 2);
}
#[test]
fn test_record_umi() {
let mut collector = DuplexMetricsCollector::new(false);
collector.record_umi("AAAA", 10, 2, true);
collector.record_umi("AAAA", 5, 1, false); collector.record_umi("CCCC", 8, 0, true);
let metrics = collector.umi_metrics();
assert_eq!(metrics.len(), 2);
let aaaa =
metrics.iter().find(|m| m.umi == "AAAA").expect("AAAA UMI metric should be present");
assert_eq!(aaaa.raw_observations, 15); assert_eq!(aaaa.raw_observations_with_errors, 3); assert_eq!(aaaa.unique_observations, 1);
let cccc =
metrics.iter().find(|m| m.umi == "CCCC").expect("CCCC UMI metric should be present");
assert_eq!(cccc.raw_observations, 8);
assert_eq!(cccc.unique_observations, 1);
}
#[test]
fn test_record_duplex_umi_disabled() {
let mut collector = DuplexMetricsCollector::new(false); collector.record_duplex_umi("AAAA-TTTT", 10, 0, true);
let umi_metrics = collector.umi_metrics();
let duplex_metrics = collector.duplex_umi_metrics(&umi_metrics);
assert!(duplex_metrics.is_empty());
}
#[test]
fn test_record_duplex_umi_enabled() {
let mut collector = DuplexMetricsCollector::new(true); collector.record_duplex_umi("AAAA-TTTT", 10, 2, true);
collector.record_umi("AAAA", 5, 0, true);
collector.record_umi("TTTT", 5, 0, true);
let umi_metrics = collector.umi_metrics();
let duplex_metrics = collector.duplex_umi_metrics(&umi_metrics);
assert_eq!(duplex_metrics.len(), 1);
assert_eq!(duplex_metrics[0].umi, "AAAA-TTTT");
assert_eq!(duplex_metrics[0].raw_observations, 10);
}
#[test]
fn test_family_size_metrics_fractions() {
let mut collector = DuplexMetricsCollector::new(false);
collector.record_cs_family(1);
collector.record_cs_family(1);
collector.record_cs_family(2);
collector.record_cs_family(3);
let metrics = collector.family_size_metrics();
let size_1 = metrics
.iter()
.find(|m| m.family_size == 1)
.expect("family_size 1 metric should be present");
assert_eq!(size_1.cs_count, 2);
assert!((size_1.cs_fraction - 0.5).abs() < 0.001); assert!((size_1.cs_fraction_gt_or_eq_size - 1.0).abs() < 0.001);
let size_3 = metrics
.iter()
.find(|m| m.family_size == 3)
.expect("family_size 3 metric should be present");
assert_eq!(size_3.cs_count, 1);
assert!((size_3.cs_fraction - 0.25).abs() < 0.001); assert!((size_3.cs_fraction_gt_or_eq_size - 0.25).abs() < 0.001); }
#[test]
fn test_duplex_family_size_metrics_sorting() {
let mut collector = DuplexMetricsCollector::new(false);
collector.record_duplex_family(5, 3);
collector.record_duplex_family(2, 1);
collector.record_duplex_family(5, 2);
let metrics = collector.duplex_family_size_metrics();
assert_eq!(metrics[0].ab_size, 2);
assert_eq!(metrics[0].ba_size, 1);
assert_eq!(metrics[1].ab_size, 5);
assert_eq!(metrics[1].ba_size, 2);
assert_eq!(metrics[2].ab_size, 5);
assert_eq!(metrics[2].ba_size, 3);
}
#[test]
fn test_duplex_umi_expected_frequency() {
let mut collector = DuplexMetricsCollector::new(true);
collector.record_umi("AAAA", 10, 0, true);
collector.record_umi("TTTT", 10, 0, true);
collector.record_duplex_umi("AAAA-TTTT", 5, 0, true);
let umi_metrics = collector.umi_metrics();
let duplex_metrics = collector.duplex_umi_metrics(&umi_metrics);
assert_eq!(duplex_metrics.len(), 1);
assert!((duplex_metrics[0].fraction_unique_observations_expected - 0.25).abs() < 0.001);
}
#[test]
fn test_empty_collector() {
let collector = DuplexMetricsCollector::new(false);
let family_metrics = collector.family_size_metrics();
assert!(family_metrics.is_empty());
let duplex_metrics = collector.duplex_family_size_metrics();
assert!(duplex_metrics.is_empty());
let umi_metrics = collector.umi_metrics();
assert!(umi_metrics.is_empty());
}
#[test]
fn test_metric_trait_impl() {
assert_eq!(FamilySizeMetric::metric_name(), "duplex family size");
assert_eq!(DuplexFamilySizeMetric::metric_name(), "duplex AB/BA family size");
assert_eq!(DuplexYieldMetric::metric_name(), "duplex yield");
assert_eq!(UmiMetric::metric_name(), "UMI");
assert_eq!(DuplexUmiMetric::metric_name(), "duplex UMI");
}
}