use crate::api::DistanceFunction;
use crate::memory::*;
use crate::{KMeans, KMeansConfig, KMeansState};
use rand::prelude::*;
use rayon::prelude::*;
use std::ops::{DerefMut, Range};
use std::simd::Simd;
struct BatchInfo {
start_idx: usize,
batch_size: usize,
}
impl BatchInfo {
fn gen_range(&self, stride: usize) -> Range<usize> {
Range {
start: (self.start_idx * stride),
end: (self.start_idx * stride + self.batch_size * stride),
}
}
}
pub(crate) struct Minibatch<T, const LANES: usize, D>
where
T: Primitive,
D: DistanceFunction<T, LANES>,
{
_p: std::marker::PhantomData<(T, D)>,
}
impl<T, const LANES: usize, D> Minibatch<T, LANES, D>
where
T: Primitive,
Simd<T, LANES>: SupportedSimdArray<T, LANES>,
D: DistanceFunction<T, LANES>,
{
fn update_cluster_assignments(
data: &KMeans<T, LANES, D>, state: &mut KMeansState<T>, batch: &BatchInfo, shuffled_samples: &[T], limit_k: Option<usize>,
) {
let centroids = &state.centroids;
let k = limit_k.unwrap_or(state.k);
let work_packet_size = batch.batch_size / rayon::current_num_threads();
shuffled_samples[batch.gen_range(data.p_samples.stride)]
.par_chunks_exact(data.p_samples.stride)
.with_min_len(work_packet_size)
.zip(state.assignments[batch.gen_range(1)].par_iter_mut())
.zip(state.centroid_distances[batch.gen_range(1)].par_iter_mut())
.for_each(|((s, assignment), centroid_dist)| {
let (best_idx, best_dist) = centroids
.chunks_exact_stride()
.take(k)
.map(|c| data.distance_fn.distance(s, c))
.enumerate()
.min_by(|(_, d0), (_, d1)| d0.partial_cmp(d1).unwrap())
.unwrap();
*assignment = best_idx;
*centroid_dist = best_dist;
});
}
fn update_centroids(data: &KMeans<T, LANES, D>, state: &mut KMeansState<T>, batch: &BatchInfo, shuffled_samples: &[T]) {
let centroid_frequency = &mut state.centroid_frequency;
let centroids = &mut state.centroids;
let assignments = &state.assignments;
shuffled_samples[batch.gen_range(data.p_samples.stride)]
.chunks_exact(data.p_samples.stride)
.zip(assignments[batch.gen_range(1)].iter().cloned())
.for_each(|(sample, assignment)| {
centroid_frequency[assignment] += 1;
let learn_rate = T::one() / T::from(centroid_frequency[assignment]).unwrap();
let inv_learn_rate = T::one() - learn_rate;
centroids
.bfr
.iter_mut()
.skip(assignment * data.p_samples.stride)
.take(data.p_samples.stride)
.zip(sample.iter().cloned())
.for_each(|(c, s)| {
*c = inv_learn_rate * *c + learn_rate * s;
});
});
}
fn shuffle_samples(data: &KMeans<T, LANES, D>, config: &KMeansConfig<'_, T>) -> (Vec<usize>, StrideBuffer<T>) {
let mut idxs: Vec<usize> = (0..data.sample_cnt).collect();
idxs.shuffle(config.rnd.borrow_mut().deref_mut());
let mut shuffled_samples = StrideBuffer::new::<LANES>(data.sample_cnt, data.sample_dims);
shuffled_samples.iter_mut().zip(idxs.iter()).for_each(|(dst, src_idx)| {
dst.copy_from_slice(&data.p_samples[*src_idx]);
});
(idxs, shuffled_samples)
}
fn unshuffle_state(shuffle_idxs: &[usize], state: &mut KMeansState<T>) {
for (from, to) in shuffle_idxs.iter().cloned().enumerate() {
state.assignments.swap(from, to);
state.centroid_distances.swap(from, to);
}
}
#[inline(always)]
pub fn calculate<F>(
data: &KMeans<T, LANES, D>, batch_size: usize, k: usize, max_iter: usize, init: F, config: &KMeansConfig<'_, T>,
) -> KMeansState<T>
where
for<'c> F: FnOnce(&KMeans<T, LANES, D>, &mut KMeansState<T>, &KMeansConfig<'c, T>),
{
assert!(k <= data.sample_cnt);
assert!(batch_size <= data.sample_cnt);
let (shuffle_idxs, shuffled_samples) = Self::shuffle_samples(data, config);
let mut state = KMeansState::new::<LANES>(data.sample_cnt, data.sample_dims, k);
state.distsum = T::infinity();
init(data, &mut state, config);
(config.init_done)(&state);
let mut abort_strategy = config.abort_strategy.create_logic();
Self::update_cluster_assignments(
data,
&mut state,
&BatchInfo {
start_idx: 0,
batch_size: data.sample_cnt,
},
&shuffled_samples.bfr, None,
);
for i in 1..=max_iter {
let batch = BatchInfo {
batch_size,
start_idx: config.rnd.borrow_mut().gen_range(0..data.sample_cnt - batch_size),
};
Self::update_cluster_assignments(data, &mut state, &batch, &shuffled_samples.bfr, None);
let new_distsum = state.centroid_distances.iter().cloned().sum();
Self::update_centroids(data, &mut state, &batch, &shuffled_samples.bfr);
(config.iteration_done)(&state, i, new_distsum);
if !abort_strategy.next(new_distsum) {
break;
}
state.distsum = new_distsum;
}
Self::unshuffle_state(&shuffle_idxs, &mut state);
data.update_cluster_assignments(&mut state, None);
let (assignments, centroid_frequency, centroid_distances, distsum) = (
&state.assignments, &mut state.centroid_frequency, &mut state.centroid_distances, &mut state.distsum,
);
let mut non_empty_clusters = state.k;
rayon::scope(|s| {
s.spawn(|_| {
non_empty_clusters -= data.update_cluster_frequencies(assignments, centroid_frequency);
});
s.spawn(|_| {
*distsum = centroid_distances.iter().cloned().sum();
});
});
state
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::helpers::testing::{assert_kmeans_result_eq, KMeansShouldResult};
use crate::{AbortStrategy, EuclideanDistance};
#[test]
fn iris_dataset_f64() {
let samples = vec![
1.4f64, 0.2, 1.4, 0.2, 1.3, 0.2, 1.5, 0.2, 1.4, 0.2, 1.7, 0.4, 1.4, 0.3, 1.5, 0.2, 1.4, 0.2, 1.5, 0.1, 1.5, 0.2, 1.6, 0.2, 1.4,
0.1, 1.1, 0.1, 1.2, 0.2, 1.5, 0.4, 1.3, 0.4, 1.4, 0.3, 1.7, 0.3, 1.5, 0.3, 1.7, 0.2, 1.5, 0.4, 1.0, 0.2, 1.7, 0.5, 1.9, 0.2,
1.6, 0.2, 1.6, 0.4, 1.5, 0.2, 1.4, 0.2, 1.6, 0.2, 1.6, 0.2, 1.5, 0.4, 1.5, 0.1, 1.4, 0.2, 1.5, 0.2, 1.2, 0.2, 1.3, 0.2, 1.4,
0.1, 1.3, 0.2, 1.5, 0.2, 1.3, 0.3, 1.3, 0.3, 1.3, 0.2, 1.6, 0.6, 1.9, 0.4, 1.4, 0.3, 1.6, 0.2, 1.4, 0.2, 1.5, 0.2, 1.4, 0.2,
4.7, 1.4, 4.5, 1.5, 4.9, 1.5, 4.0, 1.3, 4.6, 1.5, 4.5, 1.3, 4.7, 1.6, 3.3, 1.0, 4.6, 1.3, 3.9, 1.4, 3.5, 1.0, 4.2, 1.5, 4.0,
1.0, 4.7, 1.4, 3.6, 1.3, 4.4, 1.4, 4.5, 1.5, 4.1, 1.0, 4.5, 1.5, 3.9, 1.1, 4.8, 1.8, 4.0, 1.3, 4.9, 1.5, 4.7, 1.2, 4.3, 1.3,
4.4, 1.4, 4.8, 1.4, 5.0, 1.7, 4.5, 1.5, 3.5, 1.0, 3.8, 1.1, 3.7, 1.0, 3.9, 1.2, 5.1, 1.6, 4.5, 1.5, 4.5, 1.6, 4.7, 1.5, 4.4,
1.3, 4.1, 1.3, 4.0, 1.3, 4.4, 1.2, 4.6, 1.4, 4.0, 1.2, 3.3, 1.0, 4.2, 1.3, 4.2, 1.2, 4.2, 1.3, 4.3, 1.3, 3.0, 1.1, 4.1, 1.3,
6.0, 2.5, 5.1, 1.9, 5.9, 2.1, 5.6, 1.8, 5.8, 2.2, 6.6, 2.1, 4.5, 1.7, 6.3, 1.8, 5.8, 1.8, 6.1, 2.5, 5.1, 2.0, 5.3, 1.9, 5.5,
2.1, 5.0, 2.0, 5.1, 2.4, 5.3, 2.3, 5.5, 1.8, 6.7, 2.2, 6.9, 2.3, 5.0, 1.5, 5.7, 2.3, 4.9, 2.0, 6.7, 2.0, 4.9, 1.8, 5.7, 2.1,
6.0, 1.8, 4.8, 1.8, 4.9, 1.8, 5.6, 2.1, 5.8, 1.6, 6.1, 1.9, 6.4, 2.0, 5.6, 2.2, 5.1, 1.5, 5.6, 1.4, 6.1, 2.3, 5.6, 2.4, 5.5,
1.8, 4.8, 1.8, 5.4, 2.1, 5.6, 2.4, 5.1, 2.3, 5.1, 1.9, 5.9, 2.3, 5.7, 2.5, 5.2, 2.3, 5.0, 1.9, 5.2, 2.0, 5.4, 2.3, 5.1, 1.8,
];
let kmean: KMeans<f64, 8, _> = KMeans::new(&samples, 150, 2, EuclideanDistance);
let rnd = rand::rngs::StdRng::seed_from_u64(3);
let conf = KMeansConfig::build()
.random_generator(rnd)
.abort_strategy(AbortStrategy::NoImprovementForXIterations {
x: 5,
threshold: 0.0005f64,
abort_on_negative: true,
})
.build();
let res = kmean.kmeans_minibatch(30, 3, 100, KMeans::init_kmeanplusplus, &conf);
let should = KMeansShouldResult {
distsum: 31.691483430123924,
sample_dims: 2,
assignments: vec![
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2,
0, 2, 2, 0, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
],
centroid_distances: vec![
0.008367214532872126, 0.008367214532872126, 0.03307309688581347, 0.0036613321799307407, 0.008367214532872126,
0.07248486159169505, 0.00748486159169566, 0.0036613321799307407, 0.008367214532872126, 0.02454368512110721,
0.0036613321799307407, 0.018955449826989393, 0.029249567474048596, 0.16336721453287262, 0.07777897923875488,
0.021896626297577815, 0.051308391003460546, 0.00748486159169566, 0.05336721453287151, 0.0027789792387542737,
0.05424956747404797, 0.021896626297577815, 0.2271907439446376, 0.11160250865051857, 0.18483780276816517,
0.018955449826989393, 0.037190743944636465, 0.0036613321799307407, 0.008367214532872126, 0.018955449826989393,
0.018955449826989393, 0.021896626297577815, 0.02454368512110721, 0.008367214532872126, 0.0036613321799307407,
0.07777897923875488, 0.03307309688581347, 0.029249567474048596, 0.03307309688581347, 0.0036613321799307407,
0.032190743944637, 0.032190743944637, 0.03307309688581347, 0.1354260380622835, 0.20307309688581224, 0.00748486159169566,
0.018955449826989393, 0.008367214532872126, 0.0036613321799307407, 0.008367214532872126, 0.11522168220591258,
0.02902449910732018, 0.30057379488196956, 0.1404329498115445, 0.06691182305098226, 0.029869569529855745,
0.15437661178337703, 1.2864892878397125, 0.06775689347351782, 0.21212309065661458, 0.9022639357270366,
0.035362527276333384, 0.2917005554453479, 0.11522168220591258, 0.5888836540368955, 0.0015597103749257503,
0.02902449910732018, 0.22958787938901032, 0.02902449910732018, 0.30339069629041787, 0.3514188653045034, 0.1404329498115445,
0.30057379488196956, 0.15606675262844816, 0.014094921642531275, 0.0015597103749257503, 0.19310900614957455,
0.4976160484030958, 0.02902449910732018, 0.9022639357270366, 0.40550337234675576, 0.5980385836143609, 0.25296816107915016,
0.5381776147959252, 0.02902449910732018, 0.05860196389605243, 0.1247991469946448, 0.011982245586193527,
0.07832027375520695, 0.1404329498115445, 0.04240478079746134, 0.05733435826225004, 0.17085548502281234, 1.2864892878397125,
0.03620759769886895, 0.06663013291013677, 0.03620759769886895, 0.014094921642531275, 1.9424047807974576,
0.07832027375520695, 0.2971061862244855, 0.34371332908163893, 0.058356186224487565, 0.0787133290816339,
0.03549904336734527, 0.88460618622448, 0.10817942868478458, 0.48496332908162676, 0.09478475765306044, 0.37514190051019847,
0.3188919005102101, 0.1597847576530653, 0.026213329081634155, 0.4408561862244964, 0.4196061862244948, 0.18049904336734995,
0.10067761479592055, 1.0978204719387654, 1.5890704719387636, 0.4184611188256314, 0.05264190051020309, 0.5828204719387825,
1.0874633290816231, 0.44930618924816623, 0.002284757653060791, 0.1908561862244871, 0.3514188653045034, 0.44930618924816623,
0.0042490433673475075, 0.24442761479591799, 0.22407047193877133, 0.5533561862244832, 0.01942761479591874,
0.5563484427692932, 0.45799904336734926, 0.244784757653056, 0.10978475765306102, 0.10067761479592055, 0.3514188653045034,
0.06817761479592066, 0.10978475765306102, 0.3644276147959236, 0.34371332908163893, 0.10871332908162987,
0.18299904336734557, 0.2624633290816364, 0.46567761479592523, 0.21692761479592287, 0.11853475765306296,
0.38853475765306766,
],
centroids: vec![
4.360563380281689, 1.402112676056339, 1.4735294117647069, 0.25441176470588234, 5.659821428571433, 2.074107142857144,
],
};
assert_kmeans_result_eq(should, res);
}
#[test]
fn iris_dataset_f32() {
let samples = vec![
1.4f32, 0.2, 1.4, 0.2, 1.3, 0.2, 1.5, 0.2, 1.4, 0.2, 1.7, 0.4, 1.4, 0.3, 1.5, 0.2, 1.4, 0.2, 1.5, 0.1, 1.5, 0.2, 1.6, 0.2, 1.4,
0.1, 1.1, 0.1, 1.2, 0.2, 1.5, 0.4, 1.3, 0.4, 1.4, 0.3, 1.7, 0.3, 1.5, 0.3, 1.7, 0.2, 1.5, 0.4, 1.0, 0.2, 1.7, 0.5, 1.9, 0.2,
1.6, 0.2, 1.6, 0.4, 1.5, 0.2, 1.4, 0.2, 1.6, 0.2, 1.6, 0.2, 1.5, 0.4, 1.5, 0.1, 1.4, 0.2, 1.5, 0.2, 1.2, 0.2, 1.3, 0.2, 1.4,
0.1, 1.3, 0.2, 1.5, 0.2, 1.3, 0.3, 1.3, 0.3, 1.3, 0.2, 1.6, 0.6, 1.9, 0.4, 1.4, 0.3, 1.6, 0.2, 1.4, 0.2, 1.5, 0.2, 1.4, 0.2,
4.7, 1.4, 4.5, 1.5, 4.9, 1.5, 4.0, 1.3, 4.6, 1.5, 4.5, 1.3, 4.7, 1.6, 3.3, 1.0, 4.6, 1.3, 3.9, 1.4, 3.5, 1.0, 4.2, 1.5, 4.0,
1.0, 4.7, 1.4, 3.6, 1.3, 4.4, 1.4, 4.5, 1.5, 4.1, 1.0, 4.5, 1.5, 3.9, 1.1, 4.8, 1.8, 4.0, 1.3, 4.9, 1.5, 4.7, 1.2, 4.3, 1.3,
4.4, 1.4, 4.8, 1.4, 5.0, 1.7, 4.5, 1.5, 3.5, 1.0, 3.8, 1.1, 3.7, 1.0, 3.9, 1.2, 5.1, 1.6, 4.5, 1.5, 4.5, 1.6, 4.7, 1.5, 4.4,
1.3, 4.1, 1.3, 4.0, 1.3, 4.4, 1.2, 4.6, 1.4, 4.0, 1.2, 3.3, 1.0, 4.2, 1.3, 4.2, 1.2, 4.2, 1.3, 4.3, 1.3, 3.0, 1.1, 4.1, 1.3,
6.0, 2.5, 5.1, 1.9, 5.9, 2.1, 5.6, 1.8, 5.8, 2.2, 6.6, 2.1, 4.5, 1.7, 6.3, 1.8, 5.8, 1.8, 6.1, 2.5, 5.1, 2.0, 5.3, 1.9, 5.5,
2.1, 5.0, 2.0, 5.1, 2.4, 5.3, 2.3, 5.5, 1.8, 6.7, 2.2, 6.9, 2.3, 5.0, 1.5, 5.7, 2.3, 4.9, 2.0, 6.7, 2.0, 4.9, 1.8, 5.7, 2.1,
6.0, 1.8, 4.8, 1.8, 4.9, 1.8, 5.6, 2.1, 5.8, 1.6, 6.1, 1.9, 6.4, 2.0, 5.6, 2.2, 5.1, 1.5, 5.6, 1.4, 6.1, 2.3, 5.6, 2.4, 5.5,
1.8, 4.8, 1.8, 5.4, 2.1, 5.6, 2.4, 5.1, 2.3, 5.1, 1.9, 5.9, 2.3, 5.7, 2.5, 5.2, 2.3, 5.0, 1.9, 5.2, 2.0, 5.4, 2.3, 5.1, 1.8,
];
let kmean: KMeans<f32, 8, _> = KMeans::new(&samples, 150, 2, EuclideanDistance);
let rnd = rand::rngs::StdRng::seed_from_u64(3);
let conf = KMeansConfig::build()
.random_generator(rnd)
.abort_strategy(AbortStrategy::NoImprovementForXIterations {
x: 5,
threshold: 0.0005f32,
abort_on_negative: true,
})
.build();
let res = kmean.kmeans_minibatch(30, 3, 100, KMeans::init_kmeanplusplus, &conf);
let should = KMeansShouldResult {
distsum: 31.751724,
sample_dims: 2,
assignments: vec![
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0,
0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2,
2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
],
centroid_distances: vec![
0.01267357, 0.01267357, 0.037013218, 0.0083339345, 0.01267357, 0.06493769, 0.005315068, 0.0083339345, 0.01267357,
0.03569244, 0.0083339345, 0.023994308, 0.040032074, 0.17305095, 0.08135281, 0.01361693, 0.042296212, 0.005315068,
0.052296188, 0.0009754318, 0.05965469, 0.01361693, 0.2300321, 0.09757918, 0.19097538, 0.023994308, 0.029277302,
0.0083339345, 0.01267357, 0.023994308, 0.023994308, 0.01361693, 0.03569244, 0.01267357, 0.0083339345, 0.08135281,
0.037013218, 0.040032074, 0.037013218, 0.0083339345, 0.029654715, 0.029654715, 0.037013218, 0.114560306, 0.19625838,
0.005315068, 0.023994308, 0.01267357, 0.0083339345, 0.01267357, 0.16893126, 0.075018674, 0.3967574, 0.08893235, 0.12545326,
0.041105583, 0.24284437, 1.0950203, 0.09154017, 0.16545418, 0.73588943, 0.043714777, 0.1880627, 0.16893126, 0.48719388,
0.0176275, 0.075018674, 0.13849738, 0.075018674, 0.20458452, 0.48719242, 0.08893235, 0.3967574, 0.17501816, 0.00023629011,
0.0176275, 0.25936627, 0.39401728, 0.075018674, 0.73588943, 0.29415, 0.4567587, 0.17154106, 0.3755859, 0.075018674,
0.12197524, 0.1958878, 0.010670955, 0.039367035, 0.08893235, 0.023714393, 0.09849672, 0.10197579, 1.0950203, 0.009801679,
0.02284512, 0.009801679, 0.00023629011, 1.7306727, 0.039367035, 0.4355868, 0.20323333, 0.1412727, 0.061272323, 0.09970424,
1.1500959, 0.18893182, 0.6500962, 0.12950772, 0.5397043, 0.18578245, 0.07146844, 0.0048020603, 0.28166473, 0.31597906,
0.12166494, 0.057154696, 1.3967625, 1.9475476, 0.5271919, 0.09813553, 0.39754698, 1.3716642, 0.45244873, 0.03303728,
0.27774292, 0.48719242, 0.45244873, 0.008919688, 0.26440942, 0.34440956, 0.7593119, 0.03146885, 0.4730368, 0.4110758,
0.39460605, 0.13656716, 0.057154696, 0.48719242, 0.020684395, 0.13656716, 0.25342983, 0.20323333, 0.20637095, 0.24323381,
0.17754751, 0.2991156, 0.10990014, 0.08578265, 0.24068421,
],
centroids: vec![4.297827, 1.3152173, 1.4716982, 0.28679252, 5.529412, 2.0372543],
};
assert_kmeans_result_eq(should, res);
}
}