use std::f32;
use std::ops::AddAssign;
use ndarray::{Array, ArrayBase, Axis, Ix1, Ix2};
use ndarray::prelude::s;
use ndarray_rand::rand::Rng;
use ndarray_rand::rand::seq::SliceRandom;
use ndarray_rand::rand::distributions::Uniform;
#[inline]
fn squared_distances(points1: &Array<f32, Ix2>,
points2: &Array<f32, Ix2>) -> Array<f32, Ix2>
{
let points1_bcasted = points1.view().insert_axis(Axis(1));
let points1_bcasted = points1_bcasted
.broadcast((points1.nrows(), points2.nrows(), points2.ncols()))
.unwrap();
(&points1_bcasted - points2).mapv(|c| c.powi(2)).sum_axis(Axis(2))
}
fn update_clusters(points: &Array<f32, Ix2>,
centroids: &Array<f32, Ix2>,
clusters: &mut Array<usize, Ix1>)
-> bool
{
let distances = squared_distances(points, centroids);
let mut has_changed = false;
clusters.iter_mut().zip(distances.axis_iter(Axis(0)))
.for_each(|(cluster, dists)| {
let (mini, _) = dists.indexed_iter()
.fold((0, f32::INFINITY), |(mini, minv), (i, v)| {
if *v < minv { (i, *v) } else { (mini, minv) }
});
if mini != *cluster {
has_changed = true;
*cluster = mini;
}
});
has_changed
}
fn update_centroids(points: &Array<f32, Ix2>,
centroids: &mut Array<f32, Ix2>,
clusters: &Array<usize, Ix1>)
{
let mut cluster_sizes: Array<f32, Ix1> =
ArrayBase::zeros(centroids.nrows());
for (pointi, point) in points.axis_iter(Axis(0)).enumerate() {
let centroidi = clusters[pointi];
if cluster_sizes[centroidi] == 0. {
centroids.row_mut(centroidi).assign(&point);
} else {
centroids.row_mut(centroidi).add_assign(&point);
}
cluster_sizes[centroidi] += 1.;
}
centroids.axis_iter_mut(Axis(0)).zip(cluster_sizes.iter())
.filter(|(_, &cs)| cs > 0.)
.for_each(|(mut c, &cs)| { c /= cs });
}
fn kmeanspp_init_centroids<R: Rng>(rng: &mut R,
points: &Array<f32, Ix2>,
k: usize) -> Array<f32, Ix2>
{
let mut centroids: Array<f32, Ix2> = unsafe {
ArrayBase::uninitialized((k, points.ncols()))
};
let point_idxs: Vec<usize> = (0..points.nrows()).collect();
let point_idx = *point_idxs
.choose(rng)
.unwrap();
centroids.row_mut(0).assign(&points.row(point_idx));
for i in 1..k {
let gen_centroids = centroids.slice(s![..i, ..]).to_owned();
let dists = squared_distances(points, &gen_centroids)
.map_axis_mut(Axis(1), |ds| {
ds.fold(f32::INFINITY, |mind, &d| {
if d < mind { d } else {mind }
})
});
let point_idx = *point_idxs
.choose_weighted(rng, |&i| dists[i])
.unwrap();
centroids.row_mut(i).assign(&points.row(point_idx));
}
centroids
}
pub fn kmeanspp<R: Rng>(rng: &mut R,
points: &Array<f32, Ix2>,
k: usize) -> Array<usize, Ix1>
{
if k == 0 {
return ArrayBase::zeros(points.nrows());
}
let mut centroids = kmeanspp_init_centroids(rng, points, k);
let mut clusters = unsafe { ArrayBase::uninitialized(points.nrows()) };
while update_clusters(points, ¢roids, &mut clusters) {
update_centroids(points, &mut centroids, &clusters);
}
clusters
}
fn forgy_init_centroids<R: Rng>(rng: &mut R,
points: &Array<f32, Ix2>,
k: usize) -> Array<f32, Ix2>
{
let point_idxs: Vec<usize> = (0..points.nrows())
.collect::<Vec<usize>>()
.choose_multiple(rng, k)
.cloned()
.collect();
points.select(Axis(0), &point_idxs)
}
pub fn kmeans_forgy<R: Rng>(rng: &mut R,
points: &Array<f32, Ix2>,
k: usize) -> Array<usize, Ix1>
{
if k == 0 {
return ArrayBase::zeros(points.nrows());
}
let mut centroids = forgy_init_centroids(rng, points, k);
let mut clusters = unsafe { Array::uninitialized(points.nrows()) };
while update_clusters(points, ¢roids, &mut clusters) {
update_centroids(points, &mut centroids, &clusters);
}
clusters
}
fn random_part_init_clusters<R: Rng>(rng: &mut R,
npoints: usize,
k: usize) -> Array<usize, Ix1>
{
Array::from(
rng.sample_iter(Uniform::new(0, k))
.take(npoints)
.collect::<Vec<usize>>())
}
pub fn kmeans_random_part<R: Rng>(rng: &mut R,
points: &Array<f32, Ix2>,
k: usize) -> Array<usize, Ix1>
{
if k == 0 {
return ArrayBase::zeros(points.nrows());
}
let mut centroids = unsafe { Array::uninitialized((k, points.ncols())) };
let mut clusters = random_part_init_clusters(rng, points.nrows(), k);
update_centroids(points, &mut centroids, &clusters);
while update_clusters(points, ¢roids, &mut clusters) {
update_centroids(points, &mut centroids, &clusters);
}
clusters
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array, Axis, stack};
use ndarray::prelude::array;
use ndarray_rand::RandomExt;
use ndarray_rand::rand::thread_rng;
use ndarray_rand::rand_distr::Normal;
#[test]
fn test_squared_distances() {
let points1 = array![
[0., 0.],
[1., 1.],
[2., 2.],
];
let points2 = array![
[3., 3.],
[4., 4.],
];
let distances = squared_distances(&points1, &points2);
assert_eq!(distances, array![
[18., 32.],
[8., 18.],
[2., 8.],
]);
}
#[test]
fn test_update_clusters() {
let points = array![
[0., 0.],
[1., 1.],
[2., 2.],
[3., 3.],
[4., 4.],
[5., 5.],
[6., 6.],
[7., 7.],
[8., 8.],
[9., 9.],
];
let centroids = array![
[0.5, 0.5],
[2.5, 2.5],
[4.5, 4.5],
[6.5, 6.5],
[8.5, 8.5],
];
let mut clusters = ArrayBase::zeros(10);
assert!(update_clusters(&points, ¢roids, &mut clusters));
assert_eq!(clusters, array![0, 0, 1, 1, 2, 2, 3, 3, 4, 4]);
}
#[test]
fn test_update_centroids() {
let points = array![
[0., 0.],
[1., 1.],
[2., 2.],
[3., 3.],
[4., 4.],
[5., 5.],
[6., 6.],
[7., 7.],
[8., 8.],
[9., 9.],
];
let mut centroids = ArrayBase::zeros((5, 2));
let clusters = array![0, 0, 1, 1, 2, 2, 3, 3, 4, 4];
update_centroids(&points, &mut centroids, &clusters);
assert_eq!(centroids, array![
[0.5, 0.5],
[2.5, 2.5],
[4.5, 4.5],
[6.5, 6.5],
[8.5, 8.5],
]);
}
fn gen_random_2d_points(n_points: usize,
meanx: f32, stdx: f32,
meany: f32, stdy: f32) -> Array<f32, Ix2>
{
let xs = Array::random(
(n_points, 1), Normal::new(meanx, stdx).unwrap());
let ys = Array::random(
(n_points, 1), Normal::new(meany, stdy).unwrap());
stack(Axis(1), &[xs.view(), ys.view()]).unwrap()
}
#[test]
fn test_kmeans() {
let mut pts = Array::zeros((0,2));
for _ in 0..10 {
let meanx = thread_rng().gen_range(-10., 10.);
let meany = thread_rng().gen_range(-10., 10.);
let new_pts = gen_random_2d_points(1024, meanx, 1., meany, 1.);
pts = stack(Axis(0), &[pts.view(), new_pts.view()]).unwrap();
}
let clusters_pp = kmeanspp(&mut thread_rng(), &pts, 10);
let clusters_forgy = kmeans_forgy(&mut thread_rng(), &pts, 10);
let clusters_rp = kmeans_random_part(&mut thread_rng(), &pts, 10);
println!("{:?}", clusters_pp);
println!("{:?}", clusters_forgy);
println!("{:?}", clusters_rp);
}
}