use rayon::prelude::*;
use salmon_eqclass::CollapsedEqClasses;
use statrs::function::gamma::digamma;
const DIGAMMA_MIN: f64 = 1e-10;
#[derive(Debug, Clone)]
pub struct PackedEqClasses {
pub labels: Vec<u32>,
pub starts: Vec<u32>,
pub combined: Vec<f64>,
pub weights: Vec<f64>,
pub counts: Vec<u64>,
pub num_txps: usize,
pub total_count: u64,
}
impl PackedEqClasses {
pub fn from_collapsed(eq: &CollapsedEqClasses, num_txps: usize) -> Self {
let n = eq.classes.len();
let mut labels = Vec::new();
let mut starts = Vec::with_capacity(n + 1);
let mut combined = Vec::new();
let mut weights = Vec::new();
let mut counts = Vec::with_capacity(n);
starts.push(0u32);
let mut total = 0u64;
for (group, value) in &eq.classes {
if !group.valid {
continue;
}
labels.extend_from_slice(&group.txps);
combined.extend_from_slice(&value.combined_weights);
weights.extend_from_slice(&value.weights);
counts.push(value.count);
total += value.count;
starts.push(labels.len() as u32);
}
Self {
labels,
starts,
combined,
weights,
counts,
num_txps,
total_count: total,
}
}
#[inline]
pub fn num_classes(&self) -> usize {
self.counts.len()
}
#[inline]
pub fn class(&self, i: usize) -> (&[u32], &[f64]) {
let s = self.starts[i] as usize;
let e = self.starts[i + 1] as usize;
(&self.labels[s..e], &self.combined[s..e])
}
}
const MIN_EQ_CLASS_WEIGHT: f64 = f64::MIN_POSITIVE;
pub(crate) fn redistribute_truncated(
p: &PackedEqClasses,
counts: &[u64],
alpha_conv: &[f64],
prior_alphas: &[f64],
min_alpha: f64,
use_vbem: bool,
) -> (Vec<f64>, f64) {
let n = p.num_txps;
let inactive: Vec<bool> = alpha_conv.iter().map(|&a| a < min_alpha).collect();
let mut basis = vec![0.0f64; n];
for i in 0..n {
if inactive[i] {
continue;
}
basis[i] = if use_vbem {
let ap = alpha_conv[i] + prior_alphas[i];
if ap > DIGAMMA_MIN {
digamma(ap).exp()
} else {
0.0
}
} else {
alpha_conv[i]
};
}
let mut alpha_out = vec![0.0f64; n];
let mut dropped = 0.0f64;
let mut scratch: Vec<f64> = Vec::with_capacity(64);
for ci in 0..p.num_classes() {
let count = counts[ci] as f64;
let (tids, ws) = p.class(ci);
if tids.len() > 1 {
scratch.clear();
let mut denom = 0.0;
for (&tid, &w) in tids.iter().zip(ws) {
let v = basis[tid as usize] * w;
scratch.push(v);
denom += v;
}
if denom > MIN_EQ_CLASS_WEIGHT {
let inv = count / denom;
for (&tid, &v) in tids.iter().zip(scratch.iter()) {
if v > 0.0 {
alpha_out[tid as usize] += v * inv;
}
}
} else {
dropped += count; }
} else if inactive[tids[0] as usize] {
dropped += count; } else {
alpha_out[tids[0] as usize] += count;
}
}
(alpha_out, dropped)
}
pub(crate) fn em_step_seq(
p: &PackedEqClasses,
counts: &[u64],
alpha_in: &[f64],
alpha_out: &mut [f64],
scratch: &mut Vec<f64>,
) {
alpha_out.iter_mut().for_each(|a| *a = 0.0);
for ci in 0..p.num_classes() {
let count = counts[ci] as f64;
let (tids, ws) = p.class(ci);
if tids.len() > 1 {
scratch.clear();
let mut denom = 0.0;
for (&tid, &w) in tids.iter().zip(ws) {
let v = alpha_in[tid as usize] * w;
scratch.push(v);
denom += v;
}
if denom > MIN_EQ_CLASS_WEIGHT {
let inv = count / denom;
for (&tid, &v) in tids.iter().zip(scratch.iter()) {
if !v.is_nan() {
alpha_out[tid as usize] += v * inv;
}
}
}
} else {
alpha_out[tids[0] as usize] += count;
}
}
}
fn reduce_shards(shards: &[Vec<f64>], alpha_out: &mut [f64]) {
alpha_out.par_iter_mut().enumerate().for_each(|(tid, out)| {
let mut s = 0.0;
for buf in shards {
s += buf[tid];
}
*out = s;
});
}
pub(crate) fn em_step_par(
p: &PackedEqClasses,
counts: &[u64],
alpha_in: &[f64],
alpha_out: &mut [f64],
shards: &mut [Vec<f64>],
) {
let nclasses = p.num_classes();
let chunk = nclasses.div_ceil(shards.len().max(1));
shards.par_iter_mut().enumerate().for_each(|(s, buf)| {
buf.iter_mut().for_each(|x| *x = 0.0);
let start = s * chunk;
let end = ((s + 1) * chunk).min(nclasses);
for ci in start..end {
let count = counts[ci] as f64;
let (tids, ws) = p.class(ci);
if tids.len() > 1 {
let mut denom = 0.0;
for (&tid, &w) in tids.iter().zip(ws) {
denom += alpha_in[tid as usize] * w;
}
if denom > MIN_EQ_CLASS_WEIGHT {
let inv = count / denom;
for (&tid, &w) in tids.iter().zip(ws) {
let v = alpha_in[tid as usize] * w;
if !v.is_nan() {
buf[tid as usize] += v * inv;
}
}
}
} else {
buf[tids[0] as usize] += count;
}
}
});
reduce_shards(shards, alpha_out);
}
fn fill_exp_theta(alpha_in: &[f64], prior_alphas: &[f64], exp_theta: &mut [f64]) {
let alpha_sum: f64 = alpha_in.iter().zip(prior_alphas).map(|(a, p)| a + p).sum();
let log_norm = digamma(alpha_sum);
for i in 0..alpha_in.len() {
let ap = alpha_in[i] + prior_alphas[i];
exp_theta[i] = if ap > DIGAMMA_MIN {
(digamma(ap) - log_norm).exp()
} else {
0.0
};
}
}
pub(crate) fn vbem_step_seq(
p: &PackedEqClasses,
counts: &[u64],
prior_alphas: &[f64],
alpha_in: &[f64],
alpha_out: &mut [f64],
exp_theta: &mut [f64],
scratch: &mut Vec<f64>,
) {
fill_exp_theta(alpha_in, prior_alphas, exp_theta);
alpha_out.iter_mut().for_each(|a| *a = 0.0);
for ci in 0..p.num_classes() {
let count = counts[ci] as f64;
let (tids, ws) = p.class(ci);
if tids.len() > 1 {
scratch.clear();
let mut denom = 0.0;
for (&tid, &w) in tids.iter().zip(ws) {
let et = exp_theta[tid as usize];
let v = if et > 0.0 { et * w } else { 0.0 };
scratch.push(v);
denom += v;
}
if denom > MIN_EQ_CLASS_WEIGHT {
let inv = count / denom;
for (&tid, &v) in tids.iter().zip(scratch.iter()) {
if v > 0.0 {
alpha_out[tid as usize] += v * inv;
}
}
}
} else {
alpha_out[tids[0] as usize] += count;
}
}
}
pub(crate) fn vbem_step_par(
p: &PackedEqClasses,
counts: &[u64],
prior_alphas: &[f64],
alpha_in: &[f64],
alpha_out: &mut [f64],
exp_theta: &mut [f64],
shards: &mut [Vec<f64>],
) {
fill_exp_theta(alpha_in, prior_alphas, exp_theta);
let nclasses = p.num_classes();
let chunk = nclasses.div_ceil(shards.len().max(1));
let exp_theta: &[f64] = exp_theta;
shards.par_iter_mut().enumerate().for_each(|(s, buf)| {
buf.iter_mut().for_each(|x| *x = 0.0);
let start = s * chunk;
let end = ((s + 1) * chunk).min(nclasses);
for ci in start..end {
let count = counts[ci] as f64;
let (tids, ws) = p.class(ci);
if tids.len() > 1 {
let mut denom = 0.0;
for (&tid, &w) in tids.iter().zip(ws) {
let et = exp_theta[tid as usize];
if et > 0.0 {
denom += et * w;
}
}
if denom > MIN_EQ_CLASS_WEIGHT {
let inv = count / denom;
for (&tid, &w) in tids.iter().zip(ws) {
let et = exp_theta[tid as usize];
if et > 0.0 {
buf[tid as usize] += et * w * inv;
}
}
}
} else {
buf[tids[0] as usize] += count;
}
}
});
reduce_shards(shards, alpha_out);
}
#[cfg(test)]
mod tests {
use super::*;
use salmon_eqclass::{EquivalenceClassBuilder, TranscriptGroup};
fn packed(classes: &[(Vec<u32>, u64)], num_txps: usize) -> PackedEqClasses {
let b = EquivalenceClassBuilder::new();
for (txps, count) in classes {
b.add_group(
TranscriptGroup::new(txps.clone()),
vec![1.0; txps.len()],
*count,
);
}
let mut eq = b.finish();
eq.update_eff_lengths(&vec![1.0; num_txps]);
PackedEqClasses::from_collapsed(&eq, num_txps)
}
#[test]
fn redistribute_moves_truncated_mass_to_comembers_no_rescale() {
let p = packed(&[(vec![0, 1], 100)], 2);
let alpha_conv = vec![100.0, 1e-12];
let (out, dropped) =
redistribute_truncated(&p, &p.counts, &alpha_conv, &[0.0, 0.0], 1e-8, false);
assert_eq!(dropped, 0.0);
assert!(
(out[0] - 100.0).abs() < 1e-9,
"co-member should get the mass: {out:?}"
);
assert_eq!(out[1], 0.0, "truncated transcript stays 0");
assert!(
((out[0] + out[1]) - 100.0).abs() < 1e-9,
"mass preserved exactly"
);
}
#[test]
fn redistribute_reports_fully_truncated_class_mass() {
let p = packed(&[(vec![0], 5), (vec![1], 3)], 2);
let alpha_conv = vec![10.0, 1e-12];
let (out, dropped) =
redistribute_truncated(&p, &p.counts, &alpha_conv, &[0.0, 0.0], 1e-8, false);
assert_eq!(out[0], 5.0);
assert_eq!(out[1], 0.0);
assert_eq!(dropped, 3.0, "fully-truncated class mass must be reported");
}
#[test]
fn redistribute_vbem_prior_does_not_revive_truncated() {
let p = packed(&[(vec![0, 1], 100)], 2);
let alpha_conv = vec![100.0, 1e-12];
let (out, dropped) =
redistribute_truncated(&p, &p.counts, &alpha_conv, &[0.01, 0.01], 1e-8, true);
assert_eq!(dropped, 0.0);
assert_eq!(
out[1], 0.0,
"VBEM prior must not revive a truncated transcript"
);
assert!(
(out[0] - 100.0).abs() < 1e-9,
"all mass to the surviving co-member"
);
}
}