use rand::{Rng, SeedableRng};
use rand_distr::{Binomial, Distribution, Gamma};
use rand_pcg::Pcg64Mcg;
use rayon::prelude::*;
use crate::packed::PackedEqClasses;
use crate::{run_em_counts, EmOptions};
const MIN_EQ_CLASS_WEIGHT: f64 = f64::MIN_POSITIVE;
fn multinomial(total: u64, weights: &[f64], rng: &mut impl Rng) -> Vec<u64> {
let n = weights.len();
let mut out = vec![0u64; n];
if total == 0 || n == 0 {
return out;
}
let mut remaining = total;
let mut remaining_w: f64 = weights.iter().sum();
for i in 0..n {
if remaining == 0 {
break;
}
if i == n - 1 || remaining_w <= 0.0 {
out[i] = remaining;
break;
}
let p = (weights[i] / remaining_w).clamp(0.0, 1.0);
let k = if p >= 1.0 {
remaining
} else if p <= 0.0 {
0
} else {
Binomial::new(remaining, p).unwrap().sample(rng)
};
out[i] = k;
remaining -= k;
remaining_w -= weights[i];
}
out
}
pub fn bootstrap(
p: &PackedEqClasses,
opts: &EmOptions,
num_bootstraps: u32,
seed: u64,
) -> Vec<Vec<f64>> {
let sample_weights: Vec<f64> = p.counts.iter().map(|&c| c as f64).collect();
let total = p.total_count;
(0..num_bootstraps)
.into_par_iter()
.map(|bs| {
let mut rng =
Pcg64Mcg::seed_from_u64(seed ^ (bs as u64).wrapping_mul(0x9E3779B97F4A7C15));
let resampled = multinomial(total, &sample_weights, &mut rng);
let (alphas, _, _) = run_em_counts(p, &resampled, opts, false, 50, None, None);
let (alphas, _dropped) =
crate::finalize_truncate_redistribute(p, &resampled, alphas, opts, None);
alphas
})
.collect()
}
#[derive(Debug, Clone)]
pub struct GibbsOptions {
pub num_samples: u32,
pub thinning: u32,
pub prior: f64,
pub per_transcript_prior: bool,
}
impl Default for GibbsOptions {
fn default() -> Self {
Self {
num_samples: 0,
thinning: 16,
prior: 1e-3,
per_transcript_prior: true,
}
}
}
const GIBBS_BETA: f64 = 0.1;
#[allow(clippy::too_many_arguments)]
fn gibbs_round(
p: &PackedEqClasses,
active: &[u32],
eff_lens: &[f64],
prior_alphas: &[f64],
txp_count: &mut [f64],
mu: &mut [f64],
rng: &mut impl Rng,
) {
for &i in active {
let i = i as usize;
let ci = txp_count[i] + prior_alphas[i];
let scale = 1.0 / (GIBBS_BETA + eff_lens[i]);
mu[i] = if ci > 0.0 {
Gamma::new(ci, scale).unwrap().sample(rng)
} else {
0.0
};
txp_count[i] = 0.0;
}
let mut probs: Vec<f64> = Vec::with_capacity(64);
for ci in 0..p.num_classes() {
let class_count = p.counts[ci];
let s = p.starts[ci] as usize;
let e = p.starts[ci + 1] as usize;
let tids = &p.labels[s..e];
let weights = &p.weights[s..e];
if tids.len() > 1 {
probs.clear();
let mut denom = 0.0;
for (&tid, &w) in tids.iter().zip(weights) {
let v = 1000.0 * mu[tid as usize] * w;
probs.push(v);
denom += v;
}
if denom <= MIN_EQ_CLASS_WEIGHT {
for v in probs.iter_mut() {
*v = 1.0;
}
}
let draws = multinomial(class_count, &probs, rng);
for (&tid, &k) in tids.iter().zip(&draws) {
txp_count[tid as usize] += k as f64;
}
} else {
txp_count[tids[0] as usize] += class_count as f64;
}
}
}
pub fn gibbs_sample(
p: &PackedEqClasses,
eff_lens: &[f64],
init_alphas: &[f64],
opts: &GibbsOptions,
seed: u64,
) -> Vec<Vec<f64>> {
let num_txps = p.num_txps;
let num_samples = opts.num_samples as usize;
if num_samples == 0 {
return Vec::new();
}
let mut active_flag = vec![false; num_txps];
for &t in &p.labels {
active_flag[t as usize] = true;
}
let active: Vec<u32> = (0..num_txps as u32)
.filter(|&t| active_flag[t as usize])
.collect();
let prior_alphas: Vec<f64> = (0..num_txps)
.map(|i| {
if opts.per_transcript_prior {
opts.prior
} else {
opts.prior * eff_lens[i].max(1.0)
}
})
.collect();
let mut init = init_alphas.to_vec();
for i in 0..num_txps {
if !active_flag[i] {
init[i] = 0.0;
}
}
let nchains: usize = if num_samples >= 200 {
8
} else if num_samples >= 100 {
4
} else if num_samples >= 50 {
2
} else {
1
};
let step = num_samples / nchains;
let bounds: Vec<(usize, usize)> = (0..nchains)
.map(|c| {
let start = c * step;
let end = if c == nchains - 1 {
num_samples
} else {
(c + 1) * step
};
(start, end)
})
.collect();
let mut all: Vec<Vec<f64>> = vec![Vec::new(); num_samples];
let blocks: Vec<(usize, Vec<Vec<f64>>)> = bounds
.par_iter()
.enumerate()
.map(|(c, &(start, end))| {
let mut rng =
Pcg64Mcg::seed_from_u64(seed ^ (c as u64).wrapping_mul(0xD1B54A32D192ED03));
let mut txp_count = init.clone();
let mut mu = vec![0.0f64; num_txps];
let mut out: Vec<Vec<f64>> = Vec::with_capacity(end - start);
for _ in start..end {
for _ in 0..opts.thinning {
gibbs_round(
p,
&active,
eff_lens,
&prior_alphas,
&mut txp_count,
&mut mu,
&mut rng,
);
}
let mut sample = vec![0.0f64; num_txps];
let mut denom = 0.0;
for t in 0..num_txps {
let ext = mu[t] * eff_lens[t];
if ext > 1e-8 {
sample[t] = ext;
denom += ext;
}
}
if denom > 0.0 {
let scale = p.total_count as f64 / denom;
for s in &mut sample {
*s *= scale;
}
}
out.push(sample);
}
(start, out)
})
.collect();
for (start, out) in blocks {
for (j, s) in out.into_iter().enumerate() {
all[start + j] = s;
}
}
all
}
pub fn ambiguity_counts(p: &PackedEqClasses) -> (Vec<u32>, Vec<u32>) {
let mut unique = vec![0u32; p.num_txps];
let mut ambig = vec![0u32; p.num_txps];
for ci in 0..p.num_classes() {
let s = p.starts[ci] as usize;
let e = p.starts[ci + 1] as usize;
let tids = &p.labels[s..e];
let count = p.counts[ci] as u32;
if tids.len() > 1 {
for &t in tids {
ambig[t as usize] += count;
}
} else {
unique[tids[0] as usize] += count;
}
}
(unique, ambig)
}
#[cfg(test)]
mod tests {
use super::*;
use salmon_eqclass::{EquivalenceClassBuilder, TranscriptGroup};
fn packed(classes: &[(Vec<u32>, u64)], num_txps: usize) -> PackedEqClasses {
let b = EquivalenceClassBuilder::new();
for (txps, count) in classes {
b.add_group(
TranscriptGroup::new(txps.clone()),
vec![1.0; txps.len()],
*count,
);
}
let mut eq = b.finish();
eq.update_eff_lengths(&vec![1.0; num_txps]);
PackedEqClasses::from_collapsed(&eq, num_txps)
}
#[test]
fn bootstrap_mean_near_point_estimate() {
let p = packed(&[(vec![0], 300), (vec![1], 700)], 2);
let bs = bootstrap(&p, &EmOptions::default(), 50, 12345);
assert_eq!(bs.len(), 50);
let m0: f64 = bs.iter().map(|b| b[0]).sum::<f64>() / 50.0;
let m1: f64 = bs.iter().map(|b| b[1]).sum::<f64>() / 50.0;
assert!((m0 - 300.0).abs() < 30.0, "m0={m0}");
assert!((m1 - 700.0).abs() < 30.0, "m1={m1}");
for b in &bs {
assert!(((b[0] + b[1]) - 1000.0).abs() < 1e-6);
}
}
#[test]
fn bootstrap_variance_grows_with_ambiguity() {
let p = packed(&[(vec![0], 10), (vec![1], 10), (vec![0, 1], 980)], 2);
let bs = bootstrap(&p, &EmOptions::default(), 100, 7);
let m0: f64 = bs.iter().map(|b| b[0]).sum::<f64>() / 100.0;
let var0: f64 = bs.iter().map(|b| (b[0] - m0).powi(2)).sum::<f64>() / 100.0;
assert!(
var0 > 0.0,
"ambiguous transcript should have nonzero bootstrap variance"
);
}
#[test]
fn gibbs_runs_and_conserves_scale() {
let p = packed(&[(vec![0], 300), (vec![1], 700)], 2);
let opts = GibbsOptions {
num_samples: 20,
thinning: 8,
..Default::default()
};
let samples = gibbs_sample(&p, &[1.0, 1.0], &[300.0, 700.0], &opts, 99);
assert_eq!(samples.len(), 20);
for s in &samples {
let tot = s[0] + s[1];
assert!(
(tot - 1000.0).abs() < 50.0,
"gibbs total {tot} not near 1000"
);
}
}
#[test]
fn ambiguity_counts_split() {
let p = packed(&[(vec![0], 30), (vec![1], 70), (vec![0, 1], 100)], 2);
let (uniq, amb) = ambiguity_counts(&p);
assert_eq!(uniq, vec![30, 70]);
assert_eq!(amb, vec![100, 100]);
}
}