use crate::experimental::stick::sb::StickBreaking;
use crate::experimental::stick::sbd::StickBreakingDiscrete;
use crate::traits::{HasSuffStat, SuffStat};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
#[derive(Clone, Debug, PartialEq)]
pub struct StickBreakingSuffStat {
n: usize,
num_breaks: usize,
sum_log_q: f64,
}
impl Default for StickBreakingSuffStat {
fn default() -> Self {
Self::new()
}
}
impl StickBreakingSuffStat {
#[must_use]
pub fn new() -> Self {
Self {
n: 0,
num_breaks: 0,
sum_log_q: 0.0,
}
}
#[must_use]
pub fn num_breaks(&self) -> usize {
self.num_breaks
}
#[must_use]
pub fn sum_log_q(&self) -> f64 {
self.sum_log_q
}
}
impl From<&&[f64]> for StickBreakingSuffStat {
fn from(x: &&[f64]) -> Self {
let mut stat = StickBreakingSuffStat::new();
stat.observe(x);
stat
}
}
fn stick_stat_unit_powerlaw(sticks: &[f64]) -> (usize, f64) {
let remaining = sticks.iter().rev().scan(0.0, |acc, &x| {
*acc += x;
Some(*acc)
});
let qs = sticks
.iter()
.rev()
.zip(remaining)
.filter(|&(&len, ref remaining)| len < *remaining)
.map(|(&len, remaining)| 1.0 - len / remaining);
let (num_breaks, prod_q) =
qs.fold((0, 1.0), |(n, prod_q), q| (n + 1, prod_q * q));
(num_breaks, prod_q.ln())
}
impl HasSuffStat<&[f64]> for StickBreaking {
type Stat = StickBreakingSuffStat;
fn empty_suffstat(&self) -> Self::Stat {
Self::Stat::new()
}
fn ln_f_stat(&self, stat: &Self::Stat) -> f64 {
let alpha = self.alpha();
let alpha_ln = self.break_tail().alpha_ln();
(stat.num_breaks as f64)
.mul_add(alpha_ln, (alpha - 1.0) * stat.sum_log_q)
}
}
impl SuffStat<&[f64]> for StickBreakingSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, sticks: &&[f64]) {
let (num_breaks, sum_log_q) = stick_stat_unit_powerlaw(sticks);
self.n += 1;
self.num_breaks += num_breaks;
self.sum_log_q += sum_log_q;
}
fn forget(&mut self, sticks: &&[f64]) {
let (num_breaks, sum_log_q) = stick_stat_unit_powerlaw(sticks);
self.n -= 1;
self.num_breaks -= num_breaks;
self.sum_log_q -= sum_log_q;
}
fn merge(&mut self, other: Self) {
if other.n == 0 {
return;
}
self.n += other.n;
self.sum_log_q += other.sum_log_q;
self.num_breaks += other.num_breaks;
}
}
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
#[derive(Clone, Debug, PartialEq)]
pub struct StickBreakingDiscreteSuffStat {
counts: Vec<usize>,
}
impl StickBreakingDiscreteSuffStat {
#[must_use]
pub fn new() -> Self {
Self { counts: Vec::new() }
}
#[must_use]
pub fn from_counts(counts: Vec<usize>) -> Self {
Self { counts }
}
#[must_use]
pub fn break_pairs(&self) -> Vec<(usize, usize)> {
let mut s = self.counts.iter().sum();
self.counts
.iter()
.map(|&x| {
s -= x;
(s, x)
})
.collect()
}
#[must_use]
pub fn counts(&self) -> &Vec<usize> {
&self.counts
}
}
impl From<&[usize]> for StickBreakingDiscreteSuffStat {
fn from(data: &[usize]) -> Self {
let mut stat = StickBreakingDiscreteSuffStat::new();
stat.observe_many(data);
stat
}
}
impl Default for StickBreakingDiscreteSuffStat {
fn default() -> Self {
Self::new()
}
}
impl HasSuffStat<usize> for StickBreakingDiscrete {
type Stat = StickBreakingDiscreteSuffStat;
fn empty_suffstat(&self) -> Self::Stat {
Self::Stat::new()
}
fn ln_f_stat(&self, stat: &Self::Stat) -> f64 {
self.stick_sequence().ensure_breaks(stat.counts.len());
self.stick_sequence().with_inner(|inner| {
inner
.weights
.iter()
.zip(stat.counts.iter())
.map(|(w, &ct)| w.ln() * ct as f64)
.sum()
})
}
}
impl SuffStat<usize> for StickBreakingDiscreteSuffStat {
fn n(&self) -> usize {
self.counts.iter().sum()
}
fn observe(&mut self, i: &usize) {
if self.counts.len() < *i + 1 {
self.counts.resize(*i + 1, 0);
}
self.counts[*i] += 1;
}
fn forget(&mut self, i: &usize) {
assert!(self.counts[*i] > 0, "No observations of {i} to forget.");
self.counts[*i] -= 1;
}
fn merge(&mut self, other: Self) {
if other.counts.len() > self.counts.len() {
self.counts.resize(other.counts.len(), 0);
}
self.counts
.iter_mut()
.zip(other.counts.iter())
.for_each(|(ct_a, &ct_b)| *ct_a += ct_b);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_break_pairs() {
let suff_stat = StickBreakingDiscreteSuffStat {
counts: vec![1, 2, 3],
};
let pairs = suff_stat.break_pairs();
assert_eq!(pairs, vec![(5, 1), (3, 2), (0, 3)]);
}
#[test]
fn test_observe_and_forget() {
let mut suff_stat = StickBreakingDiscreteSuffStat::new();
suff_stat.observe(&1);
suff_stat.observe(&2);
suff_stat.observe(&2);
suff_stat.forget(&2);
assert_eq!(suff_stat.counts, vec![0, 1, 1]);
assert_eq!(suff_stat.n(), 2);
}
#[test]
fn test_new_is_default() {
assert!(
StickBreakingDiscreteSuffStat::new()
== StickBreakingDiscreteSuffStat::default()
);
}
}