use anyhow::{self};
use bio::bio_types::alignment::Alignment;
use crossbeam::channel::{Receiver, unbounded};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::sync::{Arc, Mutex};
use std::thread::scope;
use crate::combination::{CombinationKey, CombinationMatch, ObservedCombination};
use crate::errors::{LibraryError, ReadCountError};
use crate::filters::{FilterConfig, FilterReason, FilteredCounts, FilteredReads};
use crate::lib_spec::{DistanceMetric, Library};
use crate::library_combination::{LibraryCombination, LibraryCombinationKey};
use crate::logging::{Progress, ProgressStyle};
use crate::region::{ObservedRegion, RegionKey, RegionMatch};
use crate::seqs::ReadPair;
use crate::seqs::{ReadGroup, SeqPair};
use crate::utils::div_or_zero;
#[derive(Debug)]
pub struct ObservedCombinations {
region_ids: Vec<String>,
regions: HashMap<RegionKey, Arc<Mutex<ObservedRegion>>>,
combinations: HashMap<CombinationKey, ObservedCombination>,
library: Option<Library>,
library_combinations: Option<HashMap<LibraryCombinationKey, LibraryCombination>>,
filtered_reads: FilteredReads,
cache: ObservedReads,
}
impl ObservedCombinations {
pub fn new(region_ids: Vec<String>, filter_config: FilterConfig) -> Self {
Self {
region_ids,
regions: HashMap::new(),
combinations: HashMap::new(),
library: None,
library_combinations: None,
filtered_reads: FilteredReads::new(filter_config),
cache: HashMap::new(),
}
}
pub fn merge(&mut self, new_counts: ObservedCombinations) -> Result<(), ReadCountError> {
if self.is_compared_to_library() || new_counts.is_compared_to_library() {
return Err(ReadCountError::Error {
desc: "Can't merge ObservedCombinations once library comparison has been run"
.to_string(),
});
}
if self.region_ids != new_counts.region_ids {
return Err(ReadCountError::Error {
desc: "Can't merge ObservedCombinations with different region_ids".to_string(),
});
}
self.filtered_reads.merge(new_counts.filtered_reads)?;
for (k, v) in new_counts.regions.iter() {
if !self.regions.contains_key(k) {
self.regions.insert(k.clone(), v.to_owned());
}
}
for (comb_key, mut new_comb) in new_counts.combinations.into_iter() {
match self.combinations.get_mut(&comb_key) {
Some(old_comb) => {
for (group, new_count) in &new_comb.counts {
match old_comb.counts.get_mut(group) {
Some(old_count) => *old_count += new_count,
None => {
old_comb.counts.insert(group.clone(), *new_count);
}
}
}
}
None => {
new_comb.regions.clear();
for reg_key in &comb_key.regions {
let arc = match self.regions.get(reg_key) {
Some(x) => x,
None => return Err(ReadCountError::Error {
desc: "Region key missing during combination merge after merging regions".to_string(),
}),
};
new_comb.regions.insert(reg_key.id.clone(), arc.clone());
}
self.combinations.insert(comb_key, new_comb);
}
}
}
Ok(())
}
pub fn len(&self) -> usize {
self.combinations.len()
}
pub fn is_empty(&self) -> bool {
self.combinations.is_empty()
}
pub fn total_filtered(&self) -> u64 {
self.filtered_reads.total()
}
pub fn add_or_increment_combination(
&mut self,
comb_key: &CombinationKey,
group: ReadGroup,
) -> Result<(), anyhow::Error> {
match self.combinations.get_mut(comb_key) {
Some(comb) => comb.increment_count(group),
None => {
let mut reg_map = HashMap::new();
for reg_key in &comb_key.regions {
if !self.region_ids.contains(®_key.id) {
return Err(ReadCountError::UnexpectedRegion {
region: reg_key.id.clone(),
}
.into());
}
match self.regions.get(reg_key) {
None => {
let new_reg = Arc::new(Mutex::new(ObservedRegion::new(
reg_key.id.clone(),
®_key.sequence,
reg_key.completeness,
)));
self.regions.insert(reg_key.clone(), new_reg.clone());
reg_map.insert(reg_key.id.clone(), new_reg.clone());
}
Some(r) => {
reg_map.insert(reg_key.id.clone(), r.clone());
}
}
}
let mut comb = ObservedCombination::new(reg_map, comb_key.sequence.clone());
comb.increment_count(group);
self.combinations.insert(comb_key.clone(), comb);
}
}
Ok(())
}
pub fn update_filter_count(&mut self, read: &ReadPair, reason: FilterReason) {
self.filtered_reads.increment_count(read, reason)
}
pub fn filter_readpair(&mut self, record: &ReadPair, increment: bool) -> Option<FilterReason> {
self.filtered_reads.filter_readpair(record, increment)
}
pub fn filter_alignment(
&mut self,
record: &ReadPair,
f_alignment: &Alignment,
r_alignment: Option<&Alignment>,
increment: bool,
) -> Option<FilterReason> {
self.filtered_reads
.filter_alignment(record, f_alignment, r_alignment, increment)
}
pub fn cache(&mut self, key: SeqPair, value: CacheHit) {
self.cache.insert(key, value);
}
pub fn check_cache(
&mut self,
record: &ReadPair,
increment: bool,
) -> Result<Option<CacheHit>, anyhow::Error> {
let key = SeqPair::from_readpair(record);
let hit: CacheHit = match self.cache.get(&key) {
Some(x) => x.clone(),
None => return Ok(None),
};
if increment {
match hit {
CacheHit::Comb(ref k) => {
self.add_or_increment_combination(k, record.group.clone())?;
}
CacheHit::Filter(r) => {
self.update_filter_count(record, r);
}
}
}
Ok(Some(hit))
}
pub fn compare_to_library(
&mut self,
library: Library,
progress_style: Option<&ProgressStyle>,
distance_metric: DistanceMetric,
max_matches: usize,
threads: usize,
) -> Result<(), LibraryError> {
let n_regs = self.regions.len() as u64;
let n_combs = self.combinations.len() as u64;
let mut reg_progress: Progress = Progress::from_style(
progress_style.unwrap_or(&ProgressStyle::new(None, false)),
"Matching regions:",
"Matched regions:",
Some(n_regs),
match distance_metric {
DistanceMetric::Hamming | DistanceMetric::Exact => 250000,
DistanceMetric::BoundedLevenshtein => std::cmp::max(n_regs / 10, 50000),
DistanceMetric::Levenshtein => std::cmp::max(n_regs / 20, 10000),
},
);
match threads.cmp(&1) {
std::cmp::Ordering::Less => {
return Err(LibraryError::Library {
desc: "Threads must be >0".to_string(),
});
}
std::cmp::Ordering::Equal => {
for r in self.regions.values() {
let mut reg = r.lock().unwrap();
let val = reg.compare_to_library(&library, distance_metric, max_matches);
reg.nearest_matches = val;
reg_progress.inc(1);
}
reg_progress.finish();
}
std::cmp::Ordering::Greater => {
let (reg_tx, reg_rx) = unbounded();
let (done_tx, done_rx) = unbounded();
let lib_arc = Arc::new(&library);
scope(|scope| {
for _ in 0..threads {
let rx: Receiver<Arc<Mutex<ObservedRegion>>> = reg_rx.clone();
let tx = done_tx.clone();
let lib = lib_arc.clone();
scope.spawn(move || {
while let Ok(region) = rx.recv() {
let mut reg = region.lock().unwrap();
let val =
reg.compare_to_library(&lib, distance_metric, max_matches);
reg.nearest_matches = val;
tx.send(()).expect("Main thread comparison reciever failed");
}
});
}
for region in self.regions.values() {
reg_tx
.send(region.clone())
.expect("Library comparison thread send failed");
}
drop(reg_tx);
drop(reg_rx);
drop(done_tx);
for _ in done_rx.iter() {
reg_progress.inc(1);
}
drop(done_rx);
});
reg_progress.finish();
}
}
let mut comb_progress: Progress = Progress::from_style(
progress_style.unwrap_or(&ProgressStyle::new(None, false)),
"Comparing combinations:",
"Compared combinations:",
Some(n_combs),
2000000,
);
for value in self.combinations.values_mut() {
value.library_matches =
value.compare_to_library(&self.region_ids, &library, distance_metric, max_matches);
comb_progress.inc(1);
}
comb_progress.finish();
let mut lib_summary_progress: Progress = Progress::from_style(
progress_style.unwrap_or(&ProgressStyle::new(None, false)),
"Summarising library matches:",
"Summarised library matches:",
Some(n_combs),
std::cmp::max(n_combs / 4, 250000),
);
let mut lib_combs: HashMap<LibraryCombinationKey, LibraryCombination> = HashMap::new();
for comb in self.combinations.values_mut() {
let mut key: LibraryCombinationKey = LibraryCombinationKey::new(Vec::with_capacity(5));
for reg in &self.region_ids {
match comb.regions.get(reg) {
None => key.regions.push((reg.to_string(), RegionMatch::Unmatched)),
Some(x) => {
let or = x.lock().unwrap();
key.regions.push((
reg.to_string(),
match &or.nearest_matches {
RegionMatch::Unmatched => RegionMatch::Unmatched,
RegionMatch::Overmatched { .. } => RegionMatch::Overmatched {
distance: 0,
matches: 0,
},
RegionMatch::NoLibrary { .. } => RegionMatch::NoLibrary {
seq: Some(or.seq.clone()),
},
RegionMatch::Uncompared => RegionMatch::Uncompared,
RegionMatch::Match { seq_match, .. } => RegionMatch::Match {
seq_match: seq_match.clone(),
distance: 0,
},
RegionMatch::MultiMatch { seq_matches, .. } => {
RegionMatch::MultiMatch {
seq_matches: seq_matches.to_vec(),
distance: 0,
}
}
},
))
}
}
}
match lib_combs.get_mut(&key) {
Some(x) => {
for (group, count) in &comb.counts {
x.increment_count(group, *count);
}
}
None => {
let mut x: LibraryCombination = LibraryCombination::new(
HashMap::from_iter(key.regions.clone()),
match &comb.library_matches {
CombinationMatch::Uncompared => CombinationMatch::Uncompared,
CombinationMatch::Match { ind, .. } => CombinationMatch::Match {
ind: *ind,
distance: 0,
},
CombinationMatch::MultiMatch { inds, .. } => {
CombinationMatch::MultiMatch {
inds: inds.clone(),
distance: 0,
}
}
CombinationMatch::Recombination { .. } => {
CombinationMatch::Recombination { distance: 0 }
}
CombinationMatch::Mismatch => CombinationMatch::Mismatch,
CombinationMatch::Nonmatch => CombinationMatch::Nonmatch,
},
);
for (group, count) in &comb.counts {
x.increment_count(group, *count);
}
lib_combs.insert(key, x);
}
}
lib_summary_progress.inc(1);
}
lib_summary_progress.finish();
self.library = Some(library);
self.library_combinations = Some(lib_combs);
Ok(())
}
pub fn is_compared_to_library(&self) -> bool {
self.library.is_some()
}
pub fn summarise(&self) -> ReadSummary {
let mut read_summary = ReadSummary::empty();
read_summary.filtered_reads = self.filtered_reads.totals.clone();
for comb in self.combinations.values() {
let count: u64 = comb.total_count() as u64;
match comb.library_matches {
CombinationMatch::Uncompared => read_summary.uncompared += count,
CombinationMatch::Match { distance, .. } => {
if distance == 0 {
read_summary.exact_match += count
} else {
read_summary.nearest_match += count
}
}
CombinationMatch::MultiMatch { .. } => read_summary.multimatch += count,
CombinationMatch::Recombination { distance } => {
if distance == 0 {
read_summary.exact_recombination += count
} else {
read_summary.nearest_recombination += count
}
}
CombinationMatch::Mismatch => read_summary.mismatch += count,
CombinationMatch::Nonmatch => read_summary.nonmatch += count,
}
}
read_summary
}
pub fn write_tsv(&self, file: File, sort: bool) -> Result<(), anyhow::Error> {
let mut count_writer = BufWriter::new(file);
let mut keys: Vec<(&CombinationKey, u32)> = self
.combinations
.iter()
.map(|x| (x.0, x.1.total_count()))
.collect();
if sort {
keys.sort_unstable_by_key(|x| 0 - i64::from(x.1));
}
write!(count_writer, "group\tforward\treverse\t")?;
for r in &self.region_ids {
write!(
count_writer,
"{}\t{}_nearest\t{}_distance\t{}_n_matches\t",
r, r, r, r
)?;
}
writeln!(
count_writer,
"combination_status\tcombination_distance\tcombinations_in_library\tcombination_id\tcount"
)?;
for (key, _) in keys {
let combination = self.combinations.get(key).expect(
"Combination key from extracted key list missing from ObservedCombinations",
);
write!(
count_writer,
"{}",
combination.to_tsv(&self.region_ids, self.library.as_ref())?
)?;
}
count_writer.flush()?;
Ok(())
}
pub fn write_summary_tsv(&self, file: File, sort: bool) -> Result<(), anyhow::Error> {
let combs = match &self.library_combinations {
None => {
return Err(ReadCountError::Error {
desc: "Combinations uncompared, compare before summarising".to_string(),
}
.into());
}
Some(x) => x,
};
let mut writer = BufWriter::new(file);
let mut keys: Vec<(&LibraryCombinationKey, u32)> =
combs.iter().map(|x| (x.0, x.1.total_count())).collect();
if sort {
keys.sort_unstable_by_key(|x| 0 - i64::from(x.1));
}
write!(writer, "group\t")?;
for r in &self.region_ids {
write!(writer, "{}\t", r)?;
}
writeln!(
writer,
"combination_status\tcombinations_in_library\tcombination_id\tcount"
)?;
for (key, _) in keys {
let combination = combs.get(key).expect(
"Combination key from extracted key list missing from ObservedCombinations",
);
write!(
writer,
"{}",
combination.to_tsv(&self.region_ids, self.library.as_ref())?
)?;
}
writer.flush()?;
Ok(())
}
pub fn write_filtered_tsv(&self, file: File, sort: bool) -> Result<(), anyhow::Error> {
self.filtered_reads.write_filter_tsv(file, sort)?;
Ok(())
}
}
pub type ObservedReads = HashMap<SeqPair, CacheHit>;
#[derive(Debug, Clone)]
pub enum CacheHit {
Comb(CombinationKey),
Filter(FilterReason),
}
pub struct ReadSummary {
pub uncompared: u64,
pub exact_match: u64,
pub nearest_match: u64,
pub multimatch: u64,
pub exact_recombination: u64,
pub nearest_recombination: u64,
pub mismatch: u64,
pub nonmatch: u64,
pub filtered_reads: FilteredCounts,
}
impl ReadSummary {
#[allow(dead_code)]
fn new(
uncompared: u64,
exact_match: u64,
nearest_match: u64,
multimatch: u64,
exact_recombination: u64,
nearest_recombination: u64,
mismatch: u64,
nonmatch: u64,
filtered_reads: FilteredCounts,
) -> Self {
Self {
uncompared,
exact_match,
nearest_match,
multimatch,
exact_recombination,
nearest_recombination,
mismatch,
nonmatch,
filtered_reads,
}
}
fn empty() -> Self {
Self {
uncompared: 0,
exact_match: 0,
nearest_match: 0,
multimatch: 0,
exact_recombination: 0,
nearest_recombination: 0,
mismatch: 0,
nonmatch: 0,
filtered_reads: FilteredCounts::new(),
}
}
fn total_unfiltered(&self) -> u64 {
self.uncompared
+ self.exact_match
+ self.nearest_match
+ self.multimatch
+ self.exact_recombination
+ self.nearest_recombination
+ self.mismatch
+ self.nonmatch
}
pub fn total(&self) -> u64 {
self.total_unfiltered() + self.filtered_reads.total()
}
pub fn write_tsv(self, file: File) -> Result<(), anyhow::Error> {
let mut writer = BufWriter::new(file);
let unfiltered_total = self.total_unfiltered();
let total = self.total();
writeln!(
writer,
"group\tmetric\tcount\toverall_proportion\tgroup_proportion"
)?;
writeln!(
writer,
"all\ttotal\t{}\t{:.4}\t{:.4}",
total,
if total > 0 { 1.0 } else { 0.0 },
if total > 0 { 1.0 } else { 0.0 },
)?;
writeln!(
writer,
"unfiltered\ttotal\t{}\t{:.4}\t{:.4}",
unfiltered_total,
div_or_zero(unfiltered_total as f32, total as f32),
div_or_zero(unfiltered_total as f32, unfiltered_total as f32),
)?;
writeln!(
writer,
"unfiltered\tuncompared\t{}\t{:.4}\t{:.4}",
self.uncompared,
div_or_zero(self.uncompared as f32, total as f32),
div_or_zero(self.uncompared as f32, unfiltered_total as f32),
)?;
writeln!(
writer,
"unfiltered\texact_match\t{}\t{:.4}\t{:.4}",
self.exact_match,
div_or_zero(self.exact_match as f32, total as f32),
div_or_zero(self.exact_match as f32, unfiltered_total as f32),
)?;
writeln!(
writer,
"unfiltered\tnearest_match\t{}\t{:.4}\t{:.4}",
self.nearest_match,
div_or_zero(self.nearest_match as f32, total as f32),
div_or_zero(self.nearest_match as f32, unfiltered_total as f32),
)?;
writeln!(
writer,
"unfiltered\tmultimatch\t{}\t{:.4}\t{:.4}",
self.multimatch,
div_or_zero(self.multimatch as f32, total as f32),
div_or_zero(self.multimatch as f32, unfiltered_total as f32),
)?;
writeln!(
writer,
"unfiltered\texact_recombination\t{}\t{:.4}\t{:.4}",
self.exact_recombination,
div_or_zero(self.exact_recombination as f32, total as f32),
div_or_zero(self.exact_recombination as f32, unfiltered_total as f32),
)?;
writeln!(
writer,
"unfiltered\tnearest_recombination\t{}\t{:.4}\t{:.4}",
self.nearest_recombination,
div_or_zero(self.nearest_recombination as f32, total as f32),
div_or_zero(self.nearest_recombination as f32, unfiltered_total as f32),
)?;
writeln!(
writer,
"unfiltered\tmismatch\t{}\t{:.4}\t{:.4}",
self.mismatch,
div_or_zero(self.mismatch as f32, total as f32),
div_or_zero(self.mismatch as f32, unfiltered_total as f32),
)?;
writeln!(
writer,
"unfiltered\tnonmatch\t{}\t{:.4}\t{:.4}",
self.nonmatch,
div_or_zero(self.nonmatch as f32, total as f32),
div_or_zero(self.nonmatch as f32, unfiltered_total as f32),
)?;
write!(writer, "{}", self.filtered_reads.to_long_tsv_lines(total))?;
writer.flush()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
}