use crate::shared::data_structures::RangeArray1;
use crate::shared::feature::Feature;
use crate::shared::{InferenceParameters, VJAlignment};
use crate::v_dj::Features;
use crate::vdj::feature::{AggregatedFeatureSpanD, FeatureDJ};
use crate::vdj::AggregatedFeatureStartJ;
use std::cmp;
pub struct AggregatedFeatureStartDAndJ {
pub start_d5: i64,
pub end_d5: i64,
likelihood: RangeArray1,
dirty_likelihood: RangeArray1,
feature_j: AggregatedFeatureStartJ,
pub most_likely_d_end: i64,
pub most_likely_j_start: i64,
pub most_likely_d_index: usize,
}
impl AggregatedFeatureStartDAndJ {
pub fn new(
j_alignment: &VJAlignment,
feature_ds: &[AggregatedFeatureSpanD],
feat_insdj: &FeatureDJ,
feat: &Features,
ip: &InferenceParameters,
) -> Option<AggregatedFeatureStartDAndJ> {
let Some(feature_j) = AggregatedFeatureStartJ::new(j_alignment, feat, ip) else {
return None;
};
let mut likelihood = RangeArray1::zeros((
feature_ds.iter().map(|x| x.start_d5).min().unwrap(),
feature_ds.iter().map(|x| x.end_d5).max().unwrap(),
));
let mut total_likelihood = 0.;
let mut most_likely_d_end = 0;
let mut most_likely_d_index = 0;
let mut most_likely_j_start = 0;
let mut best_likelihood = 0.;
feature_ds.iter().for_each(|agg_d| {
let ll_dj = feat.d.likelihood((agg_d.index, feature_j.index)); feat_insdj.iter().for_each(|(d_end, j_start, ll_ins_dj)| {
if (j_start >= feature_j.start_j5) && (j_start < feature_j.end_j5) {
let ll_ins_jd = feature_j.likelihood(j_start) * ll_ins_dj * ll_dj;
agg_d.iter_fixed_dend(d_end).for_each(|(d_start, ll_deld)| {
let ll = ll_ins_jd * ll_deld;
if ll > ip.min_likelihood {
if ll > best_likelihood {
most_likely_d_end = d_end;
most_likely_d_index = agg_d.index;
most_likely_j_start = j_start;
best_likelihood = ll;
}
*likelihood.get_mut(d_start) += ll;
total_likelihood += ll;
}
});
}
});
});
if total_likelihood == 0. {
return None;
}
Some(AggregatedFeatureStartDAndJ {
start_d5: likelihood.min,
end_d5: likelihood.max,
dirty_likelihood: RangeArray1::zeros(likelihood.dim()),
likelihood,
feature_j: feature_j,
most_likely_d_end,
most_likely_j_start,
most_likely_d_index,
})
}
pub fn j_start_seq(&self) -> usize {
self.feature_j.start_seq
}
pub fn j_index(&self) -> usize {
self.feature_j.index
}
pub fn likelihood(&self, sd: i64) -> f64 {
self.likelihood.get(sd)
}
pub fn max_likelihood(&self) -> f64 {
self.likelihood.max_value()
}
pub fn dirty_update(&mut self, sd: i64, likelihood: f64) {
*self.dirty_likelihood.get_mut(sd) += likelihood;
}
pub fn disaggregate(
&mut self,
j_alignment: &VJAlignment,
feature_ds: &mut [AggregatedFeatureSpanD],
feat_insdj: &mut FeatureDJ,
feat: &mut Features,
ip: &InferenceParameters,
) {
let (min_sj, max_sj) = (
cmp::max(self.feature_j.start_j5, feat_insdj.min_sj()),
cmp::min(self.feature_j.end_j5, feat_insdj.max_sj()),
);
for agg_deld in feature_ds.iter_mut() {
let (min_ed, max_ed) = (
cmp::max(agg_deld.start_d3, feat_insdj.min_ed()),
cmp::min(agg_deld.end_d3, feat_insdj.max_ed()),
);
let ll_dj = feat.d.likelihood((agg_deld.index, self.feature_j.index));
for d_start in agg_deld.start_d5..agg_deld.end_d5 {
let dirty_proba = self.dirty_likelihood.get(d_start);
let ratio_old_new = dirty_proba / self.likelihood(d_start);
for j_start in min_sj..max_sj {
let ll_j = self.feature_j.likelihood(j_start);
if ll_dj * ll_j < ip.min_likelihood {
continue;
}
for d_end in cmp::max(d_start - 1, min_ed)..cmp::min(max_ed, j_start + 1) {
let ll_deld = agg_deld.likelihood(d_start, d_end);
let ll_insdj = feat_insdj.likelihood(d_end, j_start);
let ll = ll_dj * ll_insdj * ll_j * ll_deld;
if ll > ip.min_likelihood {
let corrected_proba = ratio_old_new * ll;
if dirty_proba > 0. {
feat.d.dirty_update(
(agg_deld.index, self.feature_j.index),
corrected_proba,
);
if ip.infer_insertions {
feat_insdj.dirty_update(d_end, j_start, corrected_proba);
}
self.feature_j.dirty_update(j_start, corrected_proba);
agg_deld.dirty_update(d_start, d_end, corrected_proba);
}
}
}
}
}
}
self.feature_j.disaggregate(j_alignment, feat, ip)
}
}