use std::f32;
use std::ops::AddAssign;
use ndarray::prelude::s;
use ndarray::{Array, ArrayBase, Axis, Ix1, Ix2};
use rand::distributions::{Uniform, WeightedIndex};
use rand::prelude::*;
use rand::seq::SliceRandom;
use rand::Rng;
#[inline]
fn squared_distances(points1: &Array<f32, Ix2>, points2: &Array<f32, Ix2>) -> Array<f32, Ix2> {
let points1_bcasted = points1.view().insert_axis(Axis(1));
let points1_bcasted = points1_bcasted
.broadcast((points1.nrows(), points2.nrows(), points2.ncols()))
.unwrap();
(&points1_bcasted - points2)
.mapv(|c| c.powi(2))
.sum_axis(Axis(2))
}
fn update_clusters(
points: &Array<f32, Ix2>,
centroids: &Array<f32, Ix2>,
clusters: &mut Array<usize, Ix1>,
) -> bool {
let distances = squared_distances(points, centroids);
let mut has_changed = false;
clusters
.iter_mut()
.zip(distances.axis_iter(Axis(0)))
.for_each(|(cluster, dists)| {
let (mini, _) =
dists
.indexed_iter()
.fold((0, f32::INFINITY), |(mini, minv), (i, v)| {
if *v < minv {
(i, *v)
} else {
(mini, minv)
}
});
if mini != *cluster {
has_changed = true;
*cluster = mini;
}
});
has_changed
}
fn update_centroids(
points: &Array<f32, Ix2>,
centroids: &mut Array<f32, Ix2>,
clusters: &Array<usize, Ix1>,
) {
let mut cluster_sizes: Array<f32, Ix1> = ArrayBase::zeros(centroids.nrows());
for (pointi, point) in points.axis_iter(Axis(0)).enumerate() {
let centroidi = clusters[pointi];
if cluster_sizes[centroidi] == 0. {
centroids.row_mut(centroidi).assign(&point);
} else {
centroids.row_mut(centroidi).add_assign(&point);
}
cluster_sizes[centroidi] += 1.;
}
centroids
.axis_iter_mut(Axis(0))
.zip(cluster_sizes.iter())
.filter(|(_, &cs)| cs > 0.)
.for_each(|(mut c, &cs)| c /= cs);
}
fn kmeanspp_init_centroids<R: Rng>(
rng: &mut R,
points: &Array<f32, Ix2>,
k: usize,
) -> Array<f32, Ix2> {
let mut centroids: Array<f32, Ix2> = unsafe { ArrayBase::uninitialized((k, points.ncols())) };
let point_idxs: Vec<usize> = (0..points.nrows()).collect();
let point_idx = *point_idxs.choose(rng).unwrap();
centroids.row_mut(0).assign(&points.row(point_idx));
for i in 1..k {
let gen_centroids = centroids.slice(s![..i, ..]).to_owned();
let distances = squared_distances(points, &gen_centroids).map_axis_mut(Axis(1), |ds| {
ds.fold(f32::INFINITY, |mind, &d| if d < mind { d } else { mind })
});
let distrib = WeightedIndex::new(&distances.to_vec()).unwrap();
let point_idx = point_idxs[distrib.sample(rng)];
centroids.row_mut(i).assign(&points.row(point_idx));
}
centroids
}
pub fn kmeanspp_with_centroids<R: Rng>(
rng: &mut R,
points: &Array<f32, Ix2>,
k: usize,
) -> (Array<usize, Ix1>, Array<f32, Ix2>) {
if k == 0 {
return (
ArrayBase::zeros(points.nrows()),
ArrayBase::zeros((k, points.ncols())),
);
}
let mut centroids = kmeanspp_init_centroids(rng, points, k);
let mut clusters = unsafe { ArrayBase::uninitialized(points.nrows()) };
while update_clusters(points, ¢roids, &mut clusters) {
update_centroids(points, &mut centroids, &clusters);
}
(clusters, centroids)
}
pub fn kmeanspp<R: Rng>(rng: &mut R, points: &Array<f32, Ix2>, k: usize) -> Array<usize, Ix1> {
kmeanspp_with_centroids(rng, points, k).0
}
fn forgy_init_centroids<R: Rng>(
rng: &mut R,
points: &Array<f32, Ix2>,
k: usize,
) -> Array<f32, Ix2> {
let point_idxs: Vec<usize> = (0..points.nrows())
.collect::<Vec<usize>>()
.choose_multiple(rng, k)
.cloned()
.collect();
points.select(Axis(0), &point_idxs)
}
pub fn kmeans_forgy_with_centroids<R: Rng>(
rng: &mut R,
points: &Array<f32, Ix2>,
k: usize,
) -> (Array<usize, Ix1>, Array<f32, Ix2>) {
if k == 0 {
return (
ArrayBase::zeros(points.nrows()),
ArrayBase::zeros((k, points.ncols())),
);
}
let mut centroids = forgy_init_centroids(rng, points, k);
let mut clusters = unsafe { Array::uninitialized(points.nrows()) };
while update_clusters(points, ¢roids, &mut clusters) {
update_centroids(points, &mut centroids, &clusters);
}
(clusters, centroids)
}
pub fn kmeans_forgy<R: Rng>(rng: &mut R, points: &Array<f32, Ix2>, k: usize) -> Array<usize, Ix1> {
kmeans_forgy_with_centroids(rng, points, k).0
}
fn random_part_init_clusters<R: Rng>(rng: &mut R, npoints: usize, k: usize) -> Array<usize, Ix1> {
Array::from(
rng.sample_iter(Uniform::new(0, k))
.take(npoints)
.collect::<Vec<usize>>(),
)
}
pub fn kmeans_random_part_with_centroids<R: Rng>(
rng: &mut R,
points: &Array<f32, Ix2>,
k: usize,
) -> (Array<usize, Ix1>, Array<f32, Ix2>) {
if k == 0 {
return (
ArrayBase::zeros(points.nrows()),
ArrayBase::zeros((k, points.ncols())),
);
}
let mut centroids = unsafe { Array::uninitialized((k, points.ncols())) };
let mut clusters = random_part_init_clusters(rng, points.nrows(), k);
update_centroids(points, &mut centroids, &clusters);
while update_clusters(points, ¢roids, &mut clusters) {
update_centroids(points, &mut centroids, &clusters);
}
(clusters, centroids)
}
pub fn kmeans_random_part<R: Rng>(
rng: &mut R,
points: &Array<f32, Ix2>,
k: usize,
) -> Array<usize, Ix1> {
kmeans_random_part_with_centroids(rng, points, k).0
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::prelude::array;
use ndarray::{stack, Array, Axis};
use ndarray_rand::rand::thread_rng;
use ndarray_rand::rand_distr::Normal;
use ndarray_rand::RandomExt;
#[test]
fn test_squared_distances() {
let points1 = array![[0., 0.], [1., 1.], [2., 2.],];
let points2 = array![[3., 3.], [4., 4.],];
let distances = squared_distances(&points1, &points2);
assert_eq!(distances, array![[18., 32.], [8., 18.], [2., 8.],]);
}
#[test]
fn test_update_clusters() {
let points = array![
[0., 0.],
[1., 1.],
[2., 2.],
[3., 3.],
[4., 4.],
[5., 5.],
[6., 6.],
[7., 7.],
[8., 8.],
[9., 9.],
];
let centroids = array![[0.5, 0.5], [2.5, 2.5], [4.5, 4.5], [6.5, 6.5], [8.5, 8.5],];
let mut clusters = ArrayBase::zeros(10);
assert!(update_clusters(&points, ¢roids, &mut clusters));
assert_eq!(clusters, array![0, 0, 1, 1, 2, 2, 3, 3, 4, 4]);
}
#[test]
fn test_update_centroids() {
let points = array![
[0., 0.],
[1., 1.],
[2., 2.],
[3., 3.],
[4., 4.],
[5., 5.],
[6., 6.],
[7., 7.],
[8., 8.],
[9., 9.],
];
let mut centroids = ArrayBase::zeros((5, 2));
let clusters = array![0, 0, 1, 1, 2, 2, 3, 3, 4, 4];
update_centroids(&points, &mut centroids, &clusters);
assert_eq!(
centroids,
array![[0.5, 0.5], [2.5, 2.5], [4.5, 4.5], [6.5, 6.5], [8.5, 8.5],]
);
}
fn gen_random_2d_points(
n_points: usize,
meanx: f32,
stdx: f32,
meany: f32,
stdy: f32,
) -> Array<f32, Ix2> {
let xs = Array::random((n_points, 1), Normal::new(meanx, stdx).unwrap());
let ys = Array::random((n_points, 1), Normal::new(meany, stdy).unwrap());
stack(Axis(1), &[xs.view(), ys.view()]).unwrap()
}
#[test]
fn test_kmeans() {
let mut pts = Array::zeros((0, 2));
for _ in 0..10 {
let meanx = thread_rng().gen_range(-10., 10.);
let meany = thread_rng().gen_range(-10., 10.);
let new_pts = gen_random_2d_points(256, meanx, 1., meany, 1.);
pts = stack(Axis(0), &[pts.view(), new_pts.view()]).unwrap();
}
let clusters_pp = kmeanspp(&mut thread_rng(), &pts, 10);
let clusters_forgy = kmeans_forgy(&mut thread_rng(), &pts, 10);
let clusters_rp = kmeans_random_part(&mut thread_rng(), &pts, 10);
println!("{:?}", clusters_pp);
println!("{:?}", clusters_forgy);
println!("{:?}", clusters_rp);
}
}