1use std::cmp::Ordering;
2use std::fmt::Debug;
3
4use crate::k_means::{KMeansParams, KMeansValidParams};
5use crate::IncrKMeansError;
6use crate::{k_means::errors::KMeansError, KMeansInit};
7use linfa::{prelude::*, DatasetBase, Float};
8use linfa_nn::distance::{Distance, L2Dist};
9use ndarray::{Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix1, Ix2, Zip};
10use ndarray_rand::rand::{Rng, SeedableRng};
11use rand_xoshiro::Xoshiro256Plus;
12
13#[cfg(feature = "serde")]
14use serde_crate::{Deserialize, Serialize};
15
16#[cfg_attr(
17 feature = "serde",
18 derive(Serialize, Deserialize),
19 serde(crate = "serde_crate")
20)]
21#[derive(Clone, Debug, PartialEq)]
22pub struct KMeans<F: Float, D: Distance<F>> {
179 centroids: Array2<F>,
180 cluster_count: Array1<F>,
181 inertia: F,
182 dist_fn: D,
183}
184
185impl<F: Float> KMeans<F, L2Dist> {
186 pub fn params(nclusters: usize) -> KMeansParams<F, Xoshiro256Plus, L2Dist> {
187 KMeansParams::new(nclusters, Xoshiro256Plus::seed_from_u64(42), L2Dist)
188 }
189
190 pub fn params_with_rng<R: Rng>(nclusters: usize, rng: R) -> KMeansParams<F, R, L2Dist> {
191 KMeansParams::new(nclusters, rng, L2Dist)
192 }
193}
194
195impl<F: Float, D: Distance<F>> KMeans<F, D> {
196 pub fn params_with<R: Rng>(nclusters: usize, rng: R, dist_fn: D) -> KMeansParams<F, R, D> {
197 KMeansParams::new(nclusters, rng, dist_fn)
198 }
199
200 pub fn centroids(&self) -> &Array2<F> {
203 &self.centroids
204 }
205
206 pub fn cluster_count(&self) -> &Array1<F> {
208 &self.cluster_count
209 }
210
211 pub fn inertia(&self) -> F {
215 self.inertia
216 }
217}
218
219impl<F: Float, R: Rng + Clone, DA: Data<Elem = F>, T, D: Distance<F>>
220 Fit<ArrayBase<DA, Ix2>, T, KMeansError> for KMeansValidParams<F, R, D>
221{
222 type Object = KMeans<F, D>;
223
224 fn fit(
230 &self,
231 dataset: &DatasetBase<ArrayBase<DA, Ix2>, T>,
232 ) -> Result<Self::Object, KMeansError> {
233 let mut rng = self.rng().clone();
234 let observations = dataset.records().view();
235 let n_samples = dataset.nsamples();
236
237 let mut min_inertia = F::infinity();
238 let mut best_centroids = None;
239 let mut memberships = Array1::zeros(n_samples);
240 let mut dists = Array1::zeros(n_samples);
241
242 let n_runs = self.n_runs();
243
244 for _ in 0..n_runs {
245 let mut centroids =
246 self.init_method()
247 .run(self.dist_fn(), self.n_clusters(), observations, &mut rng);
248 let mut n_iter = 0;
249 let inertia = loop {
250 update_memberships_and_dists(
251 self.dist_fn(),
252 ¢roids,
253 &observations,
254 &mut memberships,
255 &mut dists,
256 );
257 let new_centroids = compute_centroids(¢roids, &observations, &memberships);
258 let distance = self
259 .dist_fn()
260 .distance(centroids.view(), new_centroids.view());
261 centroids = new_centroids;
262 n_iter += 1;
263 if distance < self.tolerance() || n_iter == self.max_n_iterations() {
264 break dists.sum();
265 }
266 };
267
268 if inertia < min_inertia {
272 min_inertia = inertia;
273 best_centroids = Some(centroids.clone());
274 }
275 }
276
277 match best_centroids {
278 Some(centroids) => {
279 let mut cluster_count = Array1::zeros(self.n_clusters());
280 memberships
281 .iter()
282 .for_each(|&c| cluster_count[c] += F::one());
283 Ok(KMeans {
284 centroids,
285 cluster_count,
286 inertia: min_inertia / F::cast(dataset.nsamples()),
287 dist_fn: self.dist_fn().clone(),
288 })
289 }
290 _ => Err(KMeansError::InertiaError),
291 }
292 }
293}
294
295impl<'a, F: Float + Debug, R: Rng + Clone, DA: Data<Elem = F>, T, D: 'a + Distance<F> + Debug>
296 FitWith<'a, ArrayBase<DA, Ix2>, T, IncrKMeansError<KMeans<F, D>>>
297 for KMeansValidParams<F, R, D>
298{
299 type ObjectIn = Option<KMeans<F, D>>;
300 type ObjectOut = KMeans<F, D>;
301
302 fn fit_with(
310 &self,
311 model: Self::ObjectIn,
312 dataset: &'a DatasetBase<ArrayBase<DA, Ix2>, T>,
313 ) -> Result<Self::ObjectOut, IncrKMeansError<Self::ObjectOut>> {
314 let observations = dataset.records().view();
315 let n_samples = dataset.nsamples();
316
317 let mut model = match model {
318 Some(model) => model,
319 None => {
320 let centroids = if let KMeansInit::Precomputed(centroids) = self.init_method() {
321 centroids.clone()
323 } else {
324 let mut rng = self.rng().clone();
325 let mut dists = Array1::zeros(n_samples);
326 (0..self.n_runs())
329 .map(|_| {
330 let centroids = self.init_method().run(
331 self.dist_fn(),
332 self.n_clusters(),
333 observations,
334 &mut rng,
335 );
336 update_min_dists(self.dist_fn(), ¢roids, &observations, &mut dists);
337 (centroids, dists.sum())
338 })
339 .min_by(|(_, d1), (_, d2)| {
340 if d1 < d2 {
341 Ordering::Less
342 } else {
343 Ordering::Greater
344 }
345 })
346 .unwrap()
347 .0
348 };
349 KMeans {
350 centroids,
351 cluster_count: Array1::zeros(self.n_clusters()),
352 inertia: F::zero(),
353 dist_fn: self.dist_fn().clone(),
354 }
355 }
356 };
357
358 let mut memberships = Array1::zeros(n_samples);
359 let mut dists = Array1::zeros(n_samples);
360 update_memberships_and_dists(
361 self.dist_fn(),
362 &model.centroids,
363 &observations,
364 &mut memberships,
365 &mut dists,
366 );
367 let new_centroids = compute_centroids_incremental(
368 &observations,
369 &memberships,
370 &model.centroids,
371 &mut model.cluster_count,
372 );
373 model.inertia = dists.sum() / F::cast(n_samples);
374 let dist = self
375 .dist_fn()
376 .distance(model.centroids.view(), new_centroids.view());
377 model.centroids = new_centroids;
378
379 if dist < self.tolerance() {
380 Ok(model)
381 } else {
382 Err(IncrKMeansError::NotConverged(model))
383 }
384 }
385}
386
387impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> Transformer<&ArrayBase<DA, Ix2>, Array1<F>>
388 for KMeans<F, D>
389{
390 fn transform(&self, observations: &ArrayBase<DA, Ix2>) -> Array1<F> {
393 let mut dists = Array1::zeros(observations.nrows());
394 update_min_dists(
395 &self.dist_fn,
396 &self.centroids,
397 &observations.view(),
398 &mut dists,
399 );
400 dists
401 }
402}
403
404impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> PredictInplace<ArrayBase<DA, Ix2>, Array1<usize>>
405 for KMeans<F, D>
406{
407 fn predict_inplace(&self, observations: &ArrayBase<DA, Ix2>, memberships: &mut Array1<usize>) {
413 assert_eq!(
414 observations.nrows(),
415 memberships.len(),
416 "The number of data points must match the number of memberships."
417 );
418
419 update_cluster_memberships(
420 &self.dist_fn,
421 &self.centroids,
422 &observations.view(),
423 memberships,
424 );
425 }
426
427 fn default_target(&self, x: &ArrayBase<DA, Ix2>) -> Array1<usize> {
428 Array1::zeros(x.nrows())
429 }
430}
431
432impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> PredictInplace<ArrayBase<DA, Ix1>, usize>
433 for KMeans<F, D>
434{
435 fn predict_inplace(&self, observation: &ArrayBase<DA, Ix1>, membership: &mut usize) {
440 *membership = closest_centroid(&self.dist_fn, &self.centroids, observation).0;
441 }
442
443 fn default_target(&self, _x: &ArrayBase<DA, Ix1>) -> usize {
444 0
445 }
446}
447
448fn compute_centroids<F: Float>(
455 old_centroids: &Array2<F>,
456 observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
458 cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
460) -> Array2<F> {
461 let n_clusters = old_centroids.nrows();
462 let mut counts: Array1<usize> = Array1::ones(n_clusters);
463 let mut centroids = Array2::zeros((n_clusters, observations.ncols()));
464
465 Zip::from(observations.rows())
466 .and(cluster_memberships)
467 .for_each(|observation, &cluster_membership| {
468 let mut centroid = centroids.row_mut(cluster_membership);
469 centroid += &observation;
470 counts[cluster_membership] += 1;
471 });
472 centroids += old_centroids;
474
475 Zip::from(centroids.rows_mut())
476 .and(&counts)
477 .for_each(|mut centroid, &cnt| centroid /= F::cast(cnt));
478 centroids
479}
480
481fn compute_centroids_incremental<F: Float>(
485 observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
486 cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
487 old_centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
488 counts: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
489) -> Array2<F> {
490 let mut centroids = old_centroids.to_owned();
491 Zip::from(observations.rows())
493 .and(cluster_memberships)
494 .for_each(|obs, &c| {
495 counts[c] += F::one();
499 let shift = (&obs - ¢roids.row(c)) / counts[c];
500 let mut centroid = centroids.row_mut(c);
501 centroid += &shift;
502 });
503 centroids
504}
505
506pub(crate) fn update_cluster_memberships<F: Float, D: Distance<F>>(
508 dist_fn: &D,
509 centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
510 observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
511 cluster_memberships: &mut ArrayBase<impl DataMut<Elem = usize>, Ix1>,
512) {
513 Zip::from(observations.axis_iter(Axis(0)))
514 .and(cluster_memberships)
515 .par_for_each(|observation, cluster_membership| {
516 *cluster_membership = closest_centroid(dist_fn, centroids, &observation).0
517 });
518}
519
520pub(crate) fn update_min_dists<F: Float, D: Distance<F>>(
522 dist_fn: &D,
523 centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
524 observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
525 dists: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
526) {
527 Zip::from(observations.axis_iter(Axis(0)))
528 .and(dists)
529 .par_for_each(|observation, dist| {
530 *dist = closest_centroid(dist_fn, centroids, &observation).1
531 });
532}
533
534pub(crate) fn update_memberships_and_dists<F: Float, D: Distance<F>>(
536 dist_fn: &D,
537 centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
538 observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
539 cluster_memberships: &mut ArrayBase<impl DataMut<Elem = usize>, Ix1>,
540 dists: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
541) {
542 Zip::from(observations.axis_iter(Axis(0)))
543 .and(cluster_memberships)
544 .and(dists)
545 .par_for_each(|observation, cluster_membership, dist| {
546 let (m, d) = closest_centroid(dist_fn, centroids, &observation);
547 *cluster_membership = m;
548 *dist = d;
549 });
550}
551
552pub(crate) fn closest_centroid<F: Float, D: Distance<F>>(
555 dist_fn: &D,
556 centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
558 observation: &ArrayBase<impl Data<Elem = F>, Ix1>,
560) -> (usize, F) {
561 let iterator = centroids.rows().into_iter();
562
563 let first_centroid = centroids.row(0);
564 let (mut closest_index, mut minimum_distance) = (
565 0,
566 dist_fn.rdistance(first_centroid.view(), observation.view()),
567 );
568
569 for (centroid_index, centroid) in iterator.enumerate() {
570 let distance = dist_fn.rdistance(centroid.view(), observation.view());
571 if distance < minimum_distance {
572 closest_index = centroid_index;
573 minimum_distance = distance;
574 }
575 }
576 (closest_index, minimum_distance)
577}
578
579#[cfg(test)]
580mod tests {
581 use super::super::KMeansInit;
582 use super::*;
583 use crate::KMeansParamsError;
584 use approx::assert_abs_diff_eq;
585 use linfa_nn::distance::L1Dist;
586 use ndarray::{array, concatenate, Array, Array1, Array2, Axis};
587 use ndarray_rand::rand::prelude::ThreadRng;
588 use ndarray_rand::rand::SeedableRng;
589 use ndarray_rand::rand_distr::Uniform;
590 use ndarray_rand::RandomExt;
591
592 #[test]
593 fn autotraits() {
594 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
595 has_autotraits::<KMeans<f64, L2Dist>>();
596 has_autotraits::<KMeansParamsError>();
597 has_autotraits::<KMeansError>();
598 has_autotraits::<IncrKMeansError<String>>();
599 }
600
601 fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
602 let mut y = Array2::zeros(x.dim());
603 Zip::from(&mut y).and(x).for_each(|yi, &xi| {
604 if xi < 0.4 {
605 *yi = xi * xi;
606 } else if (0.4..0.8).contains(&xi) {
607 *yi = 3. * xi + 1.;
608 } else {
609 *yi = f64::sin(10. * xi);
610 }
611 });
612 y
613 }
614
615 macro_rules! calc_inertia {
616 ($dist:expr, $centroids:expr, $obs:expr, $memberships:expr) => {
617 $obs.rows()
618 .into_iter()
619 .zip($memberships.iter())
620 .map(|(row, &c)| $dist.rdistance(row.view(), $centroids.row(c).view()))
621 .sum::<f64>()
622 };
623 }
624
625 macro_rules! calc_memberships {
626 ($dist:expr, $centroids:expr, $obs:expr) => {{
627 let mut memberships = Array1::zeros($obs.nrows());
628 update_cluster_memberships(&$dist, &$centroids, &$obs, &mut memberships);
629 memberships
630 }};
631 }
632
633 #[test]
634 fn test_min_dists() {
635 let centroids = array![[0.0, 1.0], [40.0, 10.0]];
636 let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
637 let mut dists = Array1::zeros(observations.nrows());
638
639 update_min_dists(&L2Dist, ¢roids, &observations, &mut dists);
640 assert_abs_diff_eq!(dists, array![18.0, 5.0, 250.0]);
641 update_min_dists(&L1Dist, ¢roids, &observations, &mut dists);
642 assert_abs_diff_eq!(dists, array![6.0, 3.0, 20.0]);
643 }
644
645 fn test_n_runs<D: Distance<f64>>(dist_fn: D) {
646 let mut rng = Xoshiro256Plus::seed_from_u64(42);
647 let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
648 let yt = function_test_1d(&xt);
649 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
650
651 for init in &[
652 KMeansInit::Random,
653 KMeansInit::KMeansPlusPlus,
654 KMeansInit::KMeansPara,
655 ] {
656 let dataset = DatasetBase::from(data.clone());
658 let model = KMeans::params_with(3, rng.clone(), dist_fn.clone())
659 .n_runs(1)
660 .init_method(init.clone())
661 .fit(&dataset)
662 .expect("KMeans fitted");
663 let clusters = model.predict(dataset);
664 let inertia = calc_inertia!(
665 dist_fn,
666 model.centroids(),
667 clusters.records,
668 clusters.targets
669 );
670 let total_dist = model.transform(&clusters.records.view()).sum();
671 assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5);
672
673 let single_cluster: usize = model.predict(&data.row(0));
674 assert_abs_diff_eq!(single_cluster, clusters.targets[0]);
675
676 let dataset2 = DatasetBase::from(clusters.records().clone());
678 let model2 = KMeans::params_with(3, rng.clone(), dist_fn.clone())
679 .init_method(init.clone())
680 .fit(&dataset2)
681 .expect("KMeans fitted");
682 let clusters2 = model2.predict(dataset2);
683 let inertia2 = calc_inertia!(
684 dist_fn,
685 model2.centroids(),
686 clusters2.records,
687 clusters2.targets
688 );
689 let total_dist2 = model2.transform(&clusters2.records.view()).sum();
690 assert_abs_diff_eq!(inertia2, total_dist2, epsilon = 1e-5);
691
692 if *init == KMeansInit::Random {
694 assert!(inertia2 <= inertia);
695 }
696 }
697 }
698
699 #[test]
700 fn test_n_runs_l2dist() {
701 test_n_runs(L2Dist);
702 }
703
704 #[test]
705 fn test_n_runs_l1dist() {
706 test_n_runs(L1Dist);
707 }
708
709 #[test]
710 fn compute_centroids_works() {
711 let cluster_size = 100;
712 let n_features = 4;
713
714 let cluster_1: Array2<f64> =
716 Array::random((cluster_size, n_features), Uniform::new(-100., 100.));
717 let memberships_1 = Array1::zeros(cluster_size);
718 let expected_centroid_1 = cluster_1.sum_axis(Axis(0)) / (cluster_size + 1) as f64;
719
720 let cluster_2: Array2<f64> =
721 Array::random((cluster_size, n_features), Uniform::new(-100., 100.));
722 let memberships_2 = Array1::ones(cluster_size);
723 let expected_centroid_2 = cluster_2.sum_axis(Axis(0)) / (cluster_size + 1) as f64;
724
725 let observations = concatenate(Axis(0), &[cluster_1.view(), cluster_2.view()]).unwrap();
727 let memberships =
728 concatenate(Axis(0), &[memberships_1.view(), memberships_2.view()]).unwrap();
729
730 let old_centroids = Array2::zeros((2, n_features));
732 let centroids = compute_centroids(&old_centroids, &observations, &memberships);
733 assert_abs_diff_eq!(
734 centroids.index_axis(Axis(0), 0),
735 expected_centroid_1,
736 epsilon = 1e-5
737 );
738 assert_abs_diff_eq!(
739 centroids.index_axis(Axis(0), 1),
740 expected_centroid_2,
741 epsilon = 1e-5
742 );
743
744 assert_eq!(centroids.len_of(Axis(0)), 2);
745 }
746
747 #[test]
748 fn test_compute_extra_centroids() {
749 let observations = array![[1.0, 2.0]];
750 let memberships = array![0];
751 let old_centroids = Array2::ones((2, 2));
753 let centroids = compute_centroids(&old_centroids, &observations, &memberships);
754 assert_abs_diff_eq!(centroids, array![[1.0, 1.5], [1.0, 1.0]]);
755 }
756
757 #[test]
758 fn nothing_is_closer_than_self() {
760 let n_centroids = 20;
761 let n_features = 5;
762 let mut rng = Xoshiro256Plus::seed_from_u64(42);
763 let centroids: Array2<f64> = Array::random_using(
764 (n_centroids, n_features),
765 Uniform::new(-100., 100.),
766 &mut rng,
767 );
768
769 let expected_memberships = (0..n_centroids).collect::<Array1<_>>();
770 assert_eq!(
771 calc_memberships!(L2Dist, centroids, centroids),
772 expected_memberships
773 );
774 assert_eq!(
775 calc_memberships!(L1Dist, centroids, centroids),
776 expected_memberships
777 );
778 }
779
780 #[test]
781 fn oracle_test_for_closest_centroid() {
782 let centroids = array![[0., 0.], [1., 2.], [20., 0.], [0., 20.],];
783 let observations = array![[1., 0.6], [20., 2.], [20., 0.], [7., 20.],];
784 let l2_memberships = array![0, 2, 2, 3];
785 let l1_memberships = array![1, 2, 2, 3];
786
787 assert_eq!(
788 calc_memberships!(L2Dist, centroids, observations),
789 l2_memberships
790 );
791 assert_eq!(
792 calc_memberships!(L1Dist, centroids, observations),
793 l1_memberships
794 );
795 }
796
797 #[test]
798 fn test_compute_centroids_incremental() {
799 let observations = array![[-1.0, -3.0], [0., 0.], [3., 5.], [5., 5.]];
800 let memberships = array![0, 0, 1, 1];
801 let centroids = array![[-1., -1.], [3., 4.], [7., 8.]];
802 let mut counts = array![3.0, 0.0, 1.0];
803 let centroids =
804 compute_centroids_incremental(&observations, &memberships, ¢roids, &mut counts);
805
806 assert_abs_diff_eq!(centroids, array![[-4. / 5., -6. / 5.], [4., 5.], [7., 8.]]);
807 assert_abs_diff_eq!(counts, array![5., 2., 1.]);
808 }
809
810 #[test]
811 fn test_incremental_kmeans() {
812 let dataset1 = DatasetBase::from(array![[-1.0, -3.0], [0., 0.], [3., 5.], [5., 5.]]);
813 let dataset2 = DatasetBase::from(array![[-5.0, -5.0], [0., 0.], [10., 10.]]);
814 let model = KMeans {
815 centroids: array![[-1., -1.], [3., 4.], [7., 8.]],
816 cluster_count: array![0., 0., 0.],
817 inertia: 0.0,
818 dist_fn: L2Dist,
819 };
820 let rng = Xoshiro256Plus::seed_from_u64(45);
821 let params = KMeans::params_with_rng(3, rng).tolerance(100.0);
822
823 let model = params.fit_with(Some(model), &dataset1).unwrap();
825 assert_abs_diff_eq!(model.centroids(), &array![[-0.5, -1.5], [4., 5.], [7., 8.]]);
826
827 let model = params.fit_with(Some(model), &dataset2).unwrap();
828 assert_abs_diff_eq!(
829 model.centroids(),
830 &array![[-6. / 4., -8. / 4.], [4., 5.], [10., 10.]]
831 );
832 }
833
834 #[test]
835 fn test_tolerance() {
836 let rng = Xoshiro256Plus::seed_from_u64(45);
837 let params = KMeans::params_with_rng(1, rng)
841 .tolerance(8.5)
842 .init_method(KMeansInit::Precomputed(array![[0., 0.]]));
843 let data = DatasetBase::from(array![[1., 1.], [11., 11.]]);
844 assert!(params.fit_with(None, &data).is_ok());
845 }
846
847 #[test]
848 fn test_max_n_iterations() {
849 let mut rng = Xoshiro256Plus::seed_from_u64(42);
850 let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
851 let yt = function_test_1d(&xt);
852 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
853 let dataset = DatasetBase::from(data.clone());
854 let _model = KMeans::params_with(6, rng.clone(), L2Dist)
857 .n_runs(1)
858 .max_n_iterations(5)
859 .init_method(KMeansInit::Random)
860 .fit(&dataset)
861 .expect("KMeans fitted");
862 }
863
864 fn fittable<T: Fit<Array2<f64>, (), KMeansError>>(_: T) {}
865 #[test]
866 fn thread_rng_fittable() {
867 fittable(KMeans::params_with_rng(1, ThreadRng::default()));
868 }
869}