use anyhow::ensure;
use indexmap::IndexMap;
use serde::Serialize;
use std::collections::BTreeMap;
use std::fs::File;
use std::path::Path;
use crate::data_types::compare_benchmark::CompareBenchmark;
use crate::data_types::grouped_metrics::{GroupMetrics, GroupTypeMetrics, MetricsType};
use crate::data_types::summary_metrics::{SummaryGtMetrics, SummaryMetrics};
use crate::data_types::variants::VariantType;
use crate::parsing::stratifications::Stratifications;
#[derive(Default)]
pub struct SummaryWriter {
compare_label: String,
metrics_to_write: Vec<MetricsType>,
all_metrics: GroupTypeMetrics,
strat_metrics: IndexMap<String, GroupTypeMetrics>,
solved_blocks: u64,
error_blocks: u64
}
#[derive(Serialize)]
struct SummaryRow {
compare_label: String,
comparison: MetricsType,
region_label: String,
filter: String,
variant_type: String,
truth_total: u64,
truth_tp: u64,
truth_fn: u64,
query_total: u64,
query_tp: u64,
query_fp: u64,
metric_recall: Option<f64>,
metric_precision: Option<f64>,
metric_f1: Option<f64>,
truth_fn_gt: Option<u64>,
query_fp_gt: Option<u64>
}
impl SummaryRow {
pub fn new(
compare_label: String, comparison: MetricsType, region_label: String, filter: String, variant_type: String,
metrics: &SummaryMetrics
) -> Self {
Self {
compare_label,
comparison, region_label, filter, variant_type,
truth_total: metrics.truth_tp + metrics.truth_fn,
truth_tp: metrics.truth_tp,
truth_fn: metrics.truth_fn,
query_total: metrics.query_tp + metrics.query_fp,
query_tp: metrics.query_tp,
query_fp: metrics.query_fp,
metric_recall: metrics.recall(),
metric_precision: metrics.precision(),
metric_f1: metrics.f1(),
truth_fn_gt: None,
query_fp_gt: None
}
}
pub fn new_gt(
compare_label: String, comparison: MetricsType, region_label: String, filter: String, variant_type: String,
metrics: &SummaryGtMetrics
) -> Self {
Self {
compare_label,
comparison, region_label, filter, variant_type,
truth_total: metrics.summary_metrics.truth_tp + metrics.summary_metrics.truth_fn,
truth_tp: metrics.summary_metrics.truth_tp,
truth_fn: metrics.summary_metrics.truth_fn,
query_total: metrics.summary_metrics.query_tp + metrics.summary_metrics.query_fp,
query_tp: metrics.summary_metrics.query_tp,
query_fp: metrics.summary_metrics.query_fp,
metric_recall: metrics.summary_metrics.recall(),
metric_precision: metrics.summary_metrics.precision(),
metric_f1: metrics.summary_metrics.f1(),
truth_fn_gt: Some(metrics.truth_fn_gt),
query_fp_gt: Some(metrics.query_fp_gt)
}
}
}
impl SummaryWriter {
pub fn new(compare_label: String, metrics_to_write: Vec<MetricsType>, stratifications: Option<&Stratifications>) -> Self {
let strat_metrics = if let Some(strat) = stratifications {
let labels = strat.labels();
labels.into_iter()
.map(|l| (l, Default::default()))
.collect()
} else {
Default::default()
};
Self {
compare_label,
metrics_to_write,
all_metrics: Default::default(),
strat_metrics,
solved_blocks: 0,
error_blocks: 0
}
}
pub fn add_comparison_benchmark(&mut self, comparison: &CompareBenchmark) {
let group_metrics = comparison.group_metrics();
self.all_metrics += group_metrics;
self.solved_blocks += 1;
if let Some(contained_indices) = comparison.containment_regions() {
for &ci in contained_indices.iter() {
self.strat_metrics[ci] += group_metrics;
}
}
}
pub fn inc_error_blocks(&mut self) {
self.error_blocks += 1;
}
pub fn write_summary(&mut self, filename: &Path) -> anyhow::Result<()> {
let is_csv: bool = filename.extension().unwrap_or_default() == "csv";
let delimiter: u8 = if is_csv { b',' } else { b'\t' };
let mut csv_writer: csv::Writer<File> = csv::WriterBuilder::new()
.delimiter(delimiter)
.from_path(filename)?;
let joint_indel_label = "JointIndel".to_string();
let joint_indel_types = [VariantType::Insertion, VariantType::Deletion, VariantType::Indel];
let joint_tr_label = "JointTandemRepeat".to_string();
let joint_tr_types = [VariantType::TrExpansion, VariantType::TrContraction];
let joint_sv_label = "JointStructuralVariant".to_string();
let joint_sv_types = [
VariantType::SvInsertion,
VariantType::SvDeletion,
VariantType::SvDuplication,
VariantType::SvInversion,
VariantType::SvBreakend,
];
let joint_categories = [
(joint_indel_label.clone(), joint_indel_types.as_slice()),
(joint_sv_label.clone(), joint_sv_types.as_slice()),
(joint_tr_label.clone(), joint_tr_types.as_slice()),
];
write_group(
&mut csv_writer, self.compare_label.clone(), "ALL".to_string(), "ALL".to_string(),
&self.metrics_to_write,
&self.all_metrics,
&joint_categories
)?;
for (strat_label, strat_result) in self.strat_metrics.iter() {
write_group(
&mut csv_writer, self.compare_label.clone(), "ALL".to_string(), strat_label.clone(),
&self.metrics_to_write,
strat_result,
&joint_categories
)?;
}
csv_writer.flush()?;
Ok(())
}
pub fn all_metrics(&self) -> &GroupTypeMetrics {
&self.all_metrics
}
pub fn solved_blocks(&self) -> u64 {
self.solved_blocks
}
pub fn error_blocks(&self) -> u64 {
self.error_blocks
}
}
fn write_group(
csv_writer: &mut csv::Writer<File>,
compare_label: String, filter: String, region_label: String,
metrics_to_write: &[MetricsType],
group_metrics: &GroupTypeMetrics,
joint_categories: &[(String, &[VariantType])],
) -> anyhow::Result<()> {
for &metric in metrics_to_write.iter() {
match metric {
MetricsType::Genotype => {
write_gt_category(
csv_writer, compare_label.clone(), filter.clone(), metric, region_label.clone(),
group_metrics.joint_metrics().gt(), group_metrics.variant_metrics(),
joint_categories
)?;
}
_ => {
write_category(
csv_writer, compare_label.clone(), filter.clone(), metric, region_label.clone(),
group_metrics,
joint_categories
)?;
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn write_category(
csv_writer: &mut csv::Writer<File>,
compare_label: String, filter: String, comparison_type: MetricsType, region_label: String,
group_type_metrics: &GroupTypeMetrics,
joint_categories: &[(String, &[VariantType])],
) -> csv::Result<()> {
let all_row = SummaryRow::new(
compare_label.clone(), comparison_type, region_label.clone(), filter.clone(), "ALL".to_string(),
group_type_metrics.joint_metrics().get_metrics(comparison_type)
);
csv_writer.serialize(&all_row)?;
for (variant_type, metrics) in group_type_metrics.variant_metrics().iter() {
let metrics = metrics.get_metrics(comparison_type);
if metrics.is_empty() {
continue;
}
let v_row = SummaryRow::new(
compare_label.clone(), comparison_type, region_label.clone(), filter.clone(), format!("{variant_type:?}"),
metrics
);
csv_writer.serialize(&v_row)?;
}
for (joint_label, joint_types) in joint_categories {
let mut joint_metrics = SummaryMetrics::default();
for variant_type in joint_types.iter() {
if let Some(metrics) = group_type_metrics.variant_metrics().get(variant_type) {
joint_metrics += *metrics.get_metrics(comparison_type);
}
}
if !joint_metrics.is_empty() {
let joint_row = SummaryRow::new(
compare_label.clone(), comparison_type, region_label.clone(), filter.clone(), joint_label.clone(),
&joint_metrics
);
csv_writer.serialize(&joint_row)?;
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn write_gt_category(
csv_writer: &mut csv::Writer<File>,
compare_label: String, filter: String, comparison_type: MetricsType, region_label: String,
full_metrics: &SummaryGtMetrics,
type_metrics: &BTreeMap<VariantType, GroupMetrics>,
joint_categories: &[(String, &[VariantType])],
) -> anyhow::Result<()> {
ensure!(comparison_type == MetricsType::Genotype, "write_gt_category requires a GT input");
let all_row = SummaryRow::new_gt(
compare_label.clone(), comparison_type, region_label.clone(), filter.clone(), "ALL".to_string(),
full_metrics
);
csv_writer.serialize(&all_row)?;
for (variant_type, metrics) in type_metrics.iter() {
if metrics.gt().summary_metrics.is_empty() {
continue;
}
let v_row = SummaryRow::new_gt(
compare_label.clone(), comparison_type, region_label.clone(), filter.clone(), format!("{variant_type:?}"),
metrics.gt()
);
csv_writer.serialize(&v_row)?;
}
for (joint_label, joint_types) in joint_categories {
let mut joint_metrics = SummaryGtMetrics::default();
for variant_type in joint_types.iter() {
if let Some(metrics) = type_metrics.get(variant_type) {
joint_metrics += *metrics.gt();
}
}
if !joint_metrics.summary_metrics.is_empty() {
let joint_row = SummaryRow::new_gt(
compare_label.clone(), comparison_type, region_label.clone(), filter.clone(), joint_label.clone(),
&joint_metrics
);
csv_writer.serialize(&joint_row)?;
}
}
Ok(())
}