use salmon_eqclass::CollapsedEqClasses;
mod online;
mod packed;
pub mod uncertainty;
pub use online::OnlineInference;
pub use packed::PackedEqClasses;
pub use uncertainty::{ambiguity_counts, bootstrap, gibbs_sample, GibbsOptions};
#[derive(Debug, Clone)]
pub struct EmOptions {
pub max_iter: u32,
pub min_iter: u32,
pub rel_diff_tol: f64,
pub alpha_check_cutoff: f64,
pub min_alpha: f64,
pub use_vbem: bool,
pub vb_prior: f64,
pub per_nucleotide_prior: bool,
}
impl Default for EmOptions {
fn default() -> Self {
Self {
max_iter: 10_000,
min_iter: 50,
rel_diff_tol: 0.01,
alpha_check_cutoff: 1e-2,
min_alpha: 1e-8,
use_vbem: false,
vb_prior: 1e-2,
per_nucleotide_prior: false,
}
}
}
#[derive(Debug, Clone)]
pub struct EmResult {
pub alphas: Vec<f64>,
pub iters: u32,
pub converged: bool,
}
fn max_rel_diff(alpha_in: &[f64], alpha_out: &[f64], cutoff: f64) -> f64 {
let mut max_d = f64::NEG_INFINITY;
for i in 0..alpha_in.len() {
if alpha_in[i] > cutoff && alpha_out[i] > 0.0 {
let d = (alpha_out[i] - alpha_in[i]).abs() / alpha_out[i];
if d > max_d {
max_d = d;
}
}
}
max_d
}
pub fn optimize(
eq: &CollapsedEqClasses,
num_txps: usize,
opts: &EmOptions,
eff_lens: Option<&[f64]>,
) -> EmResult {
let packed = PackedEqClasses::from_collapsed(eq, num_txps);
optimize_packed_with_init(&packed, opts, true, None, eff_lens)
}
pub fn optimize_with_init(
eq: &CollapsedEqClasses,
num_txps: usize,
opts: &EmOptions,
init_alphas: Option<&[f64]>,
eff_lens: Option<&[f64]>,
) -> EmResult {
let packed = PackedEqClasses::from_collapsed(eq, num_txps);
optimize_packed_with_init(&packed, opts, true, init_alphas, eff_lens)
}
pub fn optimize_packed(p: &PackedEqClasses, opts: &EmOptions, parallel: bool) -> EmResult {
optimize_packed_with_init(p, opts, parallel, None, None)
}
pub fn optimize_packed_with_init(
p: &PackedEqClasses,
opts: &EmOptions,
parallel: bool,
init_alphas: Option<&[f64]>,
eff_lens: Option<&[f64]>,
) -> EmResult {
let (mut alphas, iters, converged) = run_em_counts(
p,
&p.counts,
opts,
parallel,
opts.min_iter,
init_alphas,
eff_lens,
);
for a in &mut alphas {
if *a < opts.min_alpha {
*a = 0.0;
}
}
EmResult {
alphas,
iters,
converged,
}
}
pub(crate) fn run_em_counts(
p: &PackedEqClasses,
counts: &[u64],
opts: &EmOptions,
parallel: bool,
min_iter: u32,
init_alphas: Option<&[f64]>,
eff_lens: Option<&[f64]>,
) -> (Vec<f64>, u32, bool) {
let num_txps = p.num_txps;
let total: u64 = counts.iter().sum();
let init = if num_txps > 0 {
total as f64 / num_txps as f64
} else {
0.0
};
let mut alphas = match init_alphas {
Some(a) if a.len() == num_txps => a.to_vec(),
_ => vec![init; num_txps],
};
let mut alphas_prime = vec![0.0f64; num_txps];
let prior_alphas = match (opts.per_nucleotide_prior, eff_lens) {
(true, Some(el)) if el.len() == num_txps => {
el.iter().map(|&l| opts.vb_prior * l.max(1.0)).collect()
}
_ => vec![opts.vb_prior; num_txps],
};
let mut exp_theta = vec![0.0f64; num_txps];
let mut scratch: Vec<f64> = Vec::with_capacity(64);
let mut shards: Vec<Vec<f64>> = if parallel {
let nshards = rayon::current_num_threads().clamp(1, 64);
vec![vec![0.0f64; num_txps]; nshards]
} else {
Vec::new()
};
let mut converged = false;
let mut it = 0u32;
while it < opts.max_iter {
match (opts.use_vbem, parallel) {
(false, true) => {
packed::em_step_par(p, counts, &alphas, &mut alphas_prime, &mut shards)
}
(false, false) => {
packed::em_step_seq(p, counts, &alphas, &mut alphas_prime, &mut scratch)
}
(true, true) => packed::vbem_step_par(
p,
counts,
&prior_alphas,
&alphas,
&mut alphas_prime,
&mut exp_theta,
&mut shards,
),
(true, false) => packed::vbem_step_seq(
p,
counts,
&prior_alphas,
&alphas,
&mut alphas_prime,
&mut exp_theta,
&mut scratch,
),
}
it += 1;
if it >= min_iter {
let d = max_rel_diff(&alphas, &alphas_prime, opts.alpha_check_cutoff);
std::mem::swap(&mut alphas, &mut alphas_prime);
if d.is_finite() && d < opts.rel_diff_tol {
converged = true;
break;
}
} else {
std::mem::swap(&mut alphas, &mut alphas_prime);
}
}
(alphas, it, converged)
}
#[cfg(test)]
mod tests {
use super::*;
use salmon_eqclass::{EquivalenceClassBuilder, TranscriptGroup};
fn build(classes: &[(Vec<u32>, u64)], num_txps: usize) -> CollapsedEqClasses {
let b = EquivalenceClassBuilder::new();
for (txps, count) in classes {
let w = vec![1.0; txps.len()];
b.add_group(TranscriptGroup::new(txps.clone()), w, *count);
}
let mut eq = b.finish();
eq.update_eff_lengths(&vec![1.0; num_txps]);
eq
}
#[test]
fn unique_classes_recover_exact_counts() {
let eq = build(&[(vec![0], 30), (vec![1], 70)], 2);
let res = optimize(&eq, 2, &EmOptions::default(), None);
assert!((res.alphas[0] - 30.0).abs() < 1e-6);
assert!((res.alphas[1] - 70.0).abs() < 1e-6);
}
#[test]
fn shared_class_splits_by_unique_evidence() {
let eq = build(&[(vec![0], 10), (vec![1], 90), (vec![0, 1], 100)], 2);
let res = optimize(&eq, 2, &EmOptions::default(), None);
let total = res.alphas[0] + res.alphas[1];
assert!((total - 200.0).abs() < 1e-6, "total={total}");
assert!((res.alphas[0] - 20.0).abs() < 1e-2, "a0={}", res.alphas[0]);
assert!((res.alphas[1] - 180.0).abs() < 1e-2, "a1={}", res.alphas[1]);
}
#[test]
fn conserves_total_count() {
let eq = build(&[(vec![0, 1, 2], 50), (vec![1, 2], 30), (vec![2], 20)], 3);
let res = optimize(&eq, 3, &EmOptions::default(), None);
let total: f64 = res.alphas.iter().sum();
assert!((total - 100.0).abs() < 1e-6, "total={total}");
}
#[test]
fn vbem_runs_and_conserves_approximately() {
let eq = build(&[(vec![0], 30), (vec![1], 70), (vec![0, 1], 100)], 2);
let opts = EmOptions {
use_vbem: true,
..Default::default()
};
let res = optimize(&eq, 2, &opts, None);
let total: f64 = res.alphas.iter().sum();
assert!((total - 200.0).abs() < 1.0, "total={total}");
assert!(res.alphas[1] > res.alphas[0]);
}
#[test]
fn effective_length_shifts_allocation() {
let b = EquivalenceClassBuilder::new();
b.add_group(TranscriptGroup::new(vec![0, 1]), vec![1.0, 1.0], 100);
let mut eq = b.finish();
eq.update_eff_lengths(&[300.0, 100.0]);
let res = optimize(&eq, 2, &EmOptions::default(), None);
assert!(res.alphas[1] > res.alphas[0], "{:?}", res.alphas);
}
}