mixt 0.1.0

Estimate mixture model weights for a fixed log-likelihood matrix.
Documentation
// mixt: Riemannian conjugate gradient descent for estimating mixture model weights.
//
// Copyright 2025 Tommi Mäklin [tommi@maklin.fi].
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301
// USA
//

use assert_approx_eq::assert_approx_eq;
use mixt::optimize_flat;
use mixt::BurnBackend;
use mixt::OptimizerOpts;
use mixt::optimizer::Algorithm;

use rand::rngs::ThreadRng;

use rand_distr::Distribution;
use rand_distr::{Gamma, Normal, Poisson, Uniform};
use rand_distr::weighted::WeightedIndex;

use statrs::distribution::Continuous;

/// Sample a single value from the Dirichlet distribution
fn sample_dirichlet(
    alphas: &[f32],
    rng: &mut ThreadRng,
) -> Vec<f32> {
    let k = alphas.len();

    let mut y_sum: f32 = 0.0;
    let ys: Vec<f32> = (0..k).map(|idx| {
        let gamma = Gamma::new(alphas[idx], 1.0).unwrap();
        let y = gamma.sample(rng);
        y_sum += y;
        y
    }).collect();

    ys.iter().map(|y| y/y_sum).collect::<Vec<f32>>()
}

/// Sample n values from the normal distribution
fn sample_n_normal(
    mean: f32,
    sd: f32,
    n: usize,
    rng: &mut ThreadRng,
) -> Vec<f32> {
    let normal = Normal::new(mean, sd).unwrap();
    (0..n).map(|_| normal.sample(rng)).collect::<Vec<f32>>()
}

/// Sample n values from the gamma distribution
fn sample_n_gamma(
    shape: f32,
    scale: f32,
    n: usize,
    rng: &mut ThreadRng,
) -> Vec<f32> {
    let gamma = Gamma::new(shape, scale).unwrap();
    (0..n).map(|_| gamma.sample(rng)).collect::<Vec<f32>>()
}

/// Sample n values from the uniform distribution
fn sample_n_uniform(
    min: f32,
    max: f32,
    n: usize,
    rng: &mut ThreadRng,
) -> Vec<f32> {
    let uniform = Uniform::new(min, max).unwrap();
    (0..n).map(|_| uniform.sample(rng)).collect::<Vec<f32>>()
}

/// Sample n values from the poission
fn sample_n_poisson(
    rate: f32,
    n: usize,
    rng: &mut ThreadRng,
) -> Vec<f32> {
    let poisson = Poisson::new(rate).unwrap();
    (0..n).map(|_| poisson.sample(rng)).collect::<Vec<f32>>()
}

fn random_loglls(
    k: usize,
    n: usize,
    rng: &mut ThreadRng,
) -> (Vec<f32>, Vec<f32>) {
    // Normal distribution parameters to generate observations
    let means: Vec<f32> = sample_n_normal(0_f32, 10_f32, k, rng);
    let sds: Vec<f32> = sample_n_gamma(1_f32, 2_f32, k, rng).iter().map(|x| x.sqrt()).collect();

    let normals: Vec<_> = means.iter().zip(sds.iter()).map(|(mu, sigma)| {
        statrs::distribution::Normal::new(*mu as f64, *sigma as f64).unwrap()
    }).collect();

    // Generate random thetas ~ Dirichlet(alpha_1, ..., alpha_k) by sampling from
    // Gamma(alpha_i, 1) distributions, where alpha_1 ~ Unif(0, 1)
    //
    // This tends to produce thetas that are concentrated around a few values
    let alphas: Vec<f32> = sample_n_uniform(0_f32, 1_f32, k, rng);
    let thetas: Vec<f32> = sample_dirichlet(&alphas, rng);

    // Generate log likelihoods for a mixture of `k` normal distributions
    let dist = WeightedIndex::new(&thetas).unwrap();
    let mut log_lls: Vec<Vec<f32>> = vec![vec![0_f32; n]; k];
    for i in 0..n {
        let cluster: usize = dist.sample(rng);
        let obs: f32 = sample_n_normal(means[cluster], sds[cluster], 1, rng)[0];
        for j in 0..k {
            log_lls[j][i] = normals[j].ln_pdf(obs as f64) as f32;
        }
    }
    let log_lls: Vec<f32> = log_lls.iter().cloned().flatten().collect();

    (log_lls, thetas)
}

#[test]
fn rcg32_random() {
    let mut rng = rand::rng();

    let k: usize = 2;
    let n: usize = 1000;

    let (log_lls, thetas) = random_loglls(k, n, &mut rng);
    let log_counts: Vec<f32> = sample_n_poisson(100_f32, n, &mut rng).iter().map(|x| x.ln()).collect();
    let alphas: Vec<f32> = vec![1.0; k];

    let mut opts: OptimizerOpts = Default::default();
    opts.tolerance = 1e-16_f64;
    opts.max_iters = 1000;
    opts.device = BurnBackend::NdArray32;
    opts.algorithm = Algorithm::RCG;

    let (got, _) = optimize_flat(&log_lls, &log_counts, &alphas, Some(opts)).unwrap();

    got.iter().zip(thetas.iter()).for_each(|(x, y)| { assert_approx_eq!(x, y, 4e-2) });

}

#[test]
fn em32_random() {
    let mut rng = rand::rng();

    let k: usize = 2;
    let n: usize = 1000;

    let (log_lls, thetas) = random_loglls(k, n, &mut rng);
    let log_counts: Vec<f32> = sample_n_poisson(100_f32, n, &mut rng).iter().map(|x| x.ln()).collect();
    let alphas: Vec<f32> = vec![1.0; k];

    let mut opts: OptimizerOpts = Default::default();
    opts.tolerance = 1e-16_f64;
    opts.max_iters = 1000;
    opts.device = BurnBackend::NdArray32;
    opts.algorithm = Algorithm::EM;

    let (got, _) = optimize_flat(&log_lls, &log_counts, &alphas, Some(opts)).unwrap();

    got.iter().zip(thetas.iter()).for_each(|(x, y)| { assert_approx_eq!(x, y, 4e-2) });

}