use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use crate::utils::logsumexp_f64;
#[derive(Debug, Clone, Copy)]
pub struct VbxConfig {
pub fa: f64,
pub fb: f64,
pub max_iters: usize,
pub epsilon: f64,
pub init_smoothing: f64,
}
impl Default for VbxConfig {
fn default() -> Self {
Self {
fa: 0.07,
fb: 0.8,
max_iters: 20,
epsilon: 1e-4,
init_smoothing: 7.0,
}
}
}
pub fn vbx(
features: &ArrayView2<f32>,
phi: &ArrayView1<f32>,
gamma_init: &Array2<f32>,
config: &VbxConfig,
) -> (Array2<f32>, Array1<f32>) {
let (n_samples, dim) = features.dim();
let n_speakers = gamma_init.ncols();
let fa = config.fa;
let fb = config.fb;
let fa_over_fb = fa / fb;
let features_f64 = features.mapv(|v| v as f64);
let phi_f64: Array1<f64> = phi.mapv(|v| v as f64);
let mut gamma = gamma_init.mapv(|v| v as f64);
let mut pi = Array1::from_elem(n_speakers, 1.0 / n_speakers as f64);
let frame_constants: Array1<f64> = features_f64
.rows()
.into_iter()
.map(|row| -0.5 * (row.dot(&row) + dim as f64 * (2.0 * std::f64::consts::PI).ln()))
.collect();
let phi_sqrt = phi_f64.mapv(f64::sqrt);
let mut rho = features_f64;
for mut row in rho.rows_mut() {
row *= &phi_sqrt;
}
let mut prev_elbo = f64::NEG_INFINITY;
let mut scratch = Array1::<f64>::zeros(n_speakers);
for iter in 0..config.max_iters {
let n_k: Array1<f64> = gamma.sum_axis(Axis(0));
let mut inv_l = Array2::zeros((n_speakers, dim));
let mut alpha = Array2::zeros((n_speakers, dim));
for speaker_idx in 0..n_speakers {
for dim_idx in 0..dim {
inv_l[[speaker_idx, dim_idx]] =
1.0 / (1.0 + fa_over_fb * n_k[speaker_idx] * phi_f64[dim_idx]);
}
let mut f_k = Array1::<f64>::zeros(dim);
for sample_idx in 0..n_samples {
f_k.scaled_add(gamma[[sample_idx, speaker_idx]], &rho.row(sample_idx));
}
for dim_idx in 0..dim {
alpha[[speaker_idx, dim_idx]] =
fa_over_fb * inv_l[[speaker_idx, dim_idx]] * f_k[dim_idx];
}
}
let mut log_p = Array2::<f64>::zeros((n_samples, n_speakers));
for sample_idx in 0..n_samples {
for speaker_idx in 0..n_speakers {
let rho_dot_alpha: f64 = rho.row(sample_idx).dot(&alpha.row(speaker_idx));
let penalty: f64 = (0..dim)
.map(|dim_idx| {
(inv_l[[speaker_idx, dim_idx]]
+ alpha[[speaker_idx, dim_idx]] * alpha[[speaker_idx, dim_idx]])
* phi_f64[dim_idx]
})
.sum();
log_p[[sample_idx, speaker_idx]] =
fa * (rho_dot_alpha - 0.5 * penalty + frame_constants[sample_idx]);
}
}
let lpi: Array1<f64> = pi.mapv(|p| (p + 1e-8).ln());
let mut log_p_x = Array1::<f64>::zeros(n_samples);
for sample_idx in 0..n_samples {
scratch.assign(&log_p.row(sample_idx));
scratch += &lpi;
log_p_x[sample_idx] = logsumexp_f64(&scratch.view());
}
for sample_idx in 0..n_samples {
for speaker_idx in 0..n_speakers {
gamma[[sample_idx, speaker_idx]] =
(log_p[[sample_idx, speaker_idx]] + lpi[speaker_idx] - log_p_x[sample_idx])
.exp();
}
}
pi = gamma.sum_axis(Axis(0));
let pi_sum = pi.sum();
pi /= pi_sum;
let log_px_sum: f64 = log_p_x.sum();
let reg: f64 = inv_l
.iter()
.zip(alpha.iter())
.map(|(&il, &a)| il.ln() - il - a * a + 1.0)
.sum();
let elbo = log_px_sum + fb * 0.5 * reg;
if iter > 0 && elbo - prev_elbo < config.epsilon {
break;
}
prev_elbo = elbo;
}
let gamma_f32 = gamma.mapv(|v| v as f32);
let pi_f32 = pi.mapv(|v| v as f32);
(gamma_f32, pi_f32)
}
pub fn cluster_vbx(
ahc_labels: &[usize],
features: &ArrayView2<f32>,
phi: &ArrayView1<f32>,
config: &VbxConfig,
) -> (Array2<f32>, Array1<f32>) {
let gamma_init = build_gamma_init(ahc_labels, config.init_smoothing);
vbx(features, phi, &gamma_init, config)
}
fn build_gamma_init(labels: &[usize], smoothing: f64) -> Array2<f32> {
let num_samples = labels.len();
let num_speakers = labels.iter().copied().max().unwrap_or(0) + 1;
let mut gamma = Array2::<f32>::zeros((num_samples, num_speakers));
for (row, &label) in labels.iter().enumerate() {
gamma[[row, label]] = 1.0;
}
if smoothing < 0.0 {
return gamma;
}
let smoothing_f32 = smoothing as f32;
for mut row in gamma.rows_mut() {
row *= smoothing_f32;
let max = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
row.mapv_inplace(|v| (v - max).exp());
let denom = row.sum();
row /= denom;
}
gamma
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::{Array1, Array2, array};
use ndarray_npy::ReadNpyExt;
use std::fs::File;
use std::path::PathBuf;
fn fixture_path(name: &str) -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("fixtures")
.join(name)
}
#[test]
fn two_clusters_with_vbx() {
let features = array![
[10.0, 0.0],
[10.1, 0.1],
[9.9, -0.1],
[-10.0, 0.0],
[-10.1, 0.1],
[-9.9, -0.1],
];
let phi = array![1.0, 1.0];
let mut gamma_init = Array2::zeros((6, 2));
for t in 0..3 {
gamma_init[[t, 0]] = 0.999;
gamma_init[[t, 1]] = 0.001;
}
for t in 3..6 {
gamma_init[[t, 0]] = 0.001;
gamma_init[[t, 1]] = 0.999;
}
let (gamma, _pi) = vbx(
&features.view(),
&phi.view(),
&gamma_init,
&VbxConfig::default(),
);
let labels: Vec<usize> = gamma
.rows()
.into_iter()
.map(|row| {
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0
})
.collect();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[0], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[3], labels[5]);
assert_ne!(labels[0], labels[3]);
}
#[test]
fn gamma_init_is_smoothed_one_hot() {
let gamma = build_gamma_init(&[0, 0, 1], 7.0);
assert_eq!(gamma.dim(), (3, 2));
assert!(gamma[[0, 0]] > gamma[[0, 1]]);
assert!(gamma[[2, 1]] > gamma[[2, 0]]);
}
#[test]
fn cluster_vbx_matches_python_fixture() {
let ahc_labels: Array1<i64> =
Array1::read_npy(File::open(fixture_path("pipeline_ahc_clusters.npy")).unwrap())
.unwrap();
let features: Array2<f64> =
Array2::read_npy(File::open(fixture_path("pipeline_plda_features.npy")).unwrap())
.unwrap();
let phi: Array1<f64> =
Array1::read_npy(File::open(fixture_path("pipeline_plda_phi.npy")).unwrap()).unwrap();
let expected_gamma: Array2<f64> =
Array2::read_npy(File::open(fixture_path("pipeline_vbx_gamma.npy")).unwrap()).unwrap();
let expected_pi: Array1<f64> =
Array1::read_npy(File::open(fixture_path("pipeline_vbx_pi.npy")).unwrap()).unwrap();
let ahc_labels: Vec<usize> = ahc_labels.iter().map(|value| *value as usize).collect();
let features = features.mapv(|value| value as f32);
let phi = phi.mapv(|value| value as f32);
let (gamma, pi) = cluster_vbx(
&ahc_labels,
&features.view(),
&phi.view(),
&VbxConfig::default(),
);
for (lhs, rhs) in gamma.iter().zip(expected_gamma.iter()) {
assert_abs_diff_eq!(*lhs, *rhs as f32, epsilon = 1e-4);
}
for (lhs, rhs) in pi.iter().zip(expected_pi.iter()) {
assert_abs_diff_eq!(*lhs, *rhs as f32, epsilon = 1e-5);
}
}
}