use crate::experimental::stick_breaking_process::stick_breaking::StickBreaking;
use crate::traits::{HasSuffStat, SuffStat};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
#[derive(Clone, Debug, PartialEq)]
pub struct StickBreakingSuffStat {
n: usize,
num_breaks: usize,
sum_log_q: f64,
}
impl Default for StickBreakingSuffStat {
fn default() -> Self {
Self::new()
}
}
impl StickBreakingSuffStat {
#[must_use]
pub fn new() -> Self {
Self {
n: 0,
num_breaks: 0,
sum_log_q: 0.0,
}
}
#[must_use]
pub fn num_breaks(&self) -> usize {
self.num_breaks
}
#[must_use]
pub fn sum_log_q(&self) -> f64 {
self.sum_log_q
}
}
impl From<&&[f64]> for StickBreakingSuffStat {
fn from(x: &&[f64]) -> Self {
let mut stat = StickBreakingSuffStat::new();
stat.observe(x);
stat
}
}
fn stick_stat_unit_powerlaw(sticks: &[f64]) -> (usize, f64) {
let remaining = sticks.iter().rev().scan(0.0, |acc, &x| {
*acc += x;
Some(*acc)
});
let qs = sticks
.iter()
.rev()
.zip(remaining)
.filter(|&(&len, ref remaining)| len < *remaining)
.map(|(&len, remaining)| 1.0 - len / remaining);
let (num_breaks, prod_q) =
qs.fold((0, 1.0), |(n, prod_q), q| (n + 1, prod_q * q));
(num_breaks, prod_q.ln())
}
impl HasSuffStat<&[f64]> for StickBreaking {
type Stat = StickBreakingSuffStat;
fn empty_suffstat(&self) -> Self::Stat {
Self::Stat::new()
}
fn ln_f_stat(&self, stat: &Self::Stat) -> f64 {
let alpha = self.alpha();
let alpha_ln = self.break_tail().alpha_ln();
(stat.num_breaks as f64)
.mul_add(alpha_ln, (alpha - 1.0) * stat.sum_log_q)
}
}
impl SuffStat<&[f64]> for StickBreakingSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, sticks: &&[f64]) {
let (num_breaks, sum_log_q) = stick_stat_unit_powerlaw(sticks);
self.n += 1;
self.num_breaks += num_breaks;
self.sum_log_q += sum_log_q;
}
fn forget(&mut self, sticks: &&[f64]) {
let (num_breaks, sum_log_q) = stick_stat_unit_powerlaw(sticks);
self.n -= 1;
self.num_breaks -= num_breaks;
self.sum_log_q -= sum_log_q;
}
fn merge(&mut self, other: Self) {
if other.n == 0 {
return;
}
self.n += other.n;
self.sum_log_q += other.sum_log_q;
self.num_breaks += other.num_breaks;
}
}