use serde::Serialize;
use std::collections::BTreeMap;
use std::fs::File;
use std::path::Path;
use crate::data_types::merge_benchmark::{MergeBenchmark, MergeClassification};
use crate::data_types::multi_region::MultiRegion;
use crate::data_types::variants::VariantType;
type LookupKey = (MergeClassification, VariantType, usize);
#[derive(Default)]
pub struct MergeSummaryWriter {
pass_fail_counts: BTreeMap<LookupKey, (u64, u64)>,
}
#[derive(Serialize)]
struct MergeSummaryRow {
merge_reason: String,
variant_type: String,
vcf_index: usize,
vcf_label: String,
pass_variants: u64,
fail_variants: u64,
}
impl MergeSummaryRow {
pub fn new(
merge_classification: &MergeClassification, variant_type: &VariantType, vcf_index: usize,
vcf_label: String, pass_variants: u64, fail_variants: u64
) -> Self {
Self {
merge_reason: merge_classification.to_string(),
variant_type: format!("{variant_type:?}"),
vcf_index,
vcf_label,
pass_variants,
fail_variants
}
}
}
impl MergeSummaryWriter {
pub fn add_merge_benchmark(&mut self, region: &MultiRegion, comparison: &MergeBenchmark) {
let merge_classification = comparison.merge_classification();
let passing_indices = match merge_classification {
MergeClassification::Different => vec![],
MergeClassification::NoConflict { indices } |
MergeClassification::MajorityAgree { indices } => indices.clone(),
MergeClassification::ConflictSelection { index } => vec![*index],
MergeClassification::BasepairIdentical => (0..region.variants().len()).collect(),
};
for (i, variants) in region.variants().iter().enumerate() {
let is_passing = passing_indices.contains(&i);
for v in variants.iter() {
let k = (merge_classification.clone(), v.variant_type(), i);
let entry = self.pass_fail_counts.entry(k).or_default();
if is_passing {
entry.0 += 1;
} else {
entry.1 += 1;
}
}
}
}
pub fn write_summary<T: std::fmt::Display>(&mut self, filename: &Path, tags: &[T]) -> csv::Result<()> {
let is_csv: bool = filename.extension().unwrap_or_default() == "csv";
let delimiter: u8 = if is_csv { b',' } else { b'\t' };
let mut csv_writer: csv::Writer<File> = csv::WriterBuilder::new()
.delimiter(delimiter)
.from_path(filename)?;
for ((merge_class, variant_type, vcf_index), (pass_count, fail_count)) in self.pass_fail_counts.iter() {
let row = MergeSummaryRow::new(
merge_class, variant_type, *vcf_index,
tags[*vcf_index].to_string(), *pass_count, *fail_count
);
csv_writer.serialize(&row)?;
}
csv_writer.flush()?;
Ok(())
}
}