use std::collections::HashMap;
use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use crate::error::{ClusteringError, Result};
#[derive(Debug, Clone)]
struct Lcg {
state: u64,
}
impl Lcg {
fn new(seed: u64) -> Self {
Self { state: seed.wrapping_add(1) }
}
fn next_f64(&mut self) -> f64 {
self.state = self.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let bits = self.state >> 11;
bits as f64 / (1u64 << 53) as f64
}
fn next_usize(&mut self, n: usize) -> usize {
let u = self.next_f64();
((u * n as f64) as usize).min(n - 1)
}
}
#[derive(Debug, Clone)]
pub struct CRPSampler {
pub alpha: f64,
pub seed: u64,
}
impl CRPSampler {
pub fn new(alpha: f64) -> Self {
Self { alpha, seed: 42 }
}
pub fn with_seed(alpha: f64, seed: u64) -> Self {
Self { alpha, seed }
}
pub fn sample_seating(&self, n_customers: usize) -> Result<Array1<usize>> {
if n_customers == 0 {
return Err(ClusteringError::InvalidInput(
"n_customers must be at least 1".to_string(),
));
}
if self.alpha <= 0.0 {
return Err(ClusteringError::InvalidInput(
"alpha must be positive".to_string(),
));
}
let mut rng = Lcg::new(self.seed);
let mut assignments = Array1::<usize>::zeros(n_customers);
let mut table_counts: Vec<usize> = Vec::new();
assignments[0] = 0;
table_counts.push(1);
for i in 1..n_customers {
let total = i as f64 + self.alpha;
let u = rng.next_f64() * total;
let mut cumulative = 0.0;
let mut chosen_table = table_counts.len();
for (k, &count) in table_counts.iter().enumerate() {
cumulative += count as f64;
if u < cumulative {
chosen_table = k;
break;
}
}
assignments[i] = chosen_table;
if chosen_table < table_counts.len() {
table_counts[chosen_table] += 1;
} else {
table_counts.push(1);
}
}
Ok(assignments)
}
pub fn crp_probability(
&self,
prev_assignments: &[usize],
k: Option<usize>,
) -> Result<f64> {
if self.alpha <= 0.0 {
return Err(ClusteringError::InvalidInput(
"alpha must be positive".to_string(),
));
}
let n_prev = prev_assignments.len();
let normaliser = n_prev as f64 + self.alpha;
match k {
None => {
Ok(self.alpha / normaliser)
}
Some(table) => {
let count = prev_assignments.iter().filter(|&&t| t == table).count();
if count == 0 {
Ok(self.alpha / normaliser)
} else {
Ok(count as f64 / normaliser)
}
}
}
}
pub fn predictive_distribution(
&self,
prev_assignments: &[usize],
) -> Result<(Vec<f64>, usize)> {
if self.alpha <= 0.0 {
return Err(ClusteringError::InvalidInput(
"alpha must be positive".to_string(),
));
}
let n_prev = prev_assignments.len();
let normaliser = n_prev as f64 + self.alpha;
let mut counts: HashMap<usize, usize> = HashMap::new();
for &t in prev_assignments {
*counts.entry(t).or_insert(0) += 1;
}
let n_tables = counts.len();
let mut probs: Vec<f64> = Vec::with_capacity(n_tables + 1);
let mut tables: Vec<usize> = counts.keys().copied().collect();
tables.sort_unstable();
for t in &tables {
probs.push(*counts.get(t).unwrap_or(&0) as f64 / normaliser);
}
probs.push(self.alpha / normaliser);
Ok((probs, n_tables))
}
}
#[derive(Debug, Clone)]
pub struct CRPGibbsConfig {
pub alpha: f64,
pub n_iter: usize,
pub n_burnin: usize,
pub prior_mean: f64,
pub prior_var: f64,
pub likelihood_var: f64,
pub seed: u64,
}
impl CRPGibbsConfig {
pub fn new(alpha: f64, n_iter: usize) -> Self {
Self {
alpha,
n_iter,
n_burnin: n_iter / 4,
prior_mean: 0.0,
prior_var: 1.0,
likelihood_var: 0.1,
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct CRPGibbsResult {
pub assignments: Array1<usize>,
pub n_clusters: usize,
pub cluster_means: HashMap<usize, f64>,
pub n_iter: usize,
}
pub fn gibbs_sampler_crp(
data: ArrayView2<f64>,
config: &CRPGibbsConfig,
) -> Result<CRPGibbsResult> {
let n = data.nrows();
if n == 0 {
return Err(ClusteringError::InvalidInput(
"data must have at least one row".to_string(),
));
}
if config.alpha <= 0.0 {
return Err(ClusteringError::InvalidInput(
"alpha must be positive".to_string(),
));
}
let d = data.ncols();
let mut rng = Lcg::new(config.seed);
let mut assignments: Vec<usize> = vec![0usize; n];
let mut clusters: HashMap<usize, Vec<usize>> = HashMap::new();
clusters.insert(0, (0..n).collect());
let mut next_cluster_id = 1usize;
let tau2 = config.prior_var;
let sigma2 = config.likelihood_var;
let mu0 = config.prior_mean;
let log_marginal = |members: &[usize], x_new: f64| -> f64 {
let n_k = members.len() as f64;
let tau2_n = 1.0 / (1.0 / tau2 + n_k / sigma2);
let sum_x: f64 = members
.iter()
.map(|&i| (0..d).map(|j| data[[i, j]]).sum::<f64>() / d as f64)
.sum();
let mu_n = tau2_n * (mu0 / tau2 + sum_x / sigma2);
let pred_var = sigma2 + tau2_n;
let x_mean = (0..d).map(|j| data[[(x_new as usize).min(n-1), j]]).sum::<f64>() / d as f64;
-0.5 * (2.0 * std::f64::consts::PI * pred_var).ln()
- 0.5 * (x_mean - mu_n).powi(2) / pred_var
};
let mut accumulated: Vec<Vec<usize>> = Vec::new();
for iter in 0..config.n_iter {
for i in 0..n {
let old_k = assignments[i];
let members = clusters.entry(old_k).or_default();
members.retain(|&m| m != i);
if members.is_empty() {
clusters.remove(&old_k);
}
let mut log_probs: Vec<(usize, f64)> = Vec::new();
let total_minus_i = n as f64 - 1.0;
for (&k, members) in &clusters {
let n_k = members.len() as f64;
let crp_prior = n_k / (total_minus_i + config.alpha);
let log_lik = log_marginal(members, i as f64);
log_probs.push((k, crp_prior.ln() + log_lik));
}
let new_cluster_logprob = (config.alpha / (total_minus_i + config.alpha)).ln()
+ log_marginal(&[], i as f64);
log_probs.push((next_cluster_id, new_cluster_logprob));
let log_vals: Vec<f64> = log_probs.iter().map(|(_, lp)| *lp).collect();
let max_lp = log_vals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let sum_exp: f64 = log_vals.iter().map(|&lp| (lp - max_lp).exp()).sum();
let u = rng.next_f64() * sum_exp;
let mut cumsum = 0.0;
let mut chosen_k = log_probs[0].0;
for (k, lp) in &log_probs {
cumsum += (lp - max_lp).exp();
if u <= cumsum {
chosen_k = *k;
break;
}
}
assignments[i] = chosen_k;
clusters.entry(chosen_k).or_default().push(i);
if chosen_k == next_cluster_id {
next_cluster_id += 1;
}
}
if iter >= config.n_burnin {
accumulated.push(assignments.clone());
}
}
let mut final_assignments = Array1::<usize>::zeros(n);
for i in 0..n {
let mut vote_counts: HashMap<usize, usize> = HashMap::new();
for sample in &accumulated {
*vote_counts.entry(sample[i]).or_insert(0) += 1;
}
let best = vote_counts
.into_iter()
.max_by_key(|(_, c)| *c)
.map(|(k, _)| k)
.unwrap_or(0);
final_assignments[i] = best;
}
let unique_labels: std::collections::BTreeSet<usize> =
final_assignments.iter().copied().collect();
let label_map: HashMap<usize, usize> = unique_labels
.into_iter()
.enumerate()
.map(|(new, old)| (old, new))
.collect();
for v in final_assignments.iter_mut() {
*v = *label_map.get(v).unwrap_or(v);
}
let n_clusters = label_map.len();
let mut cluster_means: HashMap<usize, f64> = HashMap::new();
let mut cluster_sums: HashMap<usize, (f64, usize)> = HashMap::new();
for i in 0..n {
let k = final_assignments[i];
let x_mean: f64 = (0..d).map(|j| data[[i, j]]).sum::<f64>() / d as f64;
let entry = cluster_sums.entry(k).or_insert((0.0, 0));
entry.0 += x_mean;
entry.1 += 1;
}
for (k, (sum, count)) in cluster_sums {
cluster_means.insert(k, sum / count as f64);
}
Ok(CRPGibbsResult {
assignments: final_assignments,
n_clusters,
cluster_means,
n_iter: config.n_iter,
})
}
#[derive(Debug, Clone)]
pub struct PitmanYorProcess {
pub alpha: f64,
pub discount: f64,
pub seed: u64,
}
impl PitmanYorProcess {
pub fn new(alpha: f64, discount: f64) -> Result<Self> {
if discount < 0.0 || discount >= 1.0 {
return Err(ClusteringError::InvalidInput(
"discount must be in [0, 1)".to_string(),
));
}
if alpha <= -discount {
return Err(ClusteringError::InvalidInput(
"alpha must be > -discount".to_string(),
));
}
Ok(Self { alpha, discount, seed: 42 })
}
pub fn with_seed(alpha: f64, discount: f64, seed: u64) -> Result<Self> {
let mut py = Self::new(alpha, discount)?;
py.seed = seed;
Ok(py)
}
pub fn sample_seating(&self, n_customers: usize) -> Result<Array1<usize>> {
if n_customers == 0 {
return Err(ClusteringError::InvalidInput(
"n_customers must be at least 1".to_string(),
));
}
let mut rng = Lcg::new(self.seed);
let mut assignments = Array1::<usize>::zeros(n_customers);
let mut table_counts: Vec<usize> = Vec::new();
assignments[0] = 0;
table_counts.push(1);
for i in 1..n_customers {
let n_tables = table_counts.len();
let total = i as f64 + self.alpha;
let new_table_prob = (self.alpha + n_tables as f64 * self.discount) / total;
let u = rng.next_f64();
if u < new_table_prob || table_counts.is_empty() {
assignments[i] = n_tables;
table_counts.push(1);
} else {
let mut cumulative = 0.0;
let denom: f64 = table_counts
.iter()
.map(|&c| (c as f64 - self.discount).max(0.0))
.sum();
let u_rescaled = rng.next_f64() * denom;
let mut chosen = n_tables - 1; for (k, &count) in table_counts.iter().enumerate() {
cumulative += (count as f64 - self.discount).max(0.0);
if u_rescaled <= cumulative {
chosen = k;
break;
}
}
assignments[i] = chosen;
table_counts[chosen] += 1;
}
}
Ok(assignments)
}
pub fn expected_n_tables(&self, n: usize) -> f64 {
if self.discount < 1e-10 {
self.alpha * (1.0 + n as f64 / self.alpha).ln()
} else {
let d = self.discount;
let a = self.alpha;
let log_ratio = d * (n as f64).ln() + lgamma(a + 1.0 + d) - lgamma(a + 1.0);
(log_ratio.exp() - 1.0) * d / (a + d).max(1e-10)
}
}
pub fn pitman_yor_process(
&self,
table_counts: &[usize],
n_prev: usize,
) -> Result<(Vec<f64>, f64)> {
if table_counts.is_empty() {
return Ok((vec![], 1.0));
}
let n_tables = table_counts.len();
let total = n_prev as f64 + self.alpha;
if total.abs() < 1e-14 {
return Err(ClusteringError::ComputationError(
"degenerate normaliser in PY process".to_string(),
));
}
let new_table_prob = (self.alpha + n_tables as f64 * self.discount) / total;
let existing_probs: Vec<f64> = table_counts
.iter()
.map(|&c| (c as f64 - self.discount).max(0.0) / total)
.collect();
Ok((existing_probs, new_table_prob))
}
}
pub fn py_e_log_stick_breaking(
alpha: f64,
discount: f64,
t: usize,
) -> Result<Array1<f64>> {
if discount < 0.0 || discount >= 1.0 {
return Err(ClusteringError::InvalidInput(
"discount must be in [0, 1)".to_string(),
));
}
let mut e_log_pi = Array1::<f64>::zeros(t);
let mut cumulative = 0.0;
for k in 0..t {
let a_k = 1.0 - discount;
let b_k = alpha + k as f64 * discount;
let ab = a_k + b_k;
if ab < 1e-14 {
break;
}
let e_log_v = digamma(a_k) - digamma(ab);
let e_log_1mv = digamma(b_k) - digamma(ab);
e_log_pi[k] = e_log_v + cumulative;
cumulative += e_log_1mv;
}
Ok(e_log_pi)
}
fn digamma(x: f64) -> f64 {
if x <= 0.0 {
return f64::NEG_INFINITY;
}
let mut v = x;
let mut result = 0.0;
while v < 6.0 {
result -= 1.0 / v;
v += 1.0;
}
result += v.ln() - 0.5 / v;
let iv2 = 1.0 / (v * v);
result -= iv2 * (1.0 / 12.0 - iv2 * (1.0 / 120.0 - iv2 / 252.0));
result
}
fn lgamma(x: f64) -> f64 {
const G: f64 = 7.0;
const C: [f64; 9] = [
0.999_999_999_999_809_9,
676.520_368_121_885_1,
-1_259.139_216_722_402_9,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507_343_278_686_905,
-0.138_571_095_265_720_12,
9.984_369_578_019_572e-6,
1.505_632_735_149_312e-7,
];
let _ = G; if x < 0.5 {
return std::f64::consts::PI.ln() - (std::f64::consts::PI * x).sin().ln() - lgamma(1.0 - x);
}
let xm1 = x - 1.0;
let mut sum = C[0];
for (i, &c) in C[1..].iter().enumerate() {
sum += c / (xm1 + i as f64 + 1.0);
}
let t = xm1 + G + 0.5;
(2.0 * std::f64::consts::PI).sqrt().ln() + sum.ln() + (xm1 + 0.5) * t.ln() - t
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_crp_seating_length() {
let sampler = CRPSampler::new(1.0);
let assignments = sampler.sample_seating(20).expect("sample");
assert_eq!(assignments.len(), 20);
}
#[test]
fn test_crp_seating_contiguous() {
let sampler = CRPSampler::new(1.0);
let assignments = sampler.sample_seating(50).expect("sample");
let max_k = *assignments.iter().max().expect("non-empty");
let unique: std::collections::BTreeSet<usize> = assignments.iter().copied().collect();
for k in 0..=max_k {
assert!(unique.contains(&k), "missing table {k}");
}
}
#[test]
fn test_crp_first_customer_table0() {
let sampler = CRPSampler::new(2.0);
let assignments = sampler.sample_seating(5).expect("sample");
assert_eq!(assignments[0], 0, "first customer must sit at table 0");
}
#[test]
fn test_crp_probability_new_table_empty() {
let sampler = CRPSampler::new(2.0);
let p_new = sampler.crp_probability(&[], None).expect("prob");
assert!((p_new - 1.0).abs() < 1e-10);
}
#[test]
fn test_crp_probability_existing() {
let sampler = CRPSampler::new(1.0);
let prev = vec![0usize, 0, 1];
let p0 = sampler.crp_probability(&prev, Some(0)).expect("prob");
assert!((p0 - 0.5).abs() < 1e-10, "p0 = {p0}");
let p1 = sampler.crp_probability(&prev, Some(1)).expect("prob");
assert!((p1 - 0.25).abs() < 1e-10, "p1 = {p1}");
let p_new = sampler.crp_probability(&prev, None).expect("prob");
assert!((p_new - 0.25).abs() < 1e-10, "p_new = {p_new}");
}
#[test]
fn test_crp_higher_alpha_more_tables() {
let sampler_low = CRPSampler::with_seed(0.1, 123);
let sampler_high = CRPSampler::with_seed(5.0, 123);
let a_low = sampler_low.sample_seating(100).expect("low");
let a_high = sampler_high.sample_seating(100).expect("high");
let n_tables_low: std::collections::HashSet<_> = a_low.iter().copied().collect();
let n_tables_high: std::collections::HashSet<_> = a_high.iter().copied().collect();
assert!(
n_tables_high.len() >= n_tables_low.len(),
"low={} high={}",
n_tables_low.len(),
n_tables_high.len()
);
}
#[test]
fn test_crp_invalid_alpha() {
let sampler = CRPSampler::new(-1.0);
assert!(sampler.sample_seating(10).is_err());
}
#[test]
fn test_crp_invalid_n_customers() {
let sampler = CRPSampler::new(1.0);
assert!(sampler.sample_seating(0).is_err());
}
#[test]
fn test_gibbs_crp_basic() {
let data = Array2::from_shape_vec(
(10, 1),
vec![0.0, 0.1, -0.1, 0.05, 0.0, 5.0, 4.9, 5.1, 5.0, 4.95],
)
.expect("data");
let cfg = CRPGibbsConfig::new(1.0, 20);
let result = gibbs_sampler_crp(data.view(), &cfg).expect("gibbs");
assert_eq!(result.assignments.len(), 10);
assert!(result.n_clusters >= 1);
assert_eq!(result.n_iter, 20);
}
#[test]
fn test_gibbs_crp_recovers_clusters() {
let data = Array2::from_shape_vec(
(8, 1),
vec![0.0, 0.1, -0.1, 0.05, 10.0, 9.9, 10.1, 10.05],
)
.expect("data");
let cfg = CRPGibbsConfig {
alpha: 1.0,
n_iter: 50,
n_burnin: 10,
prior_mean: 5.0,
prior_var: 25.0,
likelihood_var: 0.5,
seed: 99,
};
let result = gibbs_sampler_crp(data.view(), &cfg).expect("gibbs");
let label_low = result.assignments[0];
let label_high = result.assignments[4];
assert_ne!(
label_low, label_high,
"expected distinct clusters, got same label {label_low}"
);
}
#[test]
fn test_pitman_yor_seating_length() {
let py = PitmanYorProcess::new(1.0, 0.5).expect("py");
let assignments = py.sample_seating(30).expect("sample");
assert_eq!(assignments.len(), 30);
}
#[test]
fn test_pitman_yor_invalid_discount() {
assert!(PitmanYorProcess::new(1.0, 1.0).is_err());
assert!(PitmanYorProcess::new(1.0, -0.1).is_err());
}
#[test]
fn test_pitman_yor_discount_zero_like_crp() {
let py = PitmanYorProcess::with_seed(1.0, 0.0, 42).expect("py");
let crp = CRPSampler::with_seed(1.0, 42);
let py_a = py.sample_seating(50).expect("py sample");
let crp_a = crp.sample_seating(50).expect("crp sample");
let py_tables: std::collections::HashSet<_> = py_a.iter().copied().collect();
let crp_tables: std::collections::HashSet<_> = crp_a.iter().copied().collect();
let diff = (py_tables.len() as isize - crp_tables.len() as isize).abs();
assert!(diff <= 5, "PY tables={}, CRP tables={}", py_tables.len(), crp_tables.len());
}
#[test]
fn test_pitman_yor_predictive() {
let py = PitmanYorProcess::new(1.0, 0.3).expect("py");
let counts = vec![5usize, 3, 1];
let (existing, new_p) = py.pitman_yor_process(&counts, 9).expect("predictive");
for p in &existing {
assert!(*p >= 0.0, "negative prob {p}");
}
assert!(new_p >= 0.0);
let total: f64 = existing.iter().sum::<f64>() + new_p;
assert!((total - 1.0).abs() < 1e-10, "total = {total}");
}
#[test]
fn test_py_e_log_stick_breaking() {
let e_log = py_e_log_stick_breaking(1.0, 0.5, 5).expect("e_log");
assert_eq!(e_log.len(), 5);
for &v in e_log.iter() {
assert!(v.is_finite(), "non-finite value {v}");
}
}
#[test]
fn test_predictive_distribution() {
let sampler = CRPSampler::new(2.0);
let prev = vec![0usize, 0, 1, 2];
let (probs, n_tables) = sampler.predictive_distribution(&prev).expect("dist");
assert_eq!(n_tables, 3);
assert_eq!(probs.len(), 4);
let total: f64 = probs.iter().sum();
assert!((total - 1.0).abs() < 1e-10, "total = {total}");
}
}