use anyhow::bail;
use log::{debug, trace};
use priority_queue::PriorityQueue;
use std::cmp::Reverse;
use std::collections::BTreeMap;
use crate::data_types::coordinates::Coordinates;
use crate::data_types::phase_enums::Allele;
use crate::data_types::variants::Variant;
use crate::dwfa::haplotype_dwfa::{HapDWFAError, HaplotypeDWFA};
use crate::query_optimizer::order_variants;
#[derive(Debug, Eq, PartialEq)]
pub struct OptimizedAlleles {
truth_alleles: Vec<Allele>,
query_alleles: Vec<Allele>,
num_errors: usize
}
impl OptimizedAlleles {
pub fn new(truth_alleles: Vec<Allele>, query_alleles: Vec<Allele>, num_errors: usize) -> Self {
Self {
truth_alleles,
query_alleles,
num_errors
}
}
pub fn is_exact_match(&self) -> bool {
self.num_errors == 0
}
pub fn truth_alleles(&self) -> &[Allele] {
&self.truth_alleles
}
pub fn query_alleles(&self) -> &[Allele] {
&self.query_alleles
}
pub fn num_errors(&self) -> usize {
self.num_errors
}
}
pub fn optimize_gt_alleles(
reference: &[u8], coordinates: &Coordinates,
truth_variants: &[Variant], truth_alleles: &[Allele],
query_variants: &[Variant], query_alleles: &[Allele],
) -> anyhow::Result<OptimizedAlleles> {
if truth_variants.len() != truth_alleles.len() {
bail!("truth values must have equal length");
}
if query_variants.len() != query_alleles.len() {
bail!("query values must have equal length");
}
debug!("Truth variants:");
for (v, a) in truth_variants.iter().zip(truth_alleles.iter()) {
debug!("\t{a:?} => {} {:?} {:?}", v.position(), v.allele0(), v.allele1());
}
debug!("Query variants:");
for (v, a) in query_variants.iter().zip(query_alleles.iter()) {
debug!("\t{a:?} => {} {:?} {:?}", v.position(), v.allele0(), v.allele1());
}
let all_variant_order: Vec<(usize, bool)> = order_variants(truth_variants, query_variants);
let total_variant_count = all_variant_order.len();
let mut next_node_id = 0;
let start = coordinates.start() as usize;
let root_node = ExactMatchNode::new(next_node_id, start);
assert!(root_node.is_exact_match());
next_node_id += 1;
let priority = root_node.priority();
let mut pqueue: PriorityQueue<ExactMatchNode, NodePriority> = Default::default();
pqueue.push(root_node, priority);
let mut best_error_count = usize::MAX;
let mut best_result = None;
let max_per_bucket = usize::MAX;
let mut bucket_counts = vec![0; total_variant_count];
let mut min_allele_sync = 0;
let mut failed_counts: std::collections::BTreeMap<usize, usize> = Default::default();
let auto_fail_threshold = 500;
let mut auto_fail_index = 0;
let mut auto_fail_counts = 0;
let start = std::time::Instant::now();
while let Some((current_node, _priority)) = pqueue.pop() {
if current_node.num_errors() >= best_error_count {
continue;
}
if start.elapsed().as_secs() > 300 {
bail!("300 second time limit reached");
}
let order_index = current_node.set_alleles();
if order_index == total_variant_count {
let mut final_node = current_node;
final_node.finalize_dwfas(reference, coordinates.end() as usize)?;
let final_errors = final_node.num_errors();
if final_node.is_exact_match() && final_errors < best_error_count {
best_error_count = final_errors;
best_result = Some(final_node);
}
continue;
}
if order_index < min_allele_sync {
continue;
}
if bucket_counts[order_index] >= max_per_bucket {
continue;
}
bucket_counts[order_index] += 1;
if current_node.is_synchronized() {
assert!(order_index >= min_allele_sync);
min_allele_sync = order_index;
auto_fail_counts = 0;
auto_fail_index = min_allele_sync;
for (i, c) in failed_counts.iter() {
debug!("\t{i} ({:?}) => {c}", all_variant_order[*i]);
}
failed_counts.clear();
debug!("New min_allele_sync={min_allele_sync}");
debug!("Sync stats: {} / {} => {} / {}",
current_node.hap_dwfa().truth_haplotype().alleles().len(),
current_node.hap_dwfa().query_haplotype().alleles().len(),
current_node.hap_dwfa().truth_haplotype().sequence().len(),
current_node.hap_dwfa().query_haplotype().sequence().len()
);
debug!("Current queue size: {}", pqueue.len());
let mut queue_stats: BTreeMap<(usize, usize), usize> = Default::default();
for (n, _) in pqueue.iter() {
let k = (n.set_alleles(), n.num_errors());
*queue_stats.entry(k).or_default() += 1;
}
for (k, v) in queue_stats.iter() {
debug!("\t{k:?} => {v}");
}
}
let (variant_index, is_truth) = all_variant_order[order_index];
let (current_variant, current_allele) = if is_truth {
(&truth_variants[variant_index], truth_alleles[variant_index])
} else {
(&query_variants[variant_index], query_alleles[variant_index])
};
let next_var_pos = if order_index == all_variant_order.len() - 1 {
coordinates.end() as usize
} else {
let (nvi, nt) = all_variant_order[order_index+1];
if nt { truth_variants[nvi].position() as usize } else { query_variants[nvi].position() as usize }
};
let sync_extension = Some(next_var_pos);
match current_allele {
Allele::Unknown => bail!("Allele::Unknown is not supported"),
Allele::Reference => {
let mut new_node = current_node;
let success = new_node.extend_variant(
reference, is_truth,
current_variant, Allele::Reference, sync_extension, false
)?;
if success && new_node.is_exact_match() {
let new_priority = new_node.priority();
pqueue.push(new_node, new_priority);
} else {
*failed_counts.entry(order_index).or_default() += 1;
}
},
Allele::Alternate => {
let extensions = [
(Allele::Reference, true), (Allele::Alternate, false) ];
for (allele, is_error) in extensions.into_iter() {
if order_index < auto_fail_index && allele != Allele::Reference {
continue;
}
let mut new_node = current_node.clone();
new_node.set_node_id(next_node_id);
next_node_id += 1;
let success = new_node.extend_variant(
reference, is_truth,
current_variant, allele, sync_extension, is_error
)?;
if success && new_node.is_exact_match() {
let new_priority = new_node.priority();
pqueue.push(new_node, new_priority);
} else {
*failed_counts.entry(order_index).or_default() += 1;
}
}
},
};
auto_fail_counts += 1;
if auto_fail_counts >= auto_fail_threshold {
let (auto_fail_sub_index, auto_fail_is_truth) = all_variant_order[auto_fail_index];
let before_size = pqueue.len();
pqueue = pqueue.into_iter()
.filter(|(node, _priority)| {
let alleles = if auto_fail_is_truth {
node.hap_dwfa().truth_haplotype().alleles()
} else {
node.hap_dwfa().query_haplotype().alleles()
};
let allele = if auto_fail_sub_index < alleles.len() {
alleles[auto_fail_sub_index]
} else {
Allele::Reference
};
allele == Allele::Reference
}).collect();
let after_size = pqueue.len();
trace!("Auto-fail triggered for variant {auto_fail_index}: pqueue size filtered from {before_size} to {after_size}.");
auto_fail_index += 1;
auto_fail_counts = 0;
}
}
debug!("Fewest errors: {best_error_count}");
let best_node = match best_result {
Some(bn) => bn,
None => bail!("No result found for problem")
};
let ret = OptimizedAlleles {
truth_alleles: best_node.hap_dwfa().truth_haplotype().alleles().to_vec(),
query_alleles: best_node.hap_dwfa().query_haplotype().alleles().to_vec(),
num_errors: best_node.num_errors()
};
Ok(ret)
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
struct ExactMatchNode {
node_id: u64,
hap_dwfa: HaplotypeDWFA,
num_errors: usize
}
type NodePriority = (Reverse<usize>, usize, Reverse<u64>);
impl ExactMatchNode {
pub fn new(
node_id: u64, region_start: usize,
) -> Self {
let hap_dwfa = HaplotypeDWFA::new(region_start, 0);
Self {
node_id,
hap_dwfa,
num_errors: 0
}
}
pub fn extend_variant(&mut self,
reference: &[u8], is_truth: bool,
variant: &Variant, allele: Allele, sync_extension: Option<usize>, is_error: bool
) -> Result<bool, HapDWFAError> {
let extended = match self.hap_dwfa.extend_variant(reference, is_truth, variant, allele, sync_extension) {
Ok(b) => b,
Err(e) => {
if is_allowed_error(&e) {
assert!(!self.is_exact_match());
false
} else {
return Err(e);
}
},
};
if is_error {
self.num_errors += 1;
}
Ok(extended)
}
pub fn finalize_dwfas(&mut self, reference: &[u8], region_end: usize) -> Result<bool, HapDWFAError> {
match self.hap_dwfa.finalize_dwfa(reference, region_end) {
Ok(_) => Ok(true),
Err(e) => {
if is_allowed_error(&e) {
assert!(!self.is_exact_match());
Ok(false)
} else {
Err(e)
}
}
}
}
pub fn edit_distance(&self) -> usize {
self.hap_dwfa.edit_distance()
}
pub fn is_exact_match(&self) -> bool {
self.edit_distance() == 0
}
pub fn is_synchronized(&self) -> bool {
self.hap_dwfa.is_synchronized()
}
pub fn priority(&self) -> NodePriority {
(
Reverse(self.num_errors()), self.hap_dwfa.set_alleles() - self.num_errors(), Reverse(self.node_id) )
}
pub fn set_alleles(&self) -> usize {
self.hap_dwfa.set_alleles()
}
pub fn hap_dwfa(&self) -> &HaplotypeDWFA {
&self.hap_dwfa
}
pub fn num_errors(&self) -> usize {
self.num_errors
}
pub fn set_node_id(&mut self, node_id: u64) {
self.node_id = node_id;
}
}
fn is_allowed_error(error: &HapDWFAError) -> bool {
matches!(
error,
HapDWFAError::DError { error: crate::dwfa::dynamic_wfa::DWFAError::MaxEditDistance }
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match_node() {
let reference = b"ACGTACGTACGT";
let mut comparison_node = ExactMatchNode::new(0, 0);
let ins_variant = Variant::new_insertion(0, 4, b"A".to_vec(), b"AC".to_vec()).unwrap();
comparison_node.extend_variant(reference, false, &ins_variant, Allele::Alternate, None, false).unwrap();
let snv_variant = Variant::new_snv(0, 5, b"C".to_vec(), b"G".to_vec()).unwrap();
comparison_node.extend_variant(reference, false, &snv_variant, Allele::Reference, None, true).unwrap();
let snv_variant2 = Variant::new_snv(0, 8, b"A".to_vec(), b"G".to_vec()).unwrap();
comparison_node.extend_variant(reference, false, &snv_variant2, Allele::Alternate, None, false).unwrap();
comparison_node.finalize_dwfas(reference, reference.len()).unwrap();
assert_eq!(comparison_node.edit_distance(), 1);
assert_eq!(comparison_node.num_errors(), 1);
assert_eq!(comparison_node.hap_dwfa().truth_haplotype().sequence(), reference);
assert_eq!(comparison_node.hap_dwfa().truth_haplotype().alleles(), &[]);
assert_eq!(comparison_node.hap_dwfa().query_haplotype().sequence(), b"ACGTACCGTGCGT");
assert_eq!(comparison_node.hap_dwfa().query_haplotype().alleles(), &[Allele::Alternate, Allele::Reference, Allele::Alternate]);
}
#[test]
fn test_optimize_gt_alleles_match() {
let reference = b"ACGTACGTACGT";
let coordinates = Coordinates::new("mock".to_string(), 0, reference.len() as u64);
let truth_variants = vec![
Variant::new_snv(0, 5, b"C".to_vec(), b"G".to_vec()).unwrap()
];
let truth_alleles = vec![
Allele::Alternate
];
let query_variants = truth_variants.clone();
let query_alleles = truth_alleles.clone();
let result = optimize_gt_alleles(
reference, &coordinates,
&truth_variants, &truth_alleles,
&query_variants, &query_alleles
).unwrap();
assert!(result.is_exact_match());
assert_eq!(result.num_errors(), 0);
assert_eq!(result.truth_alleles(), &truth_alleles);
assert_eq!(result.query_alleles(), &query_alleles);
}
#[test]
fn test_optimize_gt_alleles_all_fn() {
let reference = b"ACGTACGTACGT";
let coordinates = Coordinates::new("mock".to_string(), 0, reference.len() as u64);
let truth_variants = vec![
Variant::new_snv(0, 5, b"C".to_vec(), b"G".to_vec()).unwrap()
];
let truth_alleles = vec![
Allele::Alternate
];
let query_variants = vec![];
let query_alleles = vec![];
let result = optimize_gt_alleles(
reference, &coordinates,
&truth_variants, &truth_alleles,
&query_variants, &query_alleles
).unwrap();
assert!(!result.is_exact_match());
assert_eq!(result.num_errors(), 1);
assert_eq!(result.truth_alleles(), &[Allele::Reference]);
assert_eq!(result.query_alleles(), &query_alleles);
}
#[test]
fn test_optimize_gt_alleles_all_fp() {
let reference = b"ACGTACGTACGT";
let coordinates = Coordinates::new("mock".to_string(), 0, reference.len() as u64);
let truth_variants = vec![];
let truth_alleles = vec![];
let query_variants = vec![
Variant::new_snv(0, 5, b"C".to_vec(), b"G".to_vec()).unwrap()
];
let query_alleles = vec![
Allele::Alternate
];
let result = optimize_gt_alleles(
reference, &coordinates,
&truth_variants, &truth_alleles,
&query_variants, &query_alleles
).unwrap();
assert!(!result.is_exact_match());
assert_eq!(result.num_errors(), 1);
assert_eq!(result.truth_alleles(), &truth_alleles);
assert_eq!(result.query_alleles(), &[Allele::Reference]);
}
#[test]
fn test_optimize_gt_alleles_diff_rep() {
let reference = b"ACGTACGTACGT";
let coordinates = Coordinates::new("mock".to_string(), 0, reference.len() as u64);
let truth_variants = vec![
Variant::new_indel(0, 5, b"CGT".to_vec(), b"GGG".to_vec()).unwrap()
];
let truth_alleles = vec![
Allele::Alternate
];
let query_variants = vec![
Variant::new_snv(0, 5, b"C".to_vec(), b"G".to_vec()).unwrap(),
Variant::new_snv(0, 7, b"T".to_vec(), b"G".to_vec()).unwrap()
];
let query_alleles = vec![
Allele::Alternate,
Allele::Alternate
];
let result = optimize_gt_alleles(
reference, &coordinates,
&truth_variants, &truth_alleles,
&query_variants, &query_alleles
).unwrap();
assert!(result.is_exact_match());
assert_eq!(result.num_errors(), 0);
assert_eq!(result.truth_alleles(), &truth_alleles);
assert_eq!(result.query_alleles(), &query_alleles);
}
}