use crate::calling::haplotypes::haplotypes::{
AlleleFreqDist, CandidateMatrix, Haplotype, HaplotypeGraph, PriorTypes, VariantCalls,
};
use crate::calling::haplotypes::hla::PopFreq;
use bio::stats::probs::adaptive_integration;
use bio::stats::{bayesian::model, LogProb, Prob};
use bv::BitVec;
use derefable::Derefable;
use derive_new::new;
use ordered_float::NotNan;
use petgraph::visit::Bfs;
use std::collections::{BTreeMap, HashMap};
pub type AlleleFreq = NotNan<f64>;
#[derive(Hash, PartialEq, Eq, Clone, Debug, Derefable, PartialOrd)]
pub struct HaplotypeFractions(#[deref] pub Vec<AlleleFreq>);
#[derive(Debug, new)]
pub(crate) struct Marginal {
n_haplotypes: usize,
haplotypes: Vec<Haplotype>,
prior_info: PriorTypes,
haplotype_graph: Option<HaplotypeGraph>,
enable_equivalence_class_constraint: bool,
application: String,
}
impl Marginal {
pub(crate) fn calc_marginal<
F: FnMut(&<Self as model::Marginal>::Event, &<Self as model::Marginal>::Data) -> LogProb,
>(
&self,
data: &Data,
haplotype_index: usize,
fractions: &mut [AlleleFreq],
joint_prob: &mut F,
) -> LogProb {
if haplotype_index == self.n_haplotypes {
let event = HaplotypeFractions(fractions.to_vec());
joint_prob(&event, data)
} else {
let fraction_upper_bound =
NotNan::new(1.00).unwrap() - fractions.iter().sum::<NotNan<f64>>();
let mut density = |fraction| {
if self.enable_equivalence_class_constraint
&& fraction > NotNan::new(0.0).unwrap()
&& fractions.len() > 1
{
if self.application == "hla".to_string() {
let current_haplotype = &self.haplotypes[haplotype_index];
let splitted = &self.haplotypes[haplotype_index]
.split(':')
.collect::<Vec<&str>>();
let haplotype_group =
Haplotype(splitted[0].to_owned() + &":" + splitted[1]);
if let Some(haplotype_graph) = &self.haplotype_graph {
let index = haplotype_graph
.get_node_index(&(current_haplotype.clone(), haplotype_group))
.unwrap();
let mut bfs = Bfs::new(&**haplotype_graph, index);
while let Some(nx) = bfs.next(&**haplotype_graph) {
let haplotype_query = &haplotype_graph[nx].0;
for (h, f) in self.haplotypes[0..haplotype_index]
.to_vec()
.iter()
.zip(fractions[0..haplotype_index].to_vec().iter())
{
if (h == haplotype_query) && (f > &NotNan::new(0.0).unwrap()) {
return LogProb::ln_zero();
}
}
}
}
}
}
let mut fractions = fractions.to_vec();
fractions.push(fraction);
self.calc_marginal(data, haplotype_index + 1, &mut fractions, joint_prob)
};
if haplotype_index == self.n_haplotypes - 1 {
density(fraction_upper_bound)
} else {
if fraction_upper_bound == NotNan::new(0.0).unwrap() {
density(NotNan::new(0.0).unwrap())
} else {
if self.prior_info == PriorTypes::Diploid {
let mut probs = Vec::new();
let mut diploid_points = |point, probs: &mut Vec<_>| {
let fractions = fractions.to_vec();
if fractions.iter().sum::<NotNan<f64>>() + point
<= NotNan::new(1.0).unwrap()
{
probs.push(density(point));
} else {
()
}
};
diploid_points(NotNan::new(0.0).unwrap(), &mut probs);
diploid_points(NotNan::new(0.5).unwrap(), &mut probs);
diploid_points(NotNan::new(1.0).unwrap(), &mut probs);
LogProb::ln_sum_exp(&probs)
} else if self.prior_info == PriorTypes::Uniform
|| self.prior_info == PriorTypes::DiploidSubclonal
{
adaptive_integration::ln_integrate_exp(
density,
NotNan::new(0.0).unwrap(),
fraction_upper_bound,
NotNan::new(0.1).unwrap(),
)
} else {
panic!("uniform, prior or diploid-subclonal must be selected")
}
}
}
}
}
}
impl model::Marginal for Marginal {
type Event = HaplotypeFractions;
type Data = Data;
type BaseEvent = HaplotypeFractions;
fn compute<F: FnMut(&Self::Event, &Self::Data) -> LogProb>(
&self,
data: &Self::Data,
joint_prob: &mut F,
) -> LogProb {
let mut fractions: Vec<AlleleFreq> = Vec::new();
self.calc_marginal(data, 0, &mut fractions, joint_prob)
}
}
#[derive(Debug, new)]
pub struct Data {
pub candidate_matrix: CandidateMatrix,
pub variant_calls: VariantCalls,
}
#[derive(Debug, new)]
pub(crate) struct Likelihood;
impl model::Likelihood<Cache> for Likelihood {
type Event = HaplotypeFractions;
type Data = Data;
fn compute(&self, event: &Self::Event, data: &Self::Data, payload: &mut Cache) -> LogProb {
self.compute_varlociraptor(event, data, payload)
}
}
impl Likelihood {
fn compute_varlociraptor(
&self,
event: &HaplotypeFractions,
data: &Data,
_cache: &mut Cache,
) -> LogProb {
let candidate_matrix_values: Vec<(BitVec, BitVec)> =
data.candidate_matrix.values().cloned().collect();
let variant_calls: Vec<(AlleleFreqDist, i32)> = data
.variant_calls
.iter()
.map(|(_, call)| (call.afd.clone(), call.dp))
.collect();
let mut final_prob = LogProb::ln_one();
candidate_matrix_values
.iter()
.zip(variant_calls.iter())
.for_each(|((genotypes, covered), (afd, dp))| {
if *dp != 0 {
let mut denom = NotNan::new(1.0).unwrap();
let mut vaf_sum = NotNan::new(0.0).unwrap();
event.iter().enumerate().for_each(|(i, fraction)| {
if genotypes[i as u64] && covered[i as u64] {
vaf_sum += *fraction;
}
else if !genotypes[i as u64] && !covered[i as u64] {
denom -= *fraction;
}
});
if denom > NotNan::new(0.0).unwrap() {
vaf_sum /= denom;
}
vaf_sum = NotNan::new((vaf_sum * NotNan::new(100.0).unwrap()).round()).unwrap()
/ NotNan::new(100.0).unwrap();
if !afd.is_empty() {
final_prob += afd.vaf_query(&vaf_sum).unwrap();
} else {
final_prob += LogProb::ln_one();
}
} else {
final_prob += LogProb::ln_one();
}
});
final_prob
}
}
#[derive(Debug, new)]
pub(crate) struct PloidyPrior {
pub prior: PriorTypes,
}
impl model::Prior for PloidyPrior {
type Event = HaplotypeFractions;
fn compute(&self, event: &Self::Event) -> LogProb {
if self.prior == PriorTypes::Diploid {
let mut prior_prob = LogProb::ln_one();
event.iter().for_each(|fraction| {
if *fraction == NotNan::new(0.0).unwrap()
|| *fraction == NotNan::new(0.5).unwrap()
|| *fraction == NotNan::new(1.0).unwrap()
{
prior_prob += LogProb::from(Prob(1.0 / 3.0))
} else {
prior_prob += LogProb::ln_zero()
}
});
prior_prob
} else if self.prior == PriorTypes::DiploidSubclonal {
if event
.iter()
.filter(|&n| n > &NotNan::new(0.0).unwrap())
.count()
> 4
{
LogProb::ln_zero()
} else {
LogProb::ln_one()
}
} else {
LogProb::ln_one()
}
}
}
pub(crate) struct PopulationPrior<'a> {
pub haplotypes: &'a Vec<Haplotype>,
pub pop_freqs: &'a BTreeMap<String, f64>,
pub ploidy_prior: &'a PriorTypes,
}
impl<'a> model::Prior for PopulationPrior<'a> {
type Event = HaplotypeFractions;
fn compute(&self, event: &Self::Event) -> LogProb {
let mut prior = LogProb::ln_one();
let epsilon = 1e-12;
let a = match self.ploidy_prior {
PriorTypes::Diploid => 2,
PriorTypes::DiploidSubclonal => 3,
PriorTypes::Uniform => 1,
};
for (haplotype, fraction) in self.haplotypes.iter().zip(event.iter()) {
let hap_str = haplotype.to_string();
if let Some(&freq) = self.pop_freqs.get(&hap_str) {
let freq_max = freq.max(epsilon);
let weight = fraction.into_inner() * a as f64;
prior += LogProb::from(Prob(freq_max.powf(weight)));
}
}
prior
}
}
pub(crate) struct CombinedPrior<P1, P2> {
pub p1: P1,
pub p2: Option<P2>,
}
impl<P1, P2> model::Prior for CombinedPrior<P1, P2>
where
P1: model::Prior<Event = HaplotypeFractions>,
P2: model::Prior<Event = HaplotypeFractions>,
{
type Event = HaplotypeFractions;
fn compute(&self, event: &Self::Event) -> LogProb {
let base = self.p1.compute(event);
match &self.p2 {
Some(p2) => base + p2.compute(event),
None => base,
}
}
}
#[derive(Debug, new)]
pub(crate) struct Posterior;
impl model::Posterior for Posterior {
type Event = HaplotypeFractions;
type BaseEvent = HaplotypeFractions;
type Data = Data;
fn compute<F: FnMut(&Self::BaseEvent, &Self::Data) -> LogProb>(
&self,
event: &Self::Event,
data: &Self::Data,
joint_prob: &mut F,
) -> LogProb {
joint_prob(event, data)
}
}
#[derive(Debug, Derefable, Default)]
pub(crate) struct Cache(#[deref] HashMap<usize, HashMap<AlleleFreq, LogProb>>);