pub mod sb;
pub mod sbd;
pub mod seq;
pub mod stat;
pub use sb::{BreakSequence, StickBreaking, StickWeights};
pub use sbd::StickBreakingDiscrete;
pub use seq::StickSequence;
pub use stat::StickBreakingDiscreteSuffStat;
use crate::{
dist::UnitPowerLawError,
traits::{HasDensity, InverseCdf, Sampleable, Support},
};
#[cfg(feature = "rkyv")]
use rkyv::{Archive, Deserialize, Serialize};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "rkyv", derive(Serialize, Deserialize, Archive))]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
#[derive(Clone, Debug, PartialEq)]
pub struct HalfBeta {
pub alpha: f64,
alpha_ln: f64,
}
impl HalfBeta {
pub fn new(alpha: f64) -> Result<Self, UnitPowerLawError> {
if alpha <= 0.0 {
Err(UnitPowerLawError::AlphaTooLow { alpha })
} else if !alpha.is_finite() {
Err(UnitPowerLawError::AlphaNotFinite { alpha })
} else {
Ok(Self {
alpha,
alpha_ln: alpha.ln(),
})
}
}
pub fn alpha(&self) -> f64 {
self.alpha
}
pub fn set_alpha(&mut self, alpha: f64) -> Result<(), UnitPowerLawError> {
if alpha <= 0.0 {
Err(UnitPowerLawError::AlphaTooLow { alpha })
} else if !alpha.is_finite() {
Err(UnitPowerLawError::AlphaNotFinite { alpha })
} else {
self.alpha = alpha;
self.alpha_ln = alpha.ln();
Ok(())
}
}
pub fn alpha_ln(&self) -> f64 {
self.alpha_ln
}
}
impl HasDensity<f64> for HalfBeta {
fn ln_f(&self, x: &f64) -> f64 {
(1.0 - *x).ln().mul_add(self.alpha - 1.0, self.alpha_ln())
}
}
impl Support<f64> for HalfBeta {
fn supports(&self, x: &f64) -> bool {
0.0 <= *x && *x <= 1.0
}
}
impl InverseCdf<f64> for HalfBeta {
fn invcdf(&self, p: f64) -> f64 {
1.0 - p.powf(self.alpha.recip())
}
}
impl Sampleable<f64> for HalfBeta {
fn draw<R: rand::Rng>(&self, rng: &mut R) -> f64 {
let p: f64 = rng.random();
self.invcdf(p)
}
}