use assert_approx_eq::assert_approx_eq;
use mixt::optimize_flat;
use mixt::BurnBackend;
use mixt::OptimizerOpts;
use mixt::optimizer::Algorithm;
use rand::rngs::ThreadRng;
use rand_distr::Distribution;
use rand_distr::{Gamma, Normal, Poisson, Uniform};
use rand_distr::weighted::WeightedIndex;
use statrs::distribution::Continuous;
fn sample_dirichlet(
alphas: &[f32],
rng: &mut ThreadRng,
) -> Vec<f32> {
let k = alphas.len();
let mut y_sum: f32 = 0.0;
let ys: Vec<f32> = (0..k).map(|idx| {
let gamma = Gamma::new(alphas[idx], 1.0).unwrap();
let y = gamma.sample(rng);
y_sum += y;
y
}).collect();
ys.iter().map(|y| y/y_sum).collect::<Vec<f32>>()
}
fn sample_n_normal(
mean: f32,
sd: f32,
n: usize,
rng: &mut ThreadRng,
) -> Vec<f32> {
let normal = Normal::new(mean, sd).unwrap();
(0..n).map(|_| normal.sample(rng)).collect::<Vec<f32>>()
}
fn sample_n_gamma(
shape: f32,
scale: f32,
n: usize,
rng: &mut ThreadRng,
) -> Vec<f32> {
let gamma = Gamma::new(shape, scale).unwrap();
(0..n).map(|_| gamma.sample(rng)).collect::<Vec<f32>>()
}
fn sample_n_uniform(
min: f32,
max: f32,
n: usize,
rng: &mut ThreadRng,
) -> Vec<f32> {
let uniform = Uniform::new(min, max).unwrap();
(0..n).map(|_| uniform.sample(rng)).collect::<Vec<f32>>()
}
fn sample_n_poisson(
rate: f32,
n: usize,
rng: &mut ThreadRng,
) -> Vec<f32> {
let poisson = Poisson::new(rate).unwrap();
(0..n).map(|_| poisson.sample(rng)).collect::<Vec<f32>>()
}
fn random_loglls(
k: usize,
n: usize,
rng: &mut ThreadRng,
) -> (Vec<f32>, Vec<f32>) {
let means: Vec<f32> = sample_n_normal(0_f32, 10_f32, k, rng);
let sds: Vec<f32> = sample_n_gamma(1_f32, 2_f32, k, rng).iter().map(|x| x.sqrt()).collect();
let normals: Vec<_> = means.iter().zip(sds.iter()).map(|(mu, sigma)| {
statrs::distribution::Normal::new(*mu as f64, *sigma as f64).unwrap()
}).collect();
let alphas: Vec<f32> = sample_n_uniform(0_f32, 1_f32, k, rng);
let thetas: Vec<f32> = sample_dirichlet(&alphas, rng);
let dist = WeightedIndex::new(&thetas).unwrap();
let mut log_lls: Vec<Vec<f32>> = vec![vec![0_f32; n]; k];
for i in 0..n {
let cluster: usize = dist.sample(rng);
let obs: f32 = sample_n_normal(means[cluster], sds[cluster], 1, rng)[0];
for j in 0..k {
log_lls[j][i] = normals[j].ln_pdf(obs as f64) as f32;
}
}
let log_lls: Vec<f32> = log_lls.iter().cloned().flatten().collect();
(log_lls, thetas)
}
#[test]
fn rcg32_random() {
let mut rng = rand::rng();
let k: usize = 2;
let n: usize = 1000;
let (log_lls, thetas) = random_loglls(k, n, &mut rng);
let log_counts: Vec<f32> = sample_n_poisson(100_f32, n, &mut rng).iter().map(|x| x.ln()).collect();
let alphas: Vec<f32> = vec![1.0; k];
let mut opts: OptimizerOpts = Default::default();
opts.tolerance = 1e-16_f64;
opts.max_iters = 1000;
opts.device = BurnBackend::NdArray32;
opts.algorithm = Algorithm::RCG;
let (got, _) = optimize_flat(&log_lls, &log_counts, &alphas, Some(opts)).unwrap();
got.iter().zip(thetas.iter()).for_each(|(x, y)| { assert_approx_eq!(x, y, 4e-2) });
}
#[test]
fn em32_random() {
let mut rng = rand::rng();
let k: usize = 2;
let n: usize = 1000;
let (log_lls, thetas) = random_loglls(k, n, &mut rng);
let log_counts: Vec<f32> = sample_n_poisson(100_f32, n, &mut rng).iter().map(|x| x.ln()).collect();
let alphas: Vec<f32> = vec![1.0; k];
let mut opts: OptimizerOpts = Default::default();
opts.tolerance = 1e-16_f64;
opts.max_iters = 1000;
opts.device = BurnBackend::NdArray32;
opts.algorithm = Algorithm::EM;
let (got, _) = optimize_flat(&log_lls, &log_counts, &alphas, Some(opts)).unwrap();
got.iter().zip(thetas.iter()).for_each(|(x, y)| { assert_approx_eq!(x, y, 4e-2) });
}