use crate::error::{StatsError, StatsResult as Result};
use scirs2_core::random::{rngs::StdRng, Beta as RandBeta, CoreRandom, Distribution, Rng, SeedableRng};
#[derive(Debug, Clone)]
pub struct BetaProcess {
pub c: f64,
pub base_measure_mass: f64,
pub n_atoms: usize,
pub atom_probs: Vec<f64>,
pub atom_locations: Vec<usize>,
pub is_sampled: bool,
}
impl BetaProcess {
pub fn new(c: f64, base_measure_mass: f64, n_atoms: usize) -> Result<Self> {
if c <= 0.0 {
return Err(StatsError::DomainError(format!(
"Beta Process c must be > 0, got {c}"
)));
}
if base_measure_mass <= 0.0 {
return Err(StatsError::DomainError(format!(
"base_measure_mass must be > 0, got {base_measure_mass}"
)));
}
if n_atoms == 0 {
return Err(StatsError::InvalidArgument(
"n_atoms must be >= 1".into(),
));
}
Ok(Self {
c,
base_measure_mass,
n_atoms,
atom_probs: Vec::new(),
atom_locations: (0..n_atoms).collect(),
is_sampled: false,
})
}
pub fn sample<R: Rng>(&mut self, rng: &mut CoreRandom<R>) -> Result<()> {
let h_per_atom = self.base_measure_mass / self.n_atoms as f64;
let alpha_beta = h_per_atom * self.c;
let beta_beta = self.c * (1.0 - h_per_atom).max(1e-10);
self.atom_probs = Vec::with_capacity(self.n_atoms);
for _ in 0..self.n_atoms {
let b = RandBeta::new(alpha_beta.max(1e-10), beta_beta.max(1e-10)).map_err(|e| {
StatsError::ComputationError(format!("Beta sampling error: {e}"))
})?;
self.atom_probs.push(b.sample(rng).clamp(0.0, 1.0));
}
self.is_sampled = true;
Ok(())
}
pub fn draw_bernoulli<R: Rng>(&self, rng: &mut CoreRandom<R>) -> Result<Vec<bool>> {
if !self.is_sampled {
return Err(StatsError::InvalidInput(
"Beta process must be sampled first (call .sample())".into(),
));
}
Ok(self.atom_probs.iter().map(|&pi| {
let u = sample_uniform_01(rng);
u < pi
}).collect())
}
pub fn expected_active_features(&self) -> f64 {
if self.is_sampled {
self.atom_probs.iter().sum()
} else {
self.base_measure_mass
}
}
pub fn log_prob_features(&self, z: &[bool]) -> Result<f64> {
if !self.is_sampled {
return Err(StatsError::InvalidInput(
"Beta process must be sampled first".into(),
));
}
if z.len() != self.n_atoms {
return Err(StatsError::DimensionMismatch(format!(
"z has {} entries, expected {}",
z.len(),
self.n_atoms
)));
}
let log_p: f64 = z
.iter()
.zip(self.atom_probs.iter())
.map(|(&zk, &pi)| {
let pi_clipped = pi.clamp(1e-300, 1.0 - 1e-300);
if zk {
pi_clipped.ln()
} else {
(1.0 - pi_clipped).ln()
}
})
.sum();
Ok(log_p)
}
pub fn expected_atoms_above(&self, threshold: f64) -> f64 {
if threshold <= 0.0 || threshold >= 1.0 {
return self.base_measure_mass;
}
let h_per_atom = self.base_measure_mass / self.n_atoms as f64;
let alpha_b = h_per_atom * self.c;
let beta_b = self.c;
let mean_b = alpha_b / (alpha_b + beta_b);
let var_b = alpha_b * beta_b / ((alpha_b + beta_b).powi(2) * (alpha_b + beta_b + 1.0));
if var_b < 1e-15 {
return if mean_b > threshold {
self.n_atoms as f64
} else {
0.0
};
}
let std_b = var_b.sqrt();
let z = (threshold - mean_b) / std_b;
let p_above = normal_cdf_complement(z);
self.n_atoms as f64 * p_above
}
pub fn posterior<R: Rng>(
&self,
feature_counts: &[usize],
n_obs: usize,
rng: &mut CoreRandom<R>,
) -> Result<Self> {
if feature_counts.len() != self.n_atoms {
return Err(StatsError::DimensionMismatch(format!(
"feature_counts has {} entries, expected {}",
feature_counts.len(),
self.n_atoms
)));
}
let h_per_atom = self.base_measure_mass / self.n_atoms as f64;
let mut post = Self::new(self.c, self.base_measure_mass, self.n_atoms)?;
post.atom_probs = Vec::with_capacity(self.n_atoms);
for k in 0..self.n_atoms {
let m_k = feature_counts[k] as f64;
let alpha_post = (self.c * h_per_atom + m_k).max(1e-10);
let beta_post = (self.c * (1.0 - h_per_atom) + n_obs as f64 - m_k).max(1e-10);
let b = RandBeta::new(alpha_post, beta_post).map_err(|e| {
StatsError::ComputationError(format!("Beta sampling error: {e}"))
})?;
post.atom_probs.push(b.sample(rng).clamp(0.0, 1.0));
}
post.is_sampled = true;
Ok(post)
}
}
fn sample_uniform_01<R: Rng>(rng: &mut CoreRandom<R>) -> f64 {
use scirs2_core::random::Uniform;
Uniform::new(0.0_f64, 1.0)
.map(|d| d.sample(rng))
.unwrap_or(0.5)
}
fn normal_cdf_complement(z: f64) -> f64 {
let t = 1.0 / (1.0 + 0.2316419 * z.abs());
let poly = t
* (0.319_381_53
+ t * (-0.356_563_782
+ t * (1.781_477_937
+ t * (-1.821_255_978 + t * 1.330_274_429))));
let pdf = (-0.5 * z * z).exp() / (2.0 * std::f64::consts::PI).sqrt();
let result = pdf * poly;
if z >= 0.0 {
result
} else {
1.0 - result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_beta_process_construction() {
assert!(BetaProcess::new(1.0, 2.0, 10).is_ok());
assert!(BetaProcess::new(0.0, 2.0, 10).is_err());
assert!(BetaProcess::new(1.0, 0.0, 10).is_err());
assert!(BetaProcess::new(1.0, 2.0, 0).is_err());
}
#[test]
fn test_beta_process_sample() {
let mut bp = BetaProcess::new(1.0, 2.0, 10).expect("construction failed");
let mut rng = CoreRandom::seed(42);
bp.sample(&mut rng).expect("sampling failed");
assert!(bp.is_sampled);
assert_eq!(bp.atom_probs.len(), 10);
assert!(bp.atom_probs.iter().all(|&p| p >= 0.0 && p <= 1.0));
}
#[test]
fn test_beta_process_draw_bernoulli() {
let mut bp = BetaProcess::new(2.0, 1.0, 5).expect("construction failed");
let mut rng = CoreRandom::seed(7);
bp.sample(&mut rng).expect("sampling failed");
let z = bp.draw_bernoulli(&mut rng).expect("draw failed");
assert_eq!(z.len(), 5);
}
#[test]
fn test_beta_process_unsampled_error() {
let bp = BetaProcess::new(1.0, 2.0, 5).expect("construction failed");
let mut rng = CoreRandom::seed(0);
assert!(bp.draw_bernoulli(&mut rng).is_err());
assert!(bp.log_prob_features(&[true, false, true, false, true]).is_err());
}
#[test]
fn test_beta_process_log_prob() {
let mut bp = BetaProcess::new(1.0, 2.0, 3).expect("construction failed");
let mut rng = CoreRandom::seed(5);
bp.sample(&mut rng).expect("sampling failed");
let lp = bp.log_prob_features(&[true, false, true]).expect("log_prob failed");
assert!(lp.is_finite());
assert!(lp <= 0.0);
assert!(bp.log_prob_features(&[true, false]).is_err());
}
#[test]
fn test_beta_process_posterior() {
let mut bp = BetaProcess::new(1.0, 3.0, 4).expect("construction failed");
let mut rng = CoreRandom::seed(42);
bp.sample(&mut rng).expect("sampling failed");
let counts = vec![3usize, 1, 0, 2];
let post = bp.posterior(&counts, 5, &mut rng).expect("posterior failed");
assert_eq!(post.atom_probs.len(), 4);
assert!(post.is_sampled);
assert!(bp.posterior(&[1, 2], 5, &mut rng).is_err());
}
#[test]
fn test_expected_active_features() {
let mut bp = BetaProcess::new(1.0, 5.0, 50).expect("construction failed");
assert!((bp.expected_active_features() - 5.0).abs() < 1e-10);
let mut rng = CoreRandom::seed(42);
bp.sample(&mut rng).expect("sampling failed");
let expected = bp.expected_active_features();
assert!(expected > 0.0, "expected active features = {expected}");
}
}