grappes 0.2.0

Implements various clustering algorithms such as k-Means variants
Documentation
use std::f32;
use std::ops::AddAssign;

use ndarray::prelude::s;
use ndarray::{Array, ArrayBase, Axis, Ix1, Ix2};

use rand::distributions::{Uniform, WeightedIndex};
use rand::prelude::*;
use rand::seq::SliceRandom;
use rand::Rng;

// Compute distance between each point1 and each points2
#[inline]
fn squared_distances(points1: &Array<f32, Ix2>, points2: &Array<f32, Ix2>) -> Array<f32, Ix2> {
    let points1_bcasted = points1.view().insert_axis(Axis(1));
    let points1_bcasted = points1_bcasted
        .broadcast((points1.nrows(), points2.nrows(), points2.ncols()))
        .unwrap();

    (&points1_bcasted - points2)
        .mapv(|c| c.powi(2))
        .sum_axis(Axis(2))
}

// update the clusters based on points and centroids
fn update_clusters(
    points: &Array<f32, Ix2>,
    centroids: &Array<f32, Ix2>,
    clusters: &mut Array<usize, Ix1>,
) -> bool {
    // Compute squared Euclidean distance between each point and each centroid
    let distances = squared_distances(points, centroids);

    let mut has_changed = false;
    // For each point
    clusters
        .iter_mut()
        .zip(distances.axis_iter(Axis(0)))
        .for_each(|(cluster, dists)| {
            // Compute its new centroid
            let (mini, _) =
                dists
                    .indexed_iter()
                    .fold((0, f32::INFINITY), |(mini, minv), (i, v)| {
                        if *v < minv {
                            (i, *v)
                        } else {
                            (mini, minv)
                        }
                    });

            // If the closest centroid has changed, update value
            if mini != *cluster {
                has_changed = true;
                *cluster = mini;
            }
        });

    has_changed
}

// update the centroids based on points and clusters
fn update_centroids(
    points: &Array<f32, Ix2>,
    centroids: &mut Array<f32, Ix2>,
    clusters: &Array<usize, Ix1>,
) {
    //Initial size of each cluster
    let mut cluster_sizes: Array<f32, Ix1> = ArrayBase::zeros(centroids.nrows());

    // For each point
    for (pointi, point) in points.axis_iter(Axis(0)).enumerate() {
        // Get its cluster
        let centroidi = clusters[pointi];

        // Update the cluster's centroid
        if cluster_sizes[centroidi] == 0. {
            centroids.row_mut(centroidi).assign(&point);
        } else {
            centroids.row_mut(centroidi).add_assign(&point);
        }

        cluster_sizes[centroidi] += 1.;
    }

    // Normalize the centroid
    centroids
        .axis_iter_mut(Axis(0))
        .zip(cluster_sizes.iter())
        .filter(|(_, &cs)| cs > 0.)
        .for_each(|(mut c, &cs)| c /= cs);
}

// KMeans++ initialization
fn kmeanspp_init_centroids<R: Rng>(
    rng: &mut R,
    points: &Array<f32, Ix2>,
    k: usize,
) -> Array<f32, Ix2> {
    let mut centroids: Array<f32, Ix2> = unsafe { ArrayBase::uninitialized((k, points.ncols())) };
    let point_idxs: Vec<usize> = (0..points.nrows()).collect();

    let point_idx = *point_idxs.choose(rng).unwrap();

    centroids.row_mut(0).assign(&points.row(point_idx));

    for i in 1..k {
        // Compute minimum distance between each point and a generated centroid
        let gen_centroids = centroids.slice(s![..i, ..]).to_owned();
        let distances = squared_distances(points, &gen_centroids).map_axis_mut(Axis(1), |ds| {
            ds.fold(f32::INFINITY, |mind, &d| if d < mind { d } else { mind })
        });

        let distrib = WeightedIndex::new(&distances.to_vec()).unwrap();
        let point_idx = point_idxs[distrib.sample(rng)];

        centroids.row_mut(i).assign(&points.row(point_idx));
    }

    centroids
}

/// KMeans++ algorithm
pub fn kmeanspp_with_centroids<R: Rng>(
    rng: &mut R,
    points: &Array<f32, Ix2>,
    k: usize,
) -> (Array<usize, Ix1>, Array<f32, Ix2>) {
    if k == 0 {
        return (
            ArrayBase::zeros(points.nrows()),
            ArrayBase::zeros((k, points.ncols())),
        );
    }

    let mut centroids = kmeanspp_init_centroids(rng, points, k);
    let mut clusters = unsafe { ArrayBase::uninitialized(points.nrows()) };

    while update_clusters(points, &centroids, &mut clusters) {
        update_centroids(points, &mut centroids, &clusters);
    }

    (clusters, centroids)
}

pub fn kmeanspp<R: Rng>(rng: &mut R, points: &Array<f32, Ix2>, k: usize) -> Array<usize, Ix1> {
    kmeanspp_with_centroids(rng, points, k).0
}

// Forgy initialization
fn forgy_init_centroids<R: Rng>(
    rng: &mut R,
    points: &Array<f32, Ix2>,
    k: usize,
) -> Array<f32, Ix2> {
    let point_idxs: Vec<usize> = (0..points.nrows())
        .collect::<Vec<usize>>()
        .choose_multiple(rng, k)
        .cloned()
        .collect();

    points.select(Axis(0), &point_idxs)
}

/// KMeans algorithm with Forgy initialization
pub fn kmeans_forgy_with_centroids<R: Rng>(
    rng: &mut R,
    points: &Array<f32, Ix2>,
    k: usize,
) -> (Array<usize, Ix1>, Array<f32, Ix2>) {
    if k == 0 {
        return (
            ArrayBase::zeros(points.nrows()),
            ArrayBase::zeros((k, points.ncols())),
        );
    }

    let mut centroids = forgy_init_centroids(rng, points, k);
    let mut clusters = unsafe { Array::uninitialized(points.nrows()) };

    while update_clusters(points, &centroids, &mut clusters) {
        update_centroids(points, &mut centroids, &clusters);
    }

    (clusters, centroids)
}

pub fn kmeans_forgy<R: Rng>(rng: &mut R, points: &Array<f32, Ix2>, k: usize) -> Array<usize, Ix1> {
    kmeans_forgy_with_centroids(rng, points, k).0
}

// Random partition initialization
fn random_part_init_clusters<R: Rng>(rng: &mut R, npoints: usize, k: usize) -> Array<usize, Ix1> {
    Array::from(
        rng.sample_iter(Uniform::new(0, k))
            .take(npoints)
            .collect::<Vec<usize>>(),
    )
}

/// KMeans algorithm with random partition initialization
pub fn kmeans_random_part_with_centroids<R: Rng>(
    rng: &mut R,
    points: &Array<f32, Ix2>,
    k: usize,
) -> (Array<usize, Ix1>, Array<f32, Ix2>) {
    if k == 0 {
        return (
            ArrayBase::zeros(points.nrows()),
            ArrayBase::zeros((k, points.ncols())),
        );
    }

    let mut centroids = unsafe { Array::uninitialized((k, points.ncols())) };
    let mut clusters = random_part_init_clusters(rng, points.nrows(), k);

    update_centroids(points, &mut centroids, &clusters);

    while update_clusters(points, &centroids, &mut clusters) {
        update_centroids(points, &mut centroids, &clusters);
    }

    (clusters, centroids)
}

pub fn kmeans_random_part<R: Rng>(
    rng: &mut R,
    points: &Array<f32, Ix2>,
    k: usize,
) -> Array<usize, Ix1> {
    kmeans_random_part_with_centroids(rng, points, k).0
}

#[cfg(test)]
mod tests {

    use super::*;
    use ndarray::prelude::array;
    use ndarray::{stack, Array, Axis};
    use ndarray_rand::rand::thread_rng;
    use ndarray_rand::rand_distr::Normal;
    use ndarray_rand::RandomExt;

    #[test]
    fn test_squared_distances() {
        let points1 = array![[0., 0.], [1., 1.], [2., 2.],];

        let points2 = array![[3., 3.], [4., 4.],];

        let distances = squared_distances(&points1, &points2);

        assert_eq!(distances, array![[18., 32.], [8., 18.], [2., 8.],]);
    }

    #[test]
    fn test_update_clusters() {
        let points = array![
            [0., 0.],
            [1., 1.],
            [2., 2.],
            [3., 3.],
            [4., 4.],
            [5., 5.],
            [6., 6.],
            [7., 7.],
            [8., 8.],
            [9., 9.],
        ];

        let centroids = array![[0.5, 0.5], [2.5, 2.5], [4.5, 4.5], [6.5, 6.5], [8.5, 8.5],];

        let mut clusters = ArrayBase::zeros(10);

        assert!(update_clusters(&points, &centroids, &mut clusters));
        assert_eq!(clusters, array![0, 0, 1, 1, 2, 2, 3, 3, 4, 4]);
    }

    #[test]
    fn test_update_centroids() {
        let points = array![
            [0., 0.],
            [1., 1.],
            [2., 2.],
            [3., 3.],
            [4., 4.],
            [5., 5.],
            [6., 6.],
            [7., 7.],
            [8., 8.],
            [9., 9.],
        ];

        let mut centroids = ArrayBase::zeros((5, 2));

        let clusters = array![0, 0, 1, 1, 2, 2, 3, 3, 4, 4];

        update_centroids(&points, &mut centroids, &clusters);

        assert_eq!(
            centroids,
            array![[0.5, 0.5], [2.5, 2.5], [4.5, 4.5], [6.5, 6.5], [8.5, 8.5],]
        );
    }

    fn gen_random_2d_points(
        n_points: usize,
        meanx: f32,
        stdx: f32,
        meany: f32,
        stdy: f32,
    ) -> Array<f32, Ix2> {
        let xs = Array::random((n_points, 1), Normal::new(meanx, stdx).unwrap());
        let ys = Array::random((n_points, 1), Normal::new(meany, stdy).unwrap());

        stack(Axis(1), &[xs.view(), ys.view()]).unwrap()
    }

    #[test]
    fn test_kmeans() {
        let mut pts = Array::zeros((0, 2));
        for _ in 0..10 {
            let meanx = thread_rng().gen_range(-10., 10.);
            let meany = thread_rng().gen_range(-10., 10.);
            let new_pts = gen_random_2d_points(256, meanx, 1., meany, 1.);

            pts = stack(Axis(0), &[pts.view(), new_pts.view()]).unwrap();
        }

        let clusters_pp = kmeanspp(&mut thread_rng(), &pts, 10);
        let clusters_forgy = kmeans_forgy(&mut thread_rng(), &pts, 10);
        let clusters_rp = kmeans_random_part(&mut thread_rng(), &pts, 10);

        println!("{:?}", clusters_pp);
        println!("{:?}", clusters_forgy);
        println!("{:?}", clusters_rp);
    }
}