#[cfg(feature = "parallel")]
use std::cmp::Ordering;
#[cfg(feature = "parallel")]
use std::sync::atomic::{AtomicUsize,AtomicU64};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[cfg(feature = "logging")]
use lazy_static::lazy_static;
#[cfg(feature = "logging")]
lazy_static! {
static ref _LOG: () = env_logger::init();
}
pub trait Elem {
fn dimensions(&self) -> usize;
fn at(&self, i: usize) -> f64;
}
#[derive(Debug)]
pub struct Centroid(pub Vec<f64>);
pub struct Clustering<'a, T> {
pub elements: &'a [T],
pub membership: Vec<usize>,
pub centroids: Vec<Centroid>,
}
#[cfg(feature = "parallel")]
pub fn kmeans<T: Elem + Sync>(k: usize, elems: &[T], iter: usize) -> Clustering<T> {
let mut centroids = initialize(k, elems);
let membership : Vec<AtomicUsize> = (0..elems.len())
.map(|_| AtomicUsize::new(0usize))
.collect();
let mut counts = vec![0; k];
#[allow(unused_variables)] for it in 0..iter {
let changes = AtomicU64::new(0);
let dispatch_element = |i : usize| -> usize {
let e = &elems[i];
let old = membership[i].load(std::sync::atomic::Ordering::SeqCst);
let dist = square_distance(e, ¢roids[old]);
let (best_c, best_d) : (usize, f64) = (0..centroids.len())
.map(|c| (c, square_distance(e, ¢roids[c])))
.min_by(|(_c1,d1), (_c2, d2)| if d1 < d2
{Ordering::Less}
else
{Ordering::Greater })
.unwrap();
if best_c != old {
changes.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
membership[i].store(best_c, std::sync::atomic::Ordering::SeqCst);
}
assert!(best_d <= dist);
best_c
};
let _res : Vec<usize> = (0..elems.len()).into_par_iter().map(dispatch_element).collect();
counts.iter_mut().for_each(|x| *x = 0);
centroids.iter_mut().for_each(|c|
c.0.iter_mut().for_each(|d| *d = 0.0));
for (i, elem) in elems.iter().enumerate() {
let clus = membership[i].load(std::sync::atomic::Ordering::SeqCst);
counts[clus] += 1;
for (d, dim) in centroids[clus].0.iter_mut().enumerate() {
*dim += elem.at(d);
}
}
for (centroid, size) in centroids.iter_mut().zip(counts.iter().copied()) {
centroid.0.iter_mut().for_each(|d| if size == 0 { *d = 0.0 } else {*d /= size as f64});
}
if changes.load(std::sync::atomic::Ordering::SeqCst) == 0 {
#[cfg(feature = "logging")]
log::info!("clustering kmeans: short circuit after nb iter : {}", it);
break;
}
}
Clustering {
elements: elems,
membership : membership.iter().map(|x| x.load(std::sync::atomic::Ordering::SeqCst)).collect::<Vec<usize>>(),
centroids
}
}
#[cfg(not(feature = "parallel"))]
pub fn kmeans<T: Elem>(k: usize, elems: &[T], iter: usize) -> Clustering<T> {
let mut centroids = initialize(k, elems);
let mut membership = vec![0; elems.len()];
let mut counts = vec![0; k];
#[allow(unused_variables)] for it in 0..iter {
let mut changes = 0;
for (i, e) in elems.iter().enumerate() {
let old = membership[i];
let mut clus = old;
let mut dist = square_distance(e, ¢roids[old]);
for (c, centroid) in centroids.iter().enumerate() {
let sdist = square_distance(e, centroid);
if sdist < dist {
dist = sdist;
clus = c;
changes += 1;
}
}
membership[i] = clus;
}
counts.iter_mut().for_each(|x| *x = 0);
centroids.iter_mut().for_each(|c|
c.0.iter_mut().for_each(|d| *d = 0.0));
for (i, elem) in elems.iter().enumerate() {
let clus = membership[i];
counts[clus] += 1;
for (d, dim) in centroids[clus].0.iter_mut().enumerate() {
*dim += elem.at(d);
}
}
for (centroid, size) in centroids.iter_mut().zip(counts.iter().copied()) {
centroid.0.iter_mut().for_each(|d| if size == 0 { *d = 0.0 } else {*d /= size as f64});
}
if changes == 0 {
#[cfg(feature = "logging")]
log::info!("clustering kmeans: short circuit after nb iter : {}", it);
break;
}
}
Clustering {
elements: elems,
membership,
centroids
}
}
fn square_distance(a: &dyn Elem, b: &dyn Elem) -> f64 {
let mut tot = 0.0;
let n = a.dimensions();
for i in 0..n {
let dim = b.at(i) - a.at(i);
tot += dim * dim;
}
tot
}
fn initialize<T: Elem>(k: usize, elems: &[T]) -> Vec<Centroid> {
let mut taken = vec![false; elems.len()];
let mut centroids = vec![];
let first = rand::random::<usize>() % elems.len();
taken[first] = true;
centroids.push(new_centroid(&elems[first]));
for _ in 1..k {
let mut imax = 0;
let mut dmax = f64::NEG_INFINITY;
for (i, elem) in elems.iter().enumerate() {
if taken[i] {
continue;
}
let mut dxmin = f64::INFINITY;
for centroid in centroids.iter() {
let dx = square_distance(elem, centroid);
if dx < dxmin {
dxmin = dx;
}
}
if dxmin > dmax {
dmax = dxmin;
imax = i;
}
}
taken[imax] = true;
centroids.push(new_centroid(&elems[imax]));
}
centroids
}
fn new_centroid<T: Elem>(elem: &T) -> Centroid {
let mut centroid = vec![];
let dimensions = elem.dimensions();
for i in 0..dimensions {
centroid.push(elem.at(i));
}
Centroid(centroid)
}
impl Elem for Centroid {
fn dimensions(&self) -> usize {
self.0.len()
}
fn at(&self, i: usize) -> f64 {
self.0[i]
}
}
macro_rules! elem {
($x: ty) => {
impl Elem for Vec<$x> {
fn dimensions(&self) -> usize {
self.len()
}
fn at(&self, i: usize) -> f64 {
self[i] as f64
}
}
impl Elem for &[$x] {
fn dimensions(&self) -> usize {
self.len()
}
fn at(&self, i: usize) -> f64 {
self[i] as f64
}
}
};
}
elem!(u8);
elem!(u16);
elem!(u32);
elem!(u64);
elem!(usize);
elem!(i8);
elem!(i16);
elem!(i32);
elem!(i64);
elem!(isize);
elem!(f32);
elem!(f64);
#[cfg(test)]
mod test {
use crate::*;
#[test]
fn test_impl() {
let a: &[i32] = &[1, 2, 3, 4];
let b: &[i32] = &[5, 4, 3, 2];
let dst = square_distance(&a, &b);
assert_eq!(24.0, dst);
}
#[test]
fn example() {
let items: &[&[f32]] = &[
&[ 1.0],
&[ 1.1],
&[ 0.9],
&[10.0],
&[11.1],
&[10.9],
&[30.0],
&[31.1],
&[30.9],
];
let clus = kmeans(3, items, 1000);
println!("centroids = {:?}", clus.membership);
println!("membership = {:?}", clus.centroids);
}
}