use super::StickSequence;
use crate::dist::Mixture;
use crate::misc::ConvergentSequence;
use crate::misc::sorted_uniforms;
use crate::traits::{
Cdf, DiscreteDistr, Entropy, HasDensity, InverseCdf, Mode, Sampleable,
Support,
};
use rand::Rng;
use rand::seq::SliceRandom;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, PartialEq)]
pub struct StickBreakingDiscrete {
sticks: StickSequence,
}
impl StickBreakingDiscrete {
pub fn new(sticks: StickSequence) -> StickBreakingDiscrete {
Self { sticks }
}
pub fn invccdf(&self, p: f64) -> usize {
debug_assert!(p > 0.0 && p < 1.0);
self.sticks.extendmap_ccdf(
|ccdf| ccdf.last().unwrap() < &p,
|ccdf| ccdf.iter().position(|q| *q < p).unwrap() - 1,
)
}
pub fn stick_sequence(&self) -> &StickSequence {
&self.sticks
}
pub fn multi_invccdf_sorted(&self, ps: &[f64]) -> Vec<usize> {
let n = ps.len();
self.sticks.extendmap_ccdf(
|ccdf| ccdf.last().unwrap() < ps.first().unwrap(),
|ccdf| {
let mut result: Vec<usize> = Vec::with_capacity(n);
let mut i: usize = n - 1;
for q in ccdf.iter().skip(1).enumerate() {
while ps[i] > *q.1 {
result.push(q.0);
if i == 0 {
break;
}
i -= 1;
}
}
result
},
)
}
}
impl Support<usize> for StickBreakingDiscrete {
fn supports(&self, _: &usize) -> bool {
true
}
}
impl Cdf<usize> for StickBreakingDiscrete {
fn sf(&self, x: &usize) -> f64 {
self.sticks.ccdf(*x + 1)
}
fn cdf(&self, x: &usize) -> f64 {
1.0 - self.sf(x)
}
}
impl InverseCdf<usize> for StickBreakingDiscrete {
fn invcdf(&self, p: f64) -> usize {
self.invccdf(1.0 - p)
}
}
impl DiscreteDistr<usize> for StickBreakingDiscrete {}
impl Mode<usize> for StickBreakingDiscrete {
fn mode(&self) -> Option<usize> {
let w0 = self.sticks.weight(0);
self.sticks.extendmap_ccdf(
|ccdf| ccdf.last().unwrap() < &w0,
|ccdf| {
let weights: Vec<f64> =
ccdf.windows(2).map(|qs| qs[0] - qs[1]).collect();
weights
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
},
)
}
}
impl HasDensity<usize> for StickBreakingDiscrete {
fn f(&self, n: &usize) -> f64 {
let sticks = &self.sticks;
sticks.weight(*n)
}
fn ln_f(&self, n: &usize) -> f64 {
self.f(n).ln()
}
}
impl Sampleable<usize> for StickBreakingDiscrete {
fn draw<R: Rng>(&self, rng: &mut R) -> usize {
let u: f64 = rng.random();
self.invccdf(u)
}
fn sample<R: Rng>(&self, n: usize, mut rng: &mut R) -> Vec<usize> {
let ps = sorted_uniforms(n, &mut rng);
let mut result = self.multi_invccdf_sorted(&ps);
result.shuffle(&mut rng);
result
}
}
impl Entropy for StickBreakingDiscrete {
fn entropy(&self) -> f64 {
let probs = (0..).map(|n| self.f(&n));
probs
.map(|p| p * p.ln())
.scan(0.0, |state, x| {
*state -= x;
Some(*state)
})
.limit(1e-10)
}
}
impl Entropy for &Mixture<StickBreakingDiscrete> {
fn entropy(&self) -> f64 {
let probs = (0..).map(|n| self.f(&n));
probs
.map(|p| p * p.ln())
.scan(0.0, |state, x| {
*state -= x;
Some(*state)
})
.limit(1e-10)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::*;
use rand::rng;
#[test]
fn test_multi_invccdf_sorted() {
let sticks = StickSequence::new(UnitPowerLaw::new(10.0).unwrap(), None);
let sbd = StickBreakingDiscrete::new(sticks);
let ps = sorted_uniforms(5, &mut rng());
assert_eq!(
sbd.multi_invccdf_sorted(&ps),
ps.iter().rev().map(|p| sbd.invccdf(*p)).collect::<Vec<_>>()
);
}
}