mixt 0.1.0

Estimate mixture model weights for a fixed log-likelihood matrix.
Documentation
// mixt: Riemannian conjugate gradient descent for estimating mixture model weights.
//
// Copyright 2025 mixt contributors [https://github.com/tmaklin/mixt]
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301
// USA
//

//! Algorithm implementations
//!
//! mixt currently implements two algorithms:
//! - [Riemannian conjugate gradient](rcg) descent.
//! - [Expectation maximization](em).
//!
//! Both algorithms assume a likelihood function where the parameters are fixed
//! but the mixing proportions are unknown.
//!
//! Computation is performed using 32-bit floating point numbers by default.
//! There are several numerical tricks employed to ensure the algorithms remain
//! stable with the limited precision, this can make the code somewhat
//! unintuitive at times.
//!

pub mod em;
pub mod rcg;

use crate::math::logsumexp;

use burn_tensor::Tensor;
use burn_tensor::backend::Backend;

/// Supported optimizer algorithms
///
/// This struct is
/// [non_exhaustive](https://doc.rust-lang.org/reference/attributes/type_system.html).
/// This is not expected to change.
///
#[derive(Debug, Clone, Default, PartialEq, Eq)]
#[non_exhaustive]
pub enum Algorithm {
    /// Expectation maximization
    EM,
    /// Riemannian conjugate gradient descent
    #[default]
    RCG,
}

impl std::str::FromStr for Algorithm {
    type Err = String; // Define an error type for parsing failures

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "rcg" => Ok(Algorithm::RCG),
            "em" => Ok(Algorithm::EM),
            _ => Err(format!("'{}' is not a valid Algorithm variant", s)),
        }
    }
}

/// Compute mixture components from fitted values
///
/// Weigh the values of each column in `exp(gamma_z)` with `exp(log_counts)` and compute
/// mixture proportions by dividing the result by `exp(logsumexp(log_counts))`.
///
/// Returns a 1D tensor containing the mixture proportions.
///
pub fn mixture_components<B: Backend>(
    gamma_z: Tensor::<B, 2>,
    log_counts: Tensor::<B, 1>,
) -> Tensor::<B, 1> {
    let log_counts = log_counts.unsqueeze();
    let n_times_total = logsumexp(log_counts.clone(), 1);
    let n_times_total = n_times_total.sum();
    let n_times_total = n_times_total.unsqueeze();
    let gamma_z = gamma_z.add(log_counts).exp();
    let alphas = gamma_z.sum_dim(1);
    let thetas = alphas.log().sub(n_times_total).exp();
    thetas.squeeze()
}

// Tests
#[cfg(test)]
mod tests {
    use assert_approx_eq::assert_approx_eq;

    #[test]
    fn mixture_components() {
        use burn::backend::ndarray::NdArray;
        use burn_tensor::Tensor;

        use super::mixture_components;

        let device = Default::default();
        type Backend = NdArray<f32>;

        let gamma_z = Tensor::<Backend, 2>::from_data(
            [
                [ -0.0010899, -0.00104044, -0.000928571, -0.00104519, -0.000995734, -0.000883857, -0.000944069, -0.000894604, -0.000782716, -0.000853449 ],
                [ -7.15745,   -7.1574,     -7.15729,     -7.15741,    -7.15736,     -7.15725,     -7.15731,     -7.15726,     -7.15715,     -7.51888 ],
                [ -8.82298,   -8.82293,    -8.82282,     -9.1846,     -9.18455,     -9.18444,     -13.418,      -13.4179,     -13.4178,     -8.82274 ],
                [ -8.72199,   -9.0836,     -13.3169,     -8.72195,    -9.08356,     -13.3169,     -8.72184,     -9.08346,     -13.3168,     -8.72175 ],
            ],
            &device,
        );

        let log_counts = Tensor::<Backend, 1>::from_data(
            [
                7.681099, 7.04316, 6.849066, 5.278115, 5.164786, 5.062595, 6.947937, 6.863803, 7.277248, 7.666222
            ],
            &device,
        );

        let expected = Tensor::<Backend, 1>::from_data(
            [
                0.999543, 0.00073079, 9.66135e-05, 0.000112505
            ],
            &device,
        );

        let got = mixture_components::<Backend>(gamma_z, log_counts);

        let got_data = got.into_data();
        let expected_data = expected.into_data();

        got_data.iter().zip(expected_data.iter()).for_each(|(x, y): (f32, f32)| { assert_approx_eq!(x, y, 1e-2_f32) });
    }
}