use std::sync::atomic::Ordering;
use crate::util::constants;
use crate::util::oarfish_types::{AlnInfo, EMInfo, TranscriptInfo};
use atomic_float::AtomicF64;
use itertools::izip;
use num_format::{Locale, ToFormattedString};
use rand::rng as trng;
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use tracing::{info, span, trace};
use crate::bootstrap;
type EqIterateT<'a> = (&'a [AlnInfo], &'a [f32], &'a [f64]);
#[inline]
fn m_step_par<DFn>(
eq_iterates: &[EqIterateT],
tinfo: &[TranscriptInfo],
model_coverage: bool,
density_fn: DFn,
prev_count: &mut [AtomicF64],
curr_counts: &mut [AtomicF64],
) where
DFn: Fn(usize, usize) -> f64 + Sync,
{
eq_iterates.par_iter().for_each_with(
&curr_counts,
|curr_counts, (alns, probs, coverage_probs)| {
let mut denom = 0.0_f64;
for (a, p, cp) in izip!(*alns, *probs, *coverage_probs) {
let target_id = a.ref_id as usize;
let txp_len = tinfo[target_id].lenf as usize;
let aln_len = a.alignment_span() as usize;
let prob = *p as f64;
let cov_prob = if model_coverage { *cp } else { 1.0 };
let dens_prob = density_fn(txp_len, aln_len);
denom +=
prev_count[target_id].load(Ordering::Relaxed) * prob * cov_prob * dens_prob;
}
if denom > constants::EM_DENOM_THRESH {
for (a, p, cp) in izip!(*alns, *probs, *coverage_probs) {
let target_id = a.ref_id as usize;
let txp_len = tinfo[target_id].lenf as usize;
let aln_len = a.alignment_span() as usize;
let prob = *p as f64;
let cov_prob = if model_coverage { *cp } else { 1.0 };
let dens_prob = density_fn(txp_len, aln_len);
let inc = (prev_count[target_id].load(Ordering::Relaxed)
* prob
* cov_prob
* dens_prob)
/ denom;
curr_counts[target_id].fetch_add(inc, Ordering::AcqRel);
}
}
},
);
}
#[inline]
fn m_step<'a, DFn, I: Iterator<Item = (&'a [AlnInfo], &'a [f32], &'a [f64])>>(
eq_map_iter: I,
tinfo: &[TranscriptInfo],
model_coverage: bool,
density_fn: DFn,
prev_count: &mut [f64],
curr_counts: &mut [f64],
) where
DFn: Fn(usize, usize) -> f64,
{
for (alns, probs, coverage_probs) in eq_map_iter {
let mut denom = 0.0_f64;
for (a, p, cp) in izip!(alns, probs, coverage_probs) {
let target_id = a.ref_id as usize;
let txp_len = tinfo[target_id].lenf as usize;
let aln_len = a.alignment_span() as usize;
let prob = *p as f64;
let cov_prob = if model_coverage { *cp } else { 1.0 };
let dens_prob = density_fn(txp_len, aln_len);
denom += prev_count[target_id] * prob * cov_prob * dens_prob;
}
if denom > constants::EM_DENOM_THRESH {
for (a, p, cp) in izip!(alns, probs, coverage_probs) {
let target_id = a.ref_id as usize;
let txp_len = tinfo[target_id].lenf as usize;
let aln_len = a.alignment_span() as usize;
let prob = *p as f64;
let cov_prob = if model_coverage { *cp } else { 1.0 };
let dens_prob = density_fn(txp_len, aln_len);
let inc = (prev_count[target_id] * prob * cov_prob * dens_prob) / denom;
curr_counts[target_id] += inc;
}
}
}
}
pub fn do_em<'a, I: Iterator<Item = (&'a [AlnInfo], &'a [f32], &'a [f64])> + 'a, F: Fn() -> I>(
em_info: &'a EMInfo,
make_iter: F,
do_log: bool,
) -> Vec<f64> {
let eq_map = em_info.eq_map;
let fops = &eq_map.filter_opts;
let tinfo: &[TranscriptInfo] = em_info.txp_info;
let max_iter = em_info.max_iter;
let convergence_thresh = em_info.convergence_thresh;
let total_weight: f64 = eq_map.num_aligned_reads() as f64;
let mut prev_counts: Vec<f64>;
let mut curr_counts: Vec<f64> = vec![0.0f64; tinfo.len()];
if let Some(ref init_counts) = em_info.init_abundances {
prev_counts = init_counts.clone();
} else {
let avg = total_weight / (tinfo.len() as f64);
prev_counts = vec![avg; tinfo.len()];
}
let mut rel_diff = 0.0_f64;
let mut niter = 0_u32;
let mut _fl_prob = 0.5f64;
let density_fn = |x, y| -> f64 {
match em_info.kde_model {
Some(ref kde_model) => kde_model[(x, y)],
_ => 1.,
}
};
while niter < max_iter {
m_step(
make_iter(),
tinfo,
fops.model_coverage,
density_fn,
&mut prev_counts,
&mut curr_counts,
);
for i in 0..curr_counts.len() {
if prev_counts[i] > constants::MIN_READ_THRESH {
let cc = curr_counts[i];
let pc = prev_counts[i];
let rd = (cc - pc) / pc;
rel_diff = rel_diff.max(rd);
}
}
std::mem::swap(&mut prev_counts, &mut curr_counts);
curr_counts.fill(0.0_f64);
if (rel_diff < convergence_thresh) && (niter > 50) {
break;
}
niter += 1;
if do_log && (niter % 10 == 0) {
if niter % 100 == 0 {
info!(
"iteration {}; rel diff {}",
niter.to_formatted_string(&Locale::en),
rel_diff
);
} else {
trace!(
"iteration {}; rel diff {}",
niter.to_formatted_string(&Locale::en),
rel_diff
);
}
}
rel_diff = 0.0_f64;
}
for x in &mut prev_counts {
if *x < constants::MIN_READ_THRESH {
*x = 0.0;
}
}
m_step(
make_iter(),
tinfo,
fops.model_coverage,
density_fn,
&mut prev_counts,
&mut curr_counts,
);
curr_counts
}
#[allow(dead_code)]
pub fn em(em_info: &EMInfo, _nthreads: usize) -> Vec<f64> {
let span = span!(tracing::Level::INFO, "em");
let _guard = span.enter();
let make_iter = || em_info.eq_map.iter();
do_em(em_info, make_iter, true)
}
pub fn do_bootstrap(em_info: &EMInfo) -> Vec<f64> {
let mut rng = trng();
let n = em_info.eq_map.len();
let inds = bootstrap::get_sample_inds(n, &mut rng);
let make_iter = || em_info.eq_map.random_sampling_iter(&inds);
do_em(em_info, make_iter, false)
}
pub fn bootstrap(em_info: &EMInfo, num_boot: u32, nthreads: usize) -> Vec<Vec<f64>> {
let span = span!(tracing::Level::INFO, "bootstrap");
let _guard = span.enter();
info!("will collection {num_boot} bootstraps");
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(nthreads)
.build()
.unwrap();
pool.install(|| {
(0..num_boot)
.into_par_iter()
.map(|i| {
let span = span!(tracing::Level::INFO, "bootstrap");
let _guard = span.enter();
info!("evaluating bootstrap replicate {}", i);
do_bootstrap(em_info)
})
.collect()
})
}
pub fn em_par(em_info: &EMInfo, nthreads: usize) -> Vec<f64> {
let span = span!(tracing::Level::INFO, "em");
let _guard = span.enter();
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(nthreads)
.build()
.unwrap();
let eq_map = em_info.eq_map;
let fops = &eq_map.filter_opts;
let tinfo: &[TranscriptInfo] = em_info.txp_info;
let max_iter = em_info.max_iter;
let convergence_thresh = em_info.convergence_thresh;
let total_weight: f64 = eq_map.num_aligned_reads() as f64;
let eq_iterates: Vec<EqIterateT> = eq_map.iter().collect();
let prev_counts: Vec<f64>;
let mut curr_counts: Vec<AtomicF64> = vec![0.0f64; tinfo.len()]
.iter()
.map(|x| AtomicF64::new(*x))
.collect();
if let Some(ref init_counts) = em_info.init_abundances {
prev_counts = init_counts.clone();
} else {
let avg = total_weight / (tinfo.len() as f64);
prev_counts = vec![avg; tinfo.len()];
}
let mut prev_counts: Vec<AtomicF64> = prev_counts.iter().map(|x| AtomicF64::new(*x)).collect();
let mut rel_diff = 0.0_f64;
let mut niter = 0_u32;
let mut _fl_prob = 0.5f64;
let density_fn = |x, y| -> f64 {
match em_info.kde_model {
Some(ref kde_model) => kde_model[(x, y)],
_ => 1.,
}
};
pool.install(|| {
while niter < max_iter {
m_step_par(
&eq_iterates,
tinfo,
fops.model_coverage,
density_fn,
&mut prev_counts,
&mut curr_counts,
);
for i in 0..curr_counts.len() {
if prev_counts[i].load(Ordering::Relaxed) > constants::MIN_READ_THRESH {
let cc = curr_counts[i].load(Ordering::Relaxed);
let pc = prev_counts[i].load(Ordering::Relaxed);
let rd = (cc - pc) / pc;
rel_diff = rel_diff.max(rd);
}
}
std::mem::swap(&mut prev_counts, &mut curr_counts);
curr_counts
.par_iter()
.for_each(|x| x.store(0.0f64, Ordering::Relaxed));
if (rel_diff < convergence_thresh) && (niter > 1) {
break;
}
niter += 1;
if niter % 10 == 0 {
if niter % 100 == 0 {
info!(
"iteration {}; rel diff {}",
niter.to_formatted_string(&Locale::en),
rel_diff
);
} else {
trace!(
"iteration {}; rel diff {}",
niter.to_formatted_string(&Locale::en),
rel_diff
);
}
}
rel_diff = 0.0_f64;
}
prev_counts.iter_mut().for_each(|x| {
if x.load(Ordering::Relaxed) < constants::MIN_READ_THRESH {
x.store(0.0, Ordering::Relaxed);
}
});
m_step_par(
&eq_iterates,
tinfo,
fops.model_coverage,
density_fn,
&mut prev_counts,
&mut curr_counts,
);
});
curr_counts
.iter()
.map(|x| x.load(Ordering::Relaxed))
.collect::<Vec<f64>>()
}