use linalg::{Matrix, MatrixSlice, Axes, Vector, BaseMatrix};
use learning::{LearningResult, UnSupModel};
use learning::error::{Error, ErrorKind};
use rand::{Rng, thread_rng};
use libnum::abs;
use std::fmt::Debug;
#[derive(Debug)]
pub struct KMeansClassifier<InitAlg: Initializer> {
iters: usize,
k: usize,
centroids: Option<Matrix<f64>>,
init_algorithm: InitAlg,
}
impl<InitAlg: Initializer> UnSupModel<Matrix<f64>, Vector<usize>> for KMeansClassifier<InitAlg> {
fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<usize>> {
if let Some(ref centroids) = self.centroids {
Ok(KMeansClassifier::<InitAlg>::find_closest_centroids(centroids.as_slice(), inputs).0)
} else {
Err(Error::new_untrained())
}
}
fn train(&mut self, inputs: &Matrix<f64>) -> LearningResult<()> {
try!(self.init_centroids(inputs));
let mut cost = 0.0;
let eps = 1e-14;
for _i in 0..self.iters {
let (idx, distances) = try!(self.get_closest_centroids(inputs));
self.update_centroids(inputs, idx);
let cost_i = distances.sum();
if abs(cost - cost_i) < eps {
break;
}
cost = cost_i;
}
Ok(())
}
}
impl KMeansClassifier<KPlusPlus> {
pub fn new(k: usize) -> KMeansClassifier<KPlusPlus> {
KMeansClassifier {
iters: 100,
k: k,
centroids: None,
init_algorithm: KPlusPlus,
}
}
}
impl<InitAlg: Initializer> KMeansClassifier<InitAlg> {
pub fn new_specified(k: usize, iters: usize, algo: InitAlg) -> KMeansClassifier<InitAlg> {
KMeansClassifier {
iters: iters,
k: k,
centroids: None,
init_algorithm: algo,
}
}
pub fn k(&self) -> usize {
self.k
}
pub fn iters(&self) -> usize {
self.iters
}
pub fn init_algorithm(&self) -> &InitAlg {
&self.init_algorithm
}
pub fn centroids(&self) -> &Option<Matrix<f64>> {
&self.centroids
}
pub fn set_iters(&mut self, iters: usize) {
self.iters = iters;
}
fn init_centroids(&mut self, inputs: &Matrix<f64>) -> LearningResult<()> {
if self.k > inputs.rows() {
Err(Error::new(ErrorKind::InvalidData,
format!("Number of clusters ({0}) exceeds number of data points \
({1}).",
self.k,
inputs.rows())))
} else {
let centroids = try!(self.init_algorithm.init_centroids(self.k, inputs));
if centroids.rows() != self.k {
Err(Error::new(ErrorKind::InvalidState,
"Initial centroids must have exactly k rows."))
} else if centroids.cols() != inputs.cols() {
Err(Error::new(ErrorKind::InvalidState,
"Initial centroids must have the same column count as inputs."))
} else {
self.centroids = Some(centroids);
Ok(())
}
}
}
fn update_centroids(&mut self, inputs: &Matrix<f64>, classes: Vector<usize>) {
let mut new_centroids = Vec::with_capacity(self.k * inputs.cols());
let mut row_indexes = vec![Vec::new(); self.k];
for (i, c) in classes.into_vec().into_iter().enumerate() {
row_indexes.get_mut(c as usize).map(|v| v.push(i));
}
for vec_i in row_indexes {
let mat_i = inputs.select_rows(&vec_i);
new_centroids.extend(mat_i.mean(Axes::Row).into_vec());
}
self.centroids = Some(Matrix::new(self.k, inputs.cols(), new_centroids));
}
fn get_closest_centroids(&self,
inputs: &Matrix<f64>)
-> LearningResult<(Vector<usize>, Vector<f64>)> {
if let Some(ref c) = self.centroids {
Ok(KMeansClassifier::<InitAlg>::find_closest_centroids(c.as_slice(), inputs))
} else {
Err(Error::new(ErrorKind::InvalidState,
"Centroids not correctly initialized."))
}
}
fn find_closest_centroids(centroids: MatrixSlice<f64>,
inputs: &Matrix<f64>)
-> (Vector<usize>, Vector<f64>) {
let mut idx = Vec::with_capacity(inputs.rows());
let mut distances = Vec::with_capacity(inputs.rows());
for i in 0..inputs.rows() {
let centroid_diff = centroids - inputs.select_rows(&vec![i; centroids.rows()]);
let dist = ¢roid_diff.elemul(¢roid_diff).sum_cols();
let (min_idx, min_dist) = dist.argmin();
idx.push(min_idx);
distances.push(min_dist);
}
(Vector::new(idx), Vector::new(distances))
}
}
pub trait Initializer: Debug {
fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>>;
}
#[derive(Debug)]
pub struct Forgy;
impl Initializer for Forgy {
fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
let mut random_choices = Vec::with_capacity(k);
let mut rng = thread_rng();
while random_choices.len() < k {
let r = rng.gen_range(0, inputs.rows());
if !random_choices.contains(&r) {
random_choices.push(r);
}
}
Ok(inputs.select_rows(&random_choices))
}
}
#[derive(Debug)]
pub struct RandomPartition;
impl Initializer for RandomPartition {
fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
let mut random_assignments = (0..k).map(|i| vec![i]).collect::<Vec<Vec<usize>>>();
let mut rng = thread_rng();
for i in k..inputs.rows() {
let idx = rng.gen_range(0, k);
unsafe {
random_assignments.get_unchecked_mut(idx).push(i);
}
}
let mut init_centroids = Vec::with_capacity(k * inputs.cols());
for vec_i in random_assignments {
let mat_i = inputs.select_rows(&vec_i);
init_centroids.extend_from_slice(&*mat_i.mean(Axes::Row).into_vec());
}
Ok(Matrix::new(k, inputs.cols(), init_centroids))
}
}
#[derive(Debug)]
pub struct KPlusPlus;
impl Initializer for KPlusPlus {
fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
let mut rng = thread_rng();
let mut init_centroids = Vec::with_capacity(k * inputs.cols());
let first_cen = rng.gen_range(0usize, inputs.rows());
unsafe {
init_centroids.extend_from_slice(inputs.get_row_unchecked(first_cen));
}
for i in 1..k {
unsafe {
let temp_centroids = MatrixSlice::from_raw_parts(init_centroids.as_ptr(),
i,
inputs.cols(),
inputs.cols());
let (_, dist) =
KMeansClassifier::<KPlusPlus>::find_closest_centroids(temp_centroids, &inputs);
if !dist.data().iter().all(|x| x.is_finite()) {
return Err(Error::new(ErrorKind::InvalidData,
"Input data led to invalid centroid distances during \
initialization."));
}
let next_cen = sample_discretely(dist);
init_centroids.extend_from_slice(inputs.get_row_unchecked(next_cen));
}
}
Ok(Matrix::new(k, inputs.cols(), init_centroids))
}
}
fn sample_discretely(unnorm_dist: Vector<f64>) -> usize {
assert!(unnorm_dist.size() > 0, "No entries in distribution vector.");
let sum = unnorm_dist.sum();
let rand = thread_rng().gen_range(0.0f64, sum);
let mut tempsum = 0.0;
for (i, p) in unnorm_dist.data().iter().enumerate() {
tempsum += *p;
if rand < tempsum {
return i;
}
}
panic!("No random value was sampled! There may be more clusters than unique data points.");
}