use crate::memory::*;
use crate::AbortStrategy;
use core::simd::Simd;
use rand::prelude::*;
use rayon::prelude::*;
use std::cell::RefCell;
pub type InitDoneCallbackFn<'a, T> = &'a dyn Fn(&KMeansState<T>);
pub type IterationDoneCallbackFn<'a, T> = &'a dyn Fn(&KMeansState<T>, usize, T);
pub struct KMeansConfig<'a, T: Primitive> {
pub(crate) init_done: InitDoneCallbackFn<'a, T>,
pub(crate) iteration_done: IterationDoneCallbackFn<'a, T>,
pub(crate) rnd: Box<RefCell<dyn RngCore>>,
pub(crate) abort_strategy: AbortStrategy<T>,
}
impl<T: Primitive> Default for KMeansConfig<'_, T> {
fn default() -> Self {
Self {
init_done: &|_| {},
iteration_done: &|_, _, _| {},
rnd: Box::new(RefCell::new(rand::thread_rng())),
abort_strategy: AbortStrategy::<T>::NoImprovement {
threshold: T::from(0.0005).unwrap(),
},
}
}
}
impl<'a, T: Primitive> KMeansConfig<'a, T> {
pub fn build() -> KMeansConfigBuilder<'a, T> {
KMeansConfigBuilder {
config: KMeansConfig::default(),
}
}
}
impl<T: Primitive> std::fmt::Debug for KMeansConfig<'_, T> {
fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Ok(()) }
}
pub struct KMeansConfigBuilder<'a, T: Primitive> {
config: KMeansConfig<'a, T>,
}
impl<'a, T: Primitive> KMeansConfigBuilder<'a, T> {
pub fn init_done(mut self, init_done: InitDoneCallbackFn<'a, T>) -> Self {
self.config.init_done = init_done;
self
}
pub fn iteration_done(mut self, iteration_done: IterationDoneCallbackFn<'a, T>) -> Self {
self.config.iteration_done = iteration_done;
self
}
pub fn random_generator<R: RngCore + 'static>(mut self, rnd: R) -> Self {
self.config.rnd = Box::new(RefCell::new(rnd));
self
}
pub fn abort_strategy(mut self, abort_strategy: AbortStrategy<T>) -> Self {
self.config.abort_strategy = abort_strategy;
self
}
pub fn build(self) -> KMeansConfig<'a, T> { self.config }
}
#[derive(Clone, Debug)]
pub struct KMeansState<T: Primitive> {
pub k: usize,
pub distsum: T,
pub centroids: StrideBuffer<T>,
pub centroid_frequency: Vec<usize>,
pub assignments: Vec<usize>,
pub centroid_distances: Vec<T>,
}
impl<T: Primitive> KMeansState<T> {
pub(crate) fn new<const LANES: usize>(sample_cnt: usize, sample_dims: usize, k: usize) -> Self {
Self {
k,
distsum: T::zero(),
centroids: StrideBuffer::new::<LANES>(k, sample_dims),
centroid_frequency: vec![0usize; k],
assignments: vec![0usize; sample_cnt],
centroid_distances: vec![T::infinity(); sample_cnt],
}
}
}
pub trait DistanceFunction<T, const LANES: usize>: Send + Sync {
fn distance(&self, a: &[T], b: &[T]) -> T;
}
pub struct KMeans<T, const LANES: usize, D: DistanceFunction<T, LANES>>
where
T: Primitive,
Simd<T, LANES>: SupportedSimdArray<T, LANES>,
{
pub(crate) sample_cnt: usize,
pub(crate) sample_dims: usize,
pub(crate) p_samples: StrideBuffer<T>,
pub(crate) distance_fn: D,
}
impl<T, const LANES: usize, D: DistanceFunction<T, LANES>> KMeans<T, LANES, D>
where
T: Primitive,
Simd<T, LANES>: SupportedSimdArray<T, LANES>,
{
pub fn new(samples: &[T], sample_cnt: usize, sample_dims: usize, distance_fn: D) -> Self {
assert!(samples.len() == sample_cnt * sample_dims);
Self {
sample_cnt,
sample_dims,
p_samples: StrideBuffer::from_slice::<LANES>(sample_dims, samples),
distance_fn,
}
}
pub(crate) fn update_centroid_distances(&self, state: &mut KMeansState<T>) {
let centroids = &state.centroids;
let work_packet_size = self.p_samples.bfr.len() / self.p_samples.stride / rayon::current_num_threads();
self.p_samples
.bfr
.par_chunks_exact(self.p_samples.stride)
.with_min_len(work_packet_size)
.zip(state.assignments.par_iter().cloned())
.zip(state.centroid_distances.par_iter_mut())
.for_each(|((s, assignment), centroid_dist)| {
*centroid_dist = self.distance_fn.distance(s, centroids.nth_stride(assignment));
});
}
pub(crate) fn update_cluster_assignments(&self, state: &mut KMeansState<T>, limit_k: Option<usize>) {
let centroids = &state.centroids;
let k = limit_k.unwrap_or(state.k);
let work_packet_size = self.p_samples.bfr.len() / self.p_samples.stride / rayon::current_num_threads();
self.p_samples
.bfr
.par_chunks_exact(self.p_samples.stride)
.with_min_len(work_packet_size)
.zip(state.assignments.par_iter_mut())
.zip(state.centroid_distances.par_iter_mut())
.for_each(|((s, assignment), centroid_dist)| {
let (best_idx, best_dist) = centroids
.chunks_exact_stride()
.take(k)
.map(|c| self.distance_fn.distance(s, c))
.enumerate()
.min_by(|(_, d0), (_, d1)| d0.partial_cmp(d1).unwrap())
.unwrap();
*assignment = best_idx;
*centroid_dist = best_dist;
});
}
pub(crate) fn update_cluster_frequencies(&self, assignments: &[usize], centroid_frequency: &mut [usize]) -> usize {
centroid_frequency.iter_mut().for_each(|v| *v = 0);
let mut used_centroids_cnt = 0;
assignments.iter().cloned().for_each(|centroid_id| {
if centroid_frequency[centroid_id] == 0 {
used_centroids_cnt += 1; }
centroid_frequency[centroid_id] += 1;
});
used_centroids_cnt
}
pub fn kmeans_lloyd<F>(&self, k: usize, max_iter: usize, init: F, config: &KMeansConfig<'_, T>) -> KMeansState<T>
where
for<'c> F: FnOnce(&KMeans<T, LANES, D>, &mut KMeansState<T>, &KMeansConfig<'c, T>),
{
crate::variants::Lloyd::calculate(self, k, max_iter, init, config)
}
pub fn kmeans_minibatch<F>(&self, batch_size: usize, k: usize, max_iter: usize, init: F, config: &KMeansConfig<'_, T>) -> KMeansState<T>
where
for<'c> F: FnOnce(&KMeans<T, LANES, D>, &mut KMeansState<T>, &KMeansConfig<'c, T>),
T: Primitive,
Simd<T, LANES>: SupportedSimdArray<T, LANES>,
{
crate::variants::Minibatch::calculate(self, batch_size, k, max_iter, init, config)
}
pub fn init_kmeanplusplus(kmean: &KMeans<T, LANES, D>, state: &mut KMeansState<T>, config: &KMeansConfig<'_, T>) {
crate::inits::kmeanplusplus::calculate(kmean, state, config);
}
pub fn init_random_partition(kmean: &KMeans<T, LANES, D>, state: &mut KMeansState<T>, config: &KMeansConfig<'_, T>) {
crate::inits::randompartition::calculate(kmean, state, config);
}
pub fn init_random_sample(kmean: &KMeans<T, LANES, D>, state: &mut KMeansState<T>, config: &KMeansConfig<'_, T>) {
crate::inits::randomsample::calculate(kmean, state, config);
}
pub fn init_precomputed(centroids: Vec<T>) -> impl Fn(&KMeans<T, LANES, D>, &mut KMeansState<T>, &KMeansConfig<'_, T>) {
move |kmean, state, config| {
crate::inits::precomputed::calculate(kmean, state, config, ¢roids);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::EuclideanDistance;
use test::Bencher;
#[test]
fn padding_and_cluster_assignments() {
calculate_cluster_assignments_multiplex(1);
calculate_cluster_assignments_multiplex(2);
calculate_cluster_assignments_multiplex(3);
calculate_cluster_assignments_multiplex(97);
calculate_cluster_assignments_multiplex(98);
calculate_cluster_assignments_multiplex(99);
calculate_cluster_assignments_multiplex(100);
}
fn calculate_cluster_assignments_multiplex(sample_dims: usize) {
calculate_cluster_assignments::<f64, 8>(sample_dims, 1e-10f64);
calculate_cluster_assignments::<f64, 4>(sample_dims, 1e-10f64);
calculate_cluster_assignments::<f64, 2>(sample_dims, 1e-10f64);
calculate_cluster_assignments::<f32, 16>(sample_dims, 1.2e-5f32);
calculate_cluster_assignments::<f32, 8>(sample_dims, 1.2e-5f32);
calculate_cluster_assignments::<f32, 4>(sample_dims, 1.2e-5f32);
calculate_cluster_assignments::<f32, 2>(sample_dims, 1.2e-5f32);
}
fn calculate_cluster_assignments<T, const LANES: usize>(sample_dims: usize, max_diff: T)
where
T: Primitive,
Simd<T, LANES>: SupportedSimdArray<T, LANES>,
{
let sample_cnt = 1000;
let k = 5;
let mut samples = vec![T::zero(); sample_cnt * sample_dims];
let mut rng = rand::rngs::StdRng::seed_from_u64(1337);
samples.iter_mut().for_each(|i| *i = rng.gen_range(T::zero()..T::one()));
let kmean = KMeans::new(&samples, sample_cnt, sample_dims, EuclideanDistance);
let mut state = KMeansState::new::<LANES>(kmean.sample_cnt, sample_dims, k);
state
.centroids
.bfr
.iter_mut()
.zip(kmean.p_samples.bfr.iter())
.for_each(|(c, s)| *c = *s);
let mut should_assignments = state.assignments.clone();
let mut should_centroid_distances = state.centroid_distances.clone();
kmean
.p_samples
.chunks_exact_stride()
.zip(should_assignments.iter_mut())
.zip(should_centroid_distances.iter_mut())
.for_each(|((s, assignment), centroid_dist)| {
let (best_idx, best_dist) = state
.centroids
.chunks_exact_stride()
.map(|c| {
s.iter()
.cloned()
.zip(c.iter().cloned())
.map(|(sv, cv)| sv - cv)
.map(|v| v * v)
.sum::<T>()
})
.enumerate()
.min_by(|(_, d0), (_, d1)| d0.partial_cmp(d1).unwrap())
.unwrap();
*assignment = best_idx;
*centroid_dist = best_dist;
});
kmean.update_cluster_assignments(&mut state, None);
for i in 0..should_assignments.len() {
assert_approx_eq!(state.centroid_distances[i], should_centroid_distances[i], max_diff);
}
assert_eq!(state.assignments, should_assignments);
}
#[bench]
fn distance_matrix_calculation_benchmark_f64x8(b: &mut Bencher) { distance_matrix_calculation_benchmark::<f64, 8>(b); }
#[bench]
fn distance_matrix_calculation_benchmark_f64x4(b: &mut Bencher) { distance_matrix_calculation_benchmark::<f64, 4>(b); }
#[bench]
fn distance_matrix_calculation_benchmark_f64x2(b: &mut Bencher) { distance_matrix_calculation_benchmark::<f64, 2>(b); }
#[bench]
fn distance_matrix_calculation_benchmark_f32x16(b: &mut Bencher) { distance_matrix_calculation_benchmark::<f32, 16>(b); }
#[bench]
fn distance_matrix_calculation_benchmark_f32x8(b: &mut Bencher) { distance_matrix_calculation_benchmark::<f32, 8>(b); }
#[bench]
fn distance_matrix_calculation_benchmark_f32x4(b: &mut Bencher) { distance_matrix_calculation_benchmark::<f32, 4>(b); }
#[bench]
fn distance_matrix_calculation_benchmark_f32x2(b: &mut Bencher) { distance_matrix_calculation_benchmark::<f32, 2>(b); }
fn distance_matrix_calculation_benchmark<T, const LANES: usize>(b: &mut Bencher)
where
T: Primitive,
Simd<T, LANES>: SupportedSimdArray<T, LANES>,
{
let sample_cnt = 20000;
let sample_dims = 2000;
let k = LANES;
let mut samples = vec![T::zero(); sample_cnt * sample_dims];
let mut rng = rand::rngs::StdRng::seed_from_u64(1337);
samples.iter_mut().for_each(|v| *v = rng.gen_range(T::zero()..T::one()));
let kmean: KMeans<T, LANES, _> = KMeans::new(&samples, sample_cnt, sample_dims, EuclideanDistance);
let mut state = KMeansState::new::<LANES>(kmean.sample_cnt, sample_dims, k);
state
.centroids
.bfr
.iter_mut()
.zip(kmean.p_samples.bfr.iter())
.for_each(|(c, s)| *c = *s);
b.iter(|| {
KMeans::update_cluster_assignments(&kmean, &mut state, None);
state.clone()
});
}
}