pub mod em;
pub mod rcg;
use crate::math::logsumexp;
use burn_tensor::Tensor;
use burn_tensor::backend::Backend;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
#[non_exhaustive]
pub enum Algorithm {
EM,
#[default]
RCG,
}
impl std::str::FromStr for Algorithm {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"rcg" => Ok(Algorithm::RCG),
"em" => Ok(Algorithm::EM),
_ => Err(format!("'{}' is not a valid Algorithm variant", s)),
}
}
}
pub fn mixture_components<B: Backend>(
gamma_z: Tensor::<B, 2>,
log_counts: Tensor::<B, 1>,
) -> Tensor::<B, 1> {
let log_counts = log_counts.unsqueeze();
let n_times_total = logsumexp(log_counts.clone(), 1);
let n_times_total = n_times_total.sum();
let n_times_total = n_times_total.unsqueeze();
let gamma_z = gamma_z.add(log_counts).exp();
let alphas = gamma_z.sum_dim(1);
let thetas = alphas.log().sub(n_times_total).exp();
thetas.squeeze()
}
#[cfg(test)]
mod tests {
use assert_approx_eq::assert_approx_eq;
#[test]
fn mixture_components() {
use burn::backend::ndarray::NdArray;
use burn_tensor::Tensor;
use super::mixture_components;
let device = Default::default();
type Backend = NdArray<f32>;
let gamma_z = Tensor::<Backend, 2>::from_data(
[
[ -0.0010899, -0.00104044, -0.000928571, -0.00104519, -0.000995734, -0.000883857, -0.000944069, -0.000894604, -0.000782716, -0.000853449 ],
[ -7.15745, -7.1574, -7.15729, -7.15741, -7.15736, -7.15725, -7.15731, -7.15726, -7.15715, -7.51888 ],
[ -8.82298, -8.82293, -8.82282, -9.1846, -9.18455, -9.18444, -13.418, -13.4179, -13.4178, -8.82274 ],
[ -8.72199, -9.0836, -13.3169, -8.72195, -9.08356, -13.3169, -8.72184, -9.08346, -13.3168, -8.72175 ],
],
&device,
);
let log_counts = Tensor::<Backend, 1>::from_data(
[
7.681099, 7.04316, 6.849066, 5.278115, 5.164786, 5.062595, 6.947937, 6.863803, 7.277248, 7.666222
],
&device,
);
let expected = Tensor::<Backend, 1>::from_data(
[
0.999543, 0.00073079, 9.66135e-05, 0.000112505
],
&device,
);
let got = mixture_components::<Backend>(gamma_z, log_counts);
let got_data = got.into_data();
let expected_data = expected.into_data();
got_data.iter().zip(expected_data.iter()).for_each(|(x, y): (f32, f32)| { assert_approx_eq!(x, y, 1e-2_f32) });
}
}