use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2, Zip, s};
use rand::distr::Distribution;
use crate::kmeans::closest_centroid;
#[derive(Clone, Copy)]
pub enum KMeansInit {
Forgy,
PlusPlus,
}
impl KMeansInit {
pub fn run(
&self,
k_clusters: usize,
data: &ArrayBase<impl Data<Elem = f64>, Ix2>,
rng: &mut impl rand::Rng,
) -> Array2<f64> {
match self {
KMeansInit::Forgy => {
let (samples, _) = data.dim();
let indices = rand::seq::index::sample(rng, samples, k_clusters).into_vec();
data.select(Axis(0), &indices)
}
KMeansInit::PlusPlus => {
let (samples, features) = data.dim();
let mut centroids = Array2::<f64>::zeros((k_clusters, features));
let mut weights = Array1::<f64>::zeros(samples);
centroids
.row_mut(0)
.assign(&data.row(rng.random_range(0..samples)));
for c_idx in 1..k_clusters {
Zip::from(data.outer_iter())
.and(&mut weights)
.par_for_each(|point, weight| {
let (_, min_dist) =
closest_centroid(&point, ¢roids.slice(s![0..c_idx, ..]));
*weight = min_dist;
});
let p_idx = rand::distr::weighted::WeightedIndex::new(weights.iter())
.map(|w_idx| w_idx.sample(rng))
.unwrap_or(0);
centroids.row_mut(c_idx).assign(&data.row(p_idx));
}
centroids
}
}
}
}