use ndarray::{Array2, ArrayView2, Axis};
use rand::Rng;
use rand::seq::IndexedRandom;
use rand_chacha::ChaCha8Rng;
use rand_chacha::rand_core::SeedableRng;
pub fn l2_normalize_cols(z: ArrayView2<'_, f64>) -> Array2<f64> {
let mut out = z.to_owned();
for mut col in out.axis_iter_mut(Axis(1)) {
let norm = col.iter().map(|v| v * v).sum::<f64>().sqrt();
if norm > 0.0 {
col.mapv_inplace(|v| v / norm);
}
}
out
}
pub fn kmeans_pp_init(z_cos: ArrayView2<'_, f64>, k: usize, seed: u64) -> Array2<f64> {
let (d, n) = z_cos.dim();
assert!(k >= 1 && k <= n, "k must be in 1..=n (got k={k}, n={n})");
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut centers = Array2::<f64>::zeros((d, k));
let indices: Vec<usize> = (0..n).collect();
let first = *indices.choose(&mut rng).unwrap();
centers.column_mut(0).assign(&z_cos.column(first));
let mut dists = vec![f64::INFINITY; n];
for (i, col) in z_cos.axis_iter(Axis(1)).enumerate() {
let diff = &col - ¢ers.column(0);
dists[i] = diff.dot(&diff);
}
for c in 1..k {
let total: f64 = dists.iter().sum();
if total == 0.0 {
let idx = rng.random_range(0..n);
centers.column_mut(c).assign(&z_cos.column(idx));
} else {
let mut pick: f64 = rng.random::<f64>() * total;
let mut idx = 0usize;
for (i, &d2) in dists.iter().enumerate() {
pick -= d2;
if pick <= 0.0 {
idx = i;
break;
}
}
centers.column_mut(c).assign(&z_cos.column(idx));
}
for (i, col) in z_cos.axis_iter(Axis(1)).enumerate() {
let diff = &col - ¢ers.column(c);
let d2 = diff.dot(&diff);
if d2 < dists[i] {
dists[i] = d2;
}
}
}
l2_normalize_cols(centers.view())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn l2_normalize_unit_columns() {
let z = array![[3.0, 0.0], [4.0, 5.0]];
let out = l2_normalize_cols(z.view());
assert_abs_diff_eq!(out[[0, 0]], 0.6, epsilon = 1e-12);
assert_abs_diff_eq!(out[[1, 0]], 0.8, epsilon = 1e-12);
assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(out[[1, 1]], 1.0, epsilon = 1e-12);
}
#[test]
fn l2_normalize_zero_column_stays_zero() {
let z = array![[0.0], [0.0]];
let out = l2_normalize_cols(z.view());
assert_eq!(out, array![[0.0], [0.0]]);
}
#[test]
fn kmeans_pp_returns_unit_norm_centroids() {
let z = array![[1.0, 0.0, 0.5, -0.5], [0.0, 1.0, 0.5, -0.5]];
let z_cos = l2_normalize_cols(z.view());
let centers = kmeans_pp_init(z_cos.view(), 2, 42);
for c in 0..2 {
let norm = centers.column(c).iter().map(|v| v * v).sum::<f64>().sqrt();
assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-12);
}
}
#[test]
fn kmeans_pp_deterministic_on_seed() {
let z = array![[1.0, 0.0, 0.7, 0.3], [0.0, 1.0, 0.7, 0.3]];
let z_cos = l2_normalize_cols(z.view());
let a = kmeans_pp_init(z_cos.view(), 2, 7);
let b = kmeans_pp_init(z_cos.view(), 2, 7);
assert_eq!(a, b);
}
}