use noodles::bgzf;
use serde::Serialize;
use std::fs::File;
use std::path::Path;
use crate::data_types::compare_benchmark::CompareBenchmark;
use crate::data_types::compare_region::CompareRegion;
use crate::data_types::summary_metrics::{SummaryGtMetrics, SummaryMetrics};
use crate::data_types::grouped_metrics::MetricsType;
pub struct RegionSummaryWriter {
csv_writer: csv::Writer<bgzf::io::MultithreadedWriter<File>>,
metrics_to_write: Vec<MetricsType>,
}
#[derive(Serialize)]
struct RegionSummaryRow {
region_id: u64,
coordinates: String,
comparison: MetricsType,
truth_total: u64,
truth_tp: u64,
truth_fn: u64,
query_total: u64,
query_tp: u64,
query_fp: u64,
metric_recall: Option<f64>,
metric_precision: Option<f64>,
metric_f1: Option<f64>,
truth_fn_gt: Option<u64>,
query_fp_gt: Option<u64>
}
impl RegionSummaryRow {
pub fn new(region: &CompareRegion, comparison: MetricsType, metrics: &SummaryMetrics) -> Self {
let region_id = region.region_id();
let coordinates = format!("{}", region.coordinates());
Self {
region_id, coordinates, comparison,
truth_total: metrics.truth_tp + metrics.truth_fn,
truth_tp: metrics.truth_tp,
truth_fn: metrics.truth_fn,
query_total: metrics.query_tp + metrics.query_fp,
query_tp: metrics.query_tp,
query_fp: metrics.query_fp,
metric_recall: metrics.recall(),
metric_precision: metrics.precision(),
metric_f1: metrics.f1(),
truth_fn_gt: None,
query_fp_gt: None
}
}
pub fn new_gt(region: &CompareRegion, comparison: MetricsType, metrics: &SummaryGtMetrics) -> Self {
let region_id = region.region_id();
let coordinates = format!("{}", region.coordinates());
Self {
region_id, coordinates, comparison,
truth_total: metrics.summary_metrics.truth_tp + metrics.summary_metrics.truth_fn,
truth_tp: metrics.summary_metrics.truth_tp,
truth_fn: metrics.summary_metrics.truth_fn,
query_total: metrics.summary_metrics.query_tp + metrics.summary_metrics.query_fp,
query_tp: metrics.summary_metrics.query_tp,
query_fp: metrics.summary_metrics.query_fp,
metric_recall: metrics.summary_metrics.recall(),
metric_precision: metrics.summary_metrics.precision(),
metric_f1: metrics.summary_metrics.f1(),
truth_fn_gt: Some(metrics.truth_fn_gt),
query_fp_gt: Some(metrics.query_fp_gt)
}
}
}
impl RegionSummaryWriter {
pub fn new(filename: &Path, metrics_to_write: Vec<MetricsType>, threads: usize) -> csv::Result<Self> {
let delimiter: u8 = b'\t';
let w_threads = std::num::NonZeroUsize::new(threads.clamp(1, 4)).unwrap();
let gzip_writer = bgzf::io::MultithreadedWriter::with_worker_count(w_threads, File::create(filename)?);
let csv_writer= csv::WriterBuilder::new()
.delimiter(delimiter)
.from_writer(gzip_writer);
Ok(Self {
csv_writer,
metrics_to_write
})
}
pub fn write_region_summary(&mut self, region: &CompareRegion, comparison: &CompareBenchmark) -> csv::Result<()> {
let joint_metrics = comparison.group_metrics().joint_metrics();
for &metric in self.metrics_to_write.iter() {
match metric {
MetricsType::Genotype => {
let row = RegionSummaryRow::new_gt(region, metric, joint_metrics.gt());
self.csv_writer.serialize(&row)?;
},
_ => {
let row = RegionSummaryRow::new(region, metric, joint_metrics.get_metrics(metric));
self.csv_writer.serialize(&row)?;
}
}
}
Ok(())
}
}