use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
#[inline]
fn log_add(a: f64, b: f64) -> f64 {
if a == f64::NEG_INFINITY {
return b;
}
if b == f64::NEG_INFINITY {
return a;
}
let (hi, lo) = if a > b { (a, b) } else { (b, a) };
hi + (lo - hi).exp().ln_1p()
}
struct AtomicF64(AtomicU64);
impl AtomicF64 {
fn new(v: f64) -> Self {
Self(AtomicU64::new(v.to_bits()))
}
#[inline]
fn load(&self) -> f64 {
f64::from_bits(self.0.load(Ordering::Relaxed))
}
#[inline]
fn log_add_assign(&self, x: f64) {
if x == f64::NEG_INFINITY {
return;
}
let mut cur = self.0.load(Ordering::Relaxed);
loop {
let new = log_add(f64::from_bits(cur), x).to_bits();
match self
.0
.compare_exchange_weak(cur, new, Ordering::Relaxed, Ordering::Relaxed)
{
Ok(_) => break,
Err(e) => cur = e,
}
}
}
}
pub struct OnlineInference {
prior_mass: Vec<f64>,
mass: Vec<AtomicF64>,
forgetting_factor: f64,
log_fm: Mutex<Vec<f64>>,
num_assigned: AtomicU64,
burnin_frags: u64,
}
impl OnlineInference {
pub fn new(ref_lens: &[u64], alpha: f64, forgetting_factor: f64, burnin_frags: u64) -> Self {
let prior_mass: Vec<f64> = ref_lens
.iter()
.map(|&l| (alpha * (l.max(1)) as f64).ln())
.collect();
let mass = ref_lens
.iter()
.map(|_| AtomicF64::new(f64::NEG_INFINITY))
.collect();
Self {
prior_mass,
mass,
forgetting_factor,
log_fm: Mutex::new(vec![0.0]), num_assigned: AtomicU64::new(0),
burnin_frags,
}
}
#[inline]
pub fn mass_log(&self, tid: usize) -> f64 {
log_add(self.prior_mass[tid], self.mass[tid].load())
}
pub fn next_log_fm(&self) -> f64 {
let ff = self.forgetting_factor;
let mut v = self.log_fm.lock().unwrap();
let k = v.len();
let fm = v[k - 1] + ff * (k as f64).ln() - ((k as f64 + 1.0).powf(ff) - 1.0).ln();
v.push(fm);
fm
}
#[inline]
pub fn collecting(&self) -> bool {
self.num_assigned.load(Ordering::Relaxed) < self.burnin_frags
}
#[inline]
pub fn num_assigned(&self) -> u64 {
self.num_assigned.load(Ordering::Relaxed)
}
pub fn assign_fragment(&self, maps: &[(u32, f64)], log_fm: f64) -> Vec<f64> {
let n = maps.len();
if n == 0 {
return Vec::new();
}
let mut unnorm = Vec::with_capacity(n);
let mut denom = f64::NEG_INFINITY;
for &(tid, log_aux) in maps {
let u = self.mass_log(tid as usize) + log_aux;
unnorm.push(u);
denom = log_add(denom, u);
}
let mut post = Vec::with_capacity(n);
if denom == f64::NEG_INFINITY {
let u = 1.0 / n as f64;
for &(tid, _) in maps {
self.mass[tid as usize].log_add_assign(log_fm + u.ln());
post.push(u);
}
} else {
for (i, &(tid, _)) in maps.iter().enumerate() {
let lp = unnorm[i] - denom; self.mass[tid as usize].log_add_assign(log_fm + lp);
post.push(lp.exp());
}
}
self.num_assigned.fetch_add(1, Ordering::Relaxed);
post
}
}