use anyhow::ensure;
use log::debug;
use priority_queue::PriorityQueue;
use std::cmp::Reverse;
use std::hash::Hash;
use crate::data_types::coordinates::Coordinates;
use crate::data_types::phase_enums::{Allele, PhasedZygosity};
use crate::data_types::variants::Variant;
use crate::dwfa::haplotype_dwfa::HaplotypeDWFA;
#[derive(Clone, Debug)]
pub struct OptimizedHaplotypes {
truth_zygosity: Vec<PhasedZygosity>,
truth_seq1: Vec<u8>,
truth_seq2: Vec<u8>,
query_zygosity: Vec<PhasedZygosity>,
query_seq1: Vec<u8>,
query_seq2: Vec<u8>,
ed1: usize,
ed2: usize,
truth_vs1: usize,
truth_vs2: usize,
query_vs1: usize,
query_vs2: usize
}
impl OptimizedHaplotypes {
pub fn is_exact_match(&self) -> bool {
self.ed1 + self.ed2 + self.truth_vs1 + self.truth_vs2 + self.query_vs1 + self.query_vs2 == 0
}
pub fn truth_zygosity(&self) -> &[PhasedZygosity] {
&self.truth_zygosity
}
pub fn truth_seq1(&self) -> &[u8] {
&self.truth_seq1
}
pub fn truth_seq2(&self) -> &[u8] {
&self.truth_seq2
}
pub fn query_zygosity(&self) -> &[PhasedZygosity] {
&self.query_zygosity
}
pub fn query_seq1(&self) -> &[u8] {
&self.query_seq1
}
pub fn query_seq2(&self) -> &[u8] {
&self.query_seq2
}
pub fn ed1(&self) -> usize {
self.ed1
}
pub fn ed2(&self) -> usize {
self.ed2
}
pub fn truth_vs1(&self) -> usize {
self.truth_vs1
}
pub fn truth_vs2(&self) -> usize {
self.truth_vs2
}
pub fn query_vs1(&self) -> usize {
self.query_vs1
}
pub fn query_vs2(&self) -> usize {
self.query_vs2
}
}
pub fn optimize_sequences(
reference: &[u8], coordinates: &Coordinates,
truth_variants: &[Variant], truth_zygosity: &[PhasedZygosity],
query_variants: &[Variant], query_zygosity: &[PhasedZygosity],
max_branch_factor: usize,
) -> anyhow::Result<Vec<OptimizedHaplotypes>> {
debug!("Starting sequence optimization...");
ensure!(truth_variants.len() == truth_zygosity.len(), "truth values must have equal length");
ensure!(query_variants.len() == query_zygosity.len(), "query values must have equal length");
ensure!(max_branch_factor > 0, "max_branch_factor must be greater than 0");
let all_variant_order: Vec<(usize, bool)> = order_variants(truth_variants, query_variants);
let total_variant_count = all_variant_order.len();
let mut next_node_id = 0;
let start = coordinates.start() as usize;
let root_node = ComparisonNode::new(next_node_id, start);
next_node_id += 1;
let priority = root_node.priority();
let mut pqueue: PriorityQueue<ComparisonNode, NodePriority> = Default::default();
pqueue.push(root_node, priority);
let mut best_ed = usize::MAX;
let mut best_results = vec![];
let mut bucket_counts = vec![0; total_variant_count+1];
while let Some((current_node, _priority)) = pqueue.pop() {
if current_node.total_cost() > best_ed {
continue;
}
let order_index = current_node.set_alleles();
if bucket_counts[order_index] == 0 {
debug!("Best path to {order_index} = {:?}, {} <?> {}, {} <?> {}",
current_node.priority(),
current_node.hap_dwfa1().truth_haplotype().sequence().len(),
current_node.hap_dwfa1().query_haplotype().sequence().len(),
current_node.hap_dwfa2().truth_haplotype().sequence().len(),
current_node.hap_dwfa2().query_haplotype().sequence().len()
);
}
if bucket_counts[order_index] >= max_branch_factor {
continue;
}
bucket_counts[order_index] += 1;
if order_index == total_variant_count {
let mut final_node = current_node;
final_node.finalize_dwfas(reference, coordinates.end() as usize)?;
let final_cost = final_node.total_cost();
match final_cost.cmp(&best_ed) {
std::cmp::Ordering::Less => {
best_ed = final_cost;
best_results = vec![final_node];
},
std::cmp::Ordering::Equal => {
best_results.push(final_node);
},
std::cmp::Ordering::Greater => {},
};
continue;
}
let (variant_index, is_truth) = all_variant_order[order_index];
let (current_variant, current_zyg) = if is_truth {
(&truth_variants[variant_index], truth_zygosity[variant_index])
} else {
(&query_variants[variant_index], query_zygosity[variant_index])
};
let next_var_pos = if order_index == all_variant_order.len() - 1 {
coordinates.end() as usize
} else {
let (nvi, nt) = all_variant_order[order_index+1];
if nt { truth_variants[nvi].position() as usize } else { query_variants[nvi].position() as usize }
};
let sync_extension = Some(next_var_pos);
if current_zyg.is_heterozygous() {
if !is_truth || current_zyg == PhasedZygosity::UnphasedHeterozygous {
let extensions = [
(Allele::Reference, Allele::Alternate), (Allele::Alternate, Allele::Reference) ];
for (allele1, allele2) in extensions.into_iter() {
let mut new_node = current_node.clone();
new_node.set_node_id(next_node_id);
next_node_id += 1;
new_node.extend_variant(
reference, is_truth,
current_variant, allele1, allele2, sync_extension
)?;
let new_priority = new_node.priority();
pqueue.push(new_node, new_priority);
}
} else {
let (a1, a2) = match current_zyg {
PhasedZygosity::PhasedHet01 => (Allele::Reference, Allele::Alternate),
PhasedZygosity::PhasedHet10 => (Allele::Alternate, Allele::Reference),
_ => panic!("should not happen")
};
let mut new_node = current_node;
new_node.extend_variant(
reference, is_truth,
current_variant, a1, a2, sync_extension
)?;
let new_priority = new_node.priority();
pqueue.push(new_node, new_priority);
}
} else {
assert_eq!(current_zyg, PhasedZygosity::HomozygousAlternate);
let mut new_node = current_node;
new_node.extend_variant(
reference, is_truth,
current_variant, Allele::Alternate, Allele::Alternate, sync_extension
)?;
let new_priority = new_node.priority();
pqueue.push(new_node, new_priority);
};
}
debug!("Best ED: {best_ed}");
ensure!(!best_results.is_empty(), "no results found");
let ret = best_results.into_iter()
.map(|best_node| {
let hap_dwfa1 = best_node.hap_dwfa1();
let hap_dwfa2 = best_node.hap_dwfa2();
let truth_hap1 = hap_dwfa1.truth_haplotype();
let truth_hap2 = hap_dwfa2.truth_haplotype();
let query_hap1 = hap_dwfa1.query_haplotype();
let query_hap2 = hap_dwfa2.query_haplotype();
let truth_zygosity = convert_alleles_to_zygosity(truth_hap1.alleles(), truth_hap2.alleles());
let query_zygosity = convert_alleles_to_zygosity(query_hap1.alleles(), query_hap2.alleles());
OptimizedHaplotypes {
truth_zygosity,
truth_seq1: truth_hap1.sequence().to_vec(),
truth_seq2: truth_hap2.sequence().to_vec(),
query_zygosity,
query_seq1: query_hap1.sequence().to_vec(),
query_seq2: query_hap2.sequence().to_vec(),
ed1: hap_dwfa1.edit_distance(),
ed2: hap_dwfa2.edit_distance(),
truth_vs1: hap_dwfa1.truth_haplotype().variant_skip_distance(),
truth_vs2: hap_dwfa2.truth_haplotype().variant_skip_distance(),
query_vs1: hap_dwfa1.query_haplotype().variant_skip_distance(),
query_vs2: hap_dwfa2.query_haplotype().variant_skip_distance(),
}
}).collect();
Ok(ret)
}
pub fn order_variants(truth_variants: &[Variant], query_variants: &[Variant]) -> Vec<(usize, bool)> {
let mut ret: Vec<(usize, bool)> = Vec::<(usize, bool)>::with_capacity(truth_variants.len()+query_variants.len());
ret.extend((0..truth_variants.len()).zip(std::iter::repeat(true)));
ret.extend((0..query_variants.len()).zip(std::iter::repeat(false)));
ret.sort_by_key(|&(i, is_truth)| if is_truth { truth_variants[i].position() } else { query_variants[i].position() });
ret
}
fn convert_alleles_to_zygosity(alleles1: &[Allele], alleles2: &[Allele]) -> Vec<PhasedZygosity> {
assert_eq!(alleles1.len(), alleles2.len());
alleles1.iter()
.zip(alleles2.iter())
.map(|(a1, a2)| {
match (a1, a2) {
(Allele::Reference, Allele::Alternate) => PhasedZygosity::PhasedHet01,
(Allele::Alternate, Allele::Reference) => PhasedZygosity::PhasedHet10,
(Allele::Alternate, Allele::Alternate) => PhasedZygosity::HomozygousAlternate,
_ => panic!("no impl")
}
})
.collect()
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
struct ComparisonNode {
node_id: u64,
hap_dwfa1: HaplotypeDWFA,
hap_dwfa2: HaplotypeDWFA
}
type NodePriority = (Reverse<usize>, Reverse<u64>);
impl ComparisonNode {
pub fn new(
node_id: u64, region_start: usize,
) -> Self {
let hap_dwfa1 = HaplotypeDWFA::new(region_start, usize::MAX);
let hap_dwfa2 = HaplotypeDWFA::new(region_start, usize::MAX);
Self {
node_id,
hap_dwfa1,
hap_dwfa2
}
}
pub fn extend_variant(&mut self,
reference: &[u8], is_truth: bool,
variant: &Variant, allele1: Allele, allele2: Allele, sync_extension: Option<usize>
) -> anyhow::Result<bool> {
let mut both_extended = true;
both_extended &= self.hap_dwfa1.extend_variant(reference, is_truth, variant, allele1, sync_extension)?;
both_extended &= self.hap_dwfa2.extend_variant(reference, is_truth, variant, allele2, sync_extension)?;
Ok(both_extended)
}
pub fn finalize_dwfas(&mut self, reference: &[u8], region_end: usize) -> anyhow::Result<()> {
self.hap_dwfa1.finalize_dwfa(reference, region_end)?;
self.hap_dwfa2.finalize_dwfa(reference, region_end)?;
Ok(())
}
pub fn total_cost(&self) -> usize {
self.hap_dwfa1.total_cost() + self.hap_dwfa2.total_cost()
}
pub fn priority(&self) -> NodePriority {
(
Reverse(self.total_cost()), Reverse(self.node_id) )
}
pub fn set_alleles(&self) -> usize {
self.hap_dwfa1.set_alleles()
}
pub fn hap_dwfa1(&self) -> &HaplotypeDWFA {
&self.hap_dwfa1
}
pub fn hap_dwfa2(&self) -> &HaplotypeDWFA {
&self.hap_dwfa2
}
pub fn set_node_id(&mut self, node_id: u64) {
self.node_id = node_id;
}
}
#[cfg(test)]
mod tests {
use super::*;
const MAX_BRANCH_FACTOR_TEST: usize = 50;
#[test]
fn test_comparison_node() {
let reference = b"ACGTACGTACGT";
let mut comparison_node = ComparisonNode::new(0, 0);
let ins_variant = Variant::new_insertion(0, 4, b"A".to_vec(), b"AC".to_vec()).unwrap();
comparison_node.extend_variant(reference, false, &ins_variant, Allele::Alternate, Allele::Reference, None).unwrap();
let snv_variant = Variant::new_snv(0, 5, b"C".to_vec(), b"G".to_vec()).unwrap();
comparison_node.extend_variant(reference, false, &snv_variant, Allele::Reference, Allele::Alternate, None).unwrap();
let snv_variant2 = Variant::new_snv(0, 8, b"A".to_vec(), b"G".to_vec()).unwrap();
comparison_node.extend_variant(reference, false, &snv_variant2, Allele::Alternate, Allele::Alternate, None).unwrap();
comparison_node.finalize_dwfas(reference, reference.len()).unwrap();
assert_eq!(comparison_node.total_cost(), 4); assert_eq!(comparison_node.hap_dwfa1().query_haplotype().sequence(), b"ACGTACCGTGCGT");
assert_eq!(comparison_node.hap_dwfa1().query_haplotype().alleles(), &[Allele::Alternate, Allele::Reference, Allele::Alternate]);
assert_eq!(comparison_node.hap_dwfa2().query_haplotype().sequence(), b"ACGTAGGTGCGT");
assert_eq!(comparison_node.hap_dwfa2().query_haplotype().alleles(), &[Allele::Reference, Allele::Alternate, Allele::Alternate]);
}
#[test]
fn test_optimize_query_sequences_001() {
let reference = b"ACGTACGTACGT";
let coordinates = Coordinates::new("chrom".to_string(), 0, reference.len() as u64);
let truth_variants = [
Variant::new_insertion(0, 4, b"A".to_vec(), b"AC".to_vec()).unwrap(),
Variant::new_snv(0, 5, b"C".to_vec(), b"G".to_vec()).unwrap(),
];
let truth_zygosity = [
PhasedZygosity::PhasedHet10,
PhasedZygosity::PhasedHet01,
];
let query_variants = [
Variant::new_insertion(0, 4, b"A".to_vec(), b"AC".to_vec()).unwrap(),
Variant::new_snv(0, 5, b"C".to_vec(), b"G".to_vec()).unwrap(),
Variant::new_snv(0, 8, b"A".to_vec(), b"G".to_vec()).unwrap()
];
let query_zygosity = [
PhasedZygosity::PhasedHet10,
PhasedZygosity::PhasedHet01,
PhasedZygosity::HomozygousAlternate
];
let sequences = optimize_sequences(
reference, &coordinates, &truth_variants, &truth_zygosity, &query_variants, &query_zygosity, MAX_BRANCH_FACTOR_TEST
).unwrap()[0].clone();
assert_eq!(sequences.ed1(), 1);
assert_eq!(sequences.ed2(), 1);
assert_eq!(sequences.truth_seq1(), b"ACGTACCGTACGT");
assert_eq!(sequences.truth_seq2(), b"ACGTAGGTACGT");
assert_eq!(sequences.truth_zygosity(), &truth_zygosity);
assert_eq!(sequences.query_seq1(), b"ACGTACCGTGCGT");
assert_eq!(sequences.query_seq2(), b"ACGTAGGTGCGT");
assert_eq!(sequences.query_zygosity(), &query_zygosity);
}
#[test]
fn test_optimize_query_sequences_all_fn() {
let reference = b"ACGTACGTACGT";
let coordinates = Coordinates::new("chrom".to_string(), 0, reference.len() as u64);
let truth_variants = [
Variant::new_insertion(0, 4, b"A".to_vec(), b"AC".to_vec()).unwrap(),
Variant::new_snv(0, 5, b"C".to_vec(), b"G".to_vec()).unwrap(),
];
let truth_zygosity = [
PhasedZygosity::PhasedHet10,
PhasedZygosity::PhasedHet01,
];
let query_variants = [];
let query_zygosity = [];
let sequences = optimize_sequences(
reference, &coordinates, &truth_variants, &truth_zygosity, &query_variants, &query_zygosity, MAX_BRANCH_FACTOR_TEST
).unwrap()[0].clone();
assert_eq!(sequences.ed1(), 1);
assert_eq!(sequences.ed2(), 1);
assert_eq!(sequences.query_seq1(), reference);
assert_eq!(sequences.query_seq2(), reference);
assert_eq!(sequences.query_zygosity(), &[]);
}
#[test]
fn test_optimize_query_sequences_multiallelic() {
let reference = b"ACGTACGTACGT";
let coordinates = Coordinates::new("chrom".to_string(), 0, reference.len() as u64);
let truth_seq1 = b"ACGTAAGTACGT"; let truth_seq2 = b"ACGTAGGTACGT"; let shared_variants = [
Variant::new_snv(0, 5, b"C".to_vec(), b"A".to_vec()).unwrap(),
Variant::new_snv(0, 5, b"C".to_vec(), b"G".to_vec()).unwrap()
];
let truth_zygosity = [
PhasedZygosity::PhasedHet10,
PhasedZygosity::PhasedHet01,
];
let query_zygosity = [
PhasedZygosity::UnphasedHeterozygous,
PhasedZygosity::UnphasedHeterozygous,
];
let query_sequences = optimize_sequences(
reference, &coordinates, &shared_variants, &truth_zygosity, &shared_variants, &query_zygosity, MAX_BRANCH_FACTOR_TEST
).unwrap()[0].clone();
assert_eq!(query_sequences.ed1(), 0);
assert_eq!(query_sequences.ed2(), 0);
assert_eq!(query_sequences.query_seq1(), truth_seq1);
assert_eq!(query_sequences.query_seq2(), truth_seq2);
assert_eq!(query_sequences.query_zygosity(), &truth_zygosity);
}
#[test]
fn test_optimize_query_sequences_incompatible() {
let reference = b"ACGTACGTACGT";
let coordinates = Coordinates::new("chrom".to_string(), 0, reference.len() as u64);
let shared_variants = [
Variant::new_snv(0, 5, b"C".to_vec(), b"A".to_vec()).unwrap(),
Variant::new_snv(0, 5, b"C".to_vec(), b"G".to_vec()).unwrap()
];
let truth_zygosity = [
PhasedZygosity::HomozygousAlternate,
PhasedZygosity::HomozygousAlternate,
];
let query_zygosity = [
PhasedZygosity::HomozygousAlternate,
PhasedZygosity::HomozygousAlternate,
];
let query_sequences = optimize_sequences(
reference, &coordinates, &shared_variants, &truth_zygosity, &shared_variants, &query_zygosity, MAX_BRANCH_FACTOR_TEST
).unwrap()[0].clone();
assert_eq!(query_sequences.ed1(), 0);
assert_eq!(query_sequences.ed2(), 0);
assert_eq!(query_sequences.truth_vs1(), 1); assert_eq!(query_sequences.truth_vs2(), 1);
assert_eq!(query_sequences.query_vs1(), 1);
assert_eq!(query_sequences.query_vs2(), 1);
assert_eq!(query_sequences.query_zygosity(), &truth_zygosity); }
}