1use std::{collections::HashSet, fmt};
7
8use diskann_utils::{
9 strided::StridedView,
10 views::{MatrixView, MutMatrixView},
11};
12use diskann_wide::{SIMDMulAdd, SIMDPartialOrd, SIMDSelect, SIMDVector};
13use rand::{
14 RngCore,
15 distr::{Distribution, Uniform},
16};
17use thiserror::Error;
18
19use super::common::square_norm;
20use crate::multi_vector::{BlockTransposed, BlockTransposedRef};
21
22pub(crate) trait MicroKernel {
35 type Intermediate;
37
38 type RollingSum: Default + Copy;
41
42 type Splat: Copy;
44
45 fn splat(x: f32) -> Self::Splat;
47
48 unsafe fn accum_full(block: *const f32, this: &[f32]) -> Self::Intermediate;
58
59 fn finish(
61 intermediate: Self::Intermediate,
62 splat: Self::Splat,
63 rolling_sum: Self::RollingSum,
64 norms: &[f32],
65 mins: &mut [f32],
66 ) -> Self::RollingSum;
67
68 fn finish_last(
70 intermediate: Self::Intermediate,
71 splat: Self::Splat,
72 rolling_sum: Self::RollingSum,
73 norms: &[f32],
74 mins: &mut [f32],
75 first: usize,
76 ) -> Self::RollingSum;
77
78 fn complete_sum(x: Self::RollingSum) -> f64;
80}
81
82diskann_wide::alias!(f32s = f32x8);
83
84impl MicroKernel for BlockTransposed<f32, 16> {
85 type Intermediate = (f32s, f32s);
87 type RollingSum = f64;
88 type Splat = f32s;
89
90 fn splat(x: f32) -> Self::Splat {
91 Self::Splat::splat(diskann_wide::ARCH, x)
92 }
93
94 #[inline(always)]
95 unsafe fn accum_full(block_ptr: *const f32, this: &[f32]) -> Self::Intermediate {
96 let mut s0 = f32s::default(diskann_wide::ARCH);
97 let mut s1 = f32s::default(diskann_wide::ARCH);
98
99 this.iter().enumerate().for_each(|(i, b)| {
100 let b = f32s::splat(diskann_wide::ARCH, *b);
101
102 let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i)) };
106 s0 = a.mul_add_simd(b, s0);
107
108 let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i + 8)) };
110 s1 = a.mul_add_simd(b, s1);
111 });
112
113 let negative = f32s::splat(diskann_wide::ARCH, -2.0);
115 (s0 * negative, s1 * negative)
116 }
117
118 #[inline(always)]
119 fn finish(
120 intermediate: Self::Intermediate,
121 splat: Self::Splat,
122 rolling_sum: Self::RollingSum,
123 norms: &[f32],
124 mins: &mut [f32],
125 ) -> Self::RollingSum {
126 assert_eq!(norms.len(), 16);
127 assert_eq!(mins.len(), 16);
128
129 let norms0 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr()) };
131 let norms1 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr().add(8)) };
133
134 let distances0 = norms0 + splat + intermediate.0;
135 let distances1 = norms1 + splat + intermediate.1;
136
137 let current_distances0 = unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr()) };
139 let current_distances1 =
141 unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr().add(8)) };
142
143 let mask0 = distances0.lt_simd(current_distances0);
144 let mask1 = distances1.lt_simd(current_distances1);
145
146 let current_distances0 = mask0.select(distances0, current_distances0);
147 let current_distances1 = mask1.select(distances1, current_distances1);
148
149 unsafe { current_distances0.store_simd(mins.as_mut_ptr()) };
151 unsafe { current_distances1.store_simd(mins.as_mut_ptr().add(8)) };
153
154 rolling_sum
155 + std::iter::zip(
156 current_distances0.to_array().iter(),
157 current_distances1.to_array().iter(),
158 )
159 .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
160 .sum::<f64>()
161 }
162
163 #[inline(always)]
164 fn finish_last(
165 intermediate: Self::Intermediate,
166 splat: Self::Splat,
167 rolling_sum: Self::RollingSum,
168 norms: &[f32],
169 mins: &mut [f32],
170 first: usize,
171 ) -> Self::RollingSum {
172 assert_eq!(norms.len(), first);
174 assert_eq!(mins.len(), first);
176
177 let lo = first.min(8);
178 let hi = first - lo;
179
180 let norms0 = unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr(), lo) };
183 let norms1 = if hi == 0 {
184 f32s::default(diskann_wide::ARCH)
185 } else {
186 unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr().add(8), hi) }
189 };
190
191 let distances0 = norms0 + splat + intermediate.0;
192 let distances1 = norms1 + splat + intermediate.1;
193
194 let current_distances0 =
196 unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr(), lo) };
197 let current_distances1 = if hi == 0 {
198 f32s::default(diskann_wide::ARCH)
199 } else {
200 unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr().add(8), hi) }
202 };
203
204 let mask0 = distances0.lt_simd(current_distances0);
205 let mask1 = distances1.lt_simd(current_distances1);
206
207 let current_distances0 = mask0.select(distances0, current_distances0);
208 let current_distances1 = mask1.select(distances1, current_distances1);
209
210 unsafe { current_distances0.store_simd_first(mins.as_mut_ptr(), lo) };
213 if hi != 0 {
214 unsafe { current_distances1.store_simd_first(mins.as_mut_ptr().add(8), hi) };
218 }
219
220 rolling_sum
221 + std::iter::zip(
222 current_distances0.to_array().iter(),
223 current_distances1.to_array().iter(),
224 )
225 .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
226 .sum::<f64>()
227 }
228
229 fn complete_sum(x: Self::RollingSum) -> f64 {
230 x
231 }
232}
233
234fn update_distances<const N: usize>(
239 square_distances: &mut [f32],
240 transpose: BlockTransposedRef<'_, f32, N>,
241 norms: &[f32],
242 this: &[f32],
243 this_square_norm: f32,
244) -> f64
245where
246 BlockTransposed<f32, N>: MicroKernel,
247{
248 assert_eq!(
251 this.len(),
252 transpose.ncols(),
253 "new point and dataset must have the same dimension",
254 );
255 assert_eq!(
257 square_distances.len(),
258 transpose.nrows(),
259 "distances buffer and dataset must have the same length",
260 );
261 assert_eq!(
263 norms.len(),
264 transpose.nrows(),
265 "norms and dataset must have the same length",
266 );
267
268 let splat = BlockTransposed::<f32, N>::splat(this_square_norm);
269 let mut rolling_sum = <BlockTransposed<f32, N> as MicroKernel>::RollingSum::default();
270
271 let iter =
272 std::iter::zip(norms.chunks_exact(N), square_distances.chunks_exact_mut(N)).enumerate();
273 iter.for_each(|(block, (these_norms, these_distances))| {
274 debug_assert!(block < transpose.num_blocks());
275 let base = unsafe { transpose.block_ptr_unchecked(block) };
279
280 let intermediate = unsafe { BlockTransposed::<f32, N>::accum_full(base, this) };
283
284 rolling_sum = BlockTransposed::<f32, N>::finish(
285 intermediate,
286 splat,
287 rolling_sum,
288 these_norms,
289 these_distances,
290 );
291 });
292
293 let remainder = transpose.remainder();
295 if remainder != 0 {
296 let base = unsafe { transpose.block_ptr_unchecked(transpose.full_blocks()) };
299
300 let intermediate = unsafe { BlockTransposed::<f32, N>::accum_full(base, this) };
305
306 let start = N * transpose.full_blocks();
307 rolling_sum = BlockTransposed::<f32, N>::finish_last(
308 intermediate,
309 splat,
310 rolling_sum,
311 &norms[start..],
312 &mut square_distances[start..],
313 remainder,
314 );
315 }
316
317 BlockTransposed::<f32, N>::complete_sum(rolling_sum)
318}
319
320#[derive(Debug, Clone, Copy, PartialEq)]
321pub enum FailureReason {
322 DatasetTooSmall,
325 InsufficientDiversity,
327 SawInfinity,
329}
330
331impl FailureReason {
332 pub fn is_numerically_recoverable(self) -> bool {
333 match self {
334 Self::DatasetTooSmall | Self::InsufficientDiversity => true,
337
338 Self::SawInfinity => false,
341 }
342 }
343}
344
345impl fmt::Display for FailureReason {
346 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
347 let reason: &str = match self {
348 Self::DatasetTooSmall => "dataset does not have enough points",
349 Self::InsufficientDiversity => "dataset is insufficiently diverse",
350 Self::SawInfinity => "a value of infinity or NaN was observed",
351 };
352 f.write_str(reason)
353 }
354}
355
356#[derive(Debug, Clone, Copy, Error)]
357#[error("only populated {selected} of {expected} points because {reason}")]
358pub struct KMeansPlusPlusError {
359 pub selected: usize,
361 pub expected: usize,
363 pub reason: FailureReason,
365}
366
367impl KMeansPlusPlusError {
368 fn new(selected: usize, expected: usize, reason: FailureReason) -> Self {
369 Self {
370 selected,
371 expected,
372 reason,
373 }
374 }
375
376 pub fn is_numerically_recoverable(&self) -> bool {
377 self.reason.is_numerically_recoverable() && self.selected > 0
378 }
379}
380
381pub(crate) fn kmeans_plusplus_into_inner<const N: usize>(
382 mut points: MutMatrixView<'_, f32>,
383 data: StridedView<'_, f32>,
384 transpose: BlockTransposedRef<'_, f32, N>,
385 norms: &[f32],
386 rng: &mut dyn RngCore,
387) -> Result<(), KMeansPlusPlusError>
388where
389 BlockTransposed<f32, N>: MicroKernel,
390{
391 assert_eq!(norms.len(), data.nrows());
392 assert_eq!(transpose.nrows(), data.nrows());
393 assert_eq!(transpose.ncols(), data.ncols());
394 assert_eq!(points.ncols(), data.ncols());
395
396 points.as_mut_slice().fill(0.0);
398 let expected = points.nrows();
399
400 let all_rows = match Uniform::new(0, data.nrows()) {
405 Ok(dist) => dist,
406 Err(_) => {
407 return if expected == 0 {
409 Ok(())
410 } else {
411 Err(KMeansPlusPlusError::new(
412 0,
413 expected,
414 FailureReason::DatasetTooSmall,
415 ))
416 };
417 }
418 };
419
420 let mut min_distances: Vec<f32> = vec![f32::INFINITY; data.nrows()];
421 let mut picked = HashSet::with_capacity(expected);
422
423 let mut previous_square_norm = {
425 let i = all_rows.sample(rng);
426 points.row_mut(0).copy_from_slice(data.row(i));
427 picked.insert(i);
428 norms[i]
429 };
430
431 let mut selected = 1;
432 for current in 1..expected.min(data.nrows()) {
433 let last = points.row(current - 1);
434 let s = update_distances(
435 &mut min_distances,
436 transpose,
437 norms,
438 last,
439 previous_square_norm,
440 );
441
442 match Uniform::<f64>::new(0.0, s) {
446 Ok(distribution) => {
447 let threshold = distribution.sample(rng);
448 let mut rolling_sum: f64 = 0.0;
449 for (i, d) in min_distances.iter().enumerate() {
450 rolling_sum += <f32 as Into<f64>>::into(*d);
451 if rolling_sum >= threshold && (*d > 0.0) && !picked.contains(&i) {
452 points.row_mut(current).clone_from_slice(data.row(i));
455 picked.insert(i);
456 previous_square_norm = norms[i];
457 selected = current + 1;
458 break;
459 }
460 }
461 }
462 Err(rand::distr::uniform::Error::EmptyRange) => {}
466 Err(rand::distr::uniform::Error::NonFinite) => {
468 return Err(KMeansPlusPlusError::new(
469 selected,
470 expected,
471 FailureReason::SawInfinity,
472 ));
473 }
474 }
475
476 if selected != (current + 1) {
480 return Err(KMeansPlusPlusError::new(
481 selected,
482 expected,
483 FailureReason::InsufficientDiversity,
484 ));
485 }
486 }
487
488 if selected != expected {
490 Err(KMeansPlusPlusError::new(
491 selected,
492 expected,
493 FailureReason::DatasetTooSmall,
494 ))
495 } else {
496 Ok(())
497 }
498}
499
500pub fn kmeans_plusplus_into(
501 centers: MutMatrixView<'_, f32>,
502 data: MatrixView<'_, f32>,
503 rng: &mut dyn RngCore,
504) -> Result<(), KMeansPlusPlusError> {
505 assert_eq!(
506 centers.ncols(),
507 data.ncols(),
508 "centers output matrix should have the same dimensionality as the dataset"
509 );
510
511 const GROUPSIZE: usize = 16;
512 let mut norms: Vec<f32> = vec![0.0; data.nrows()];
513
514 for (n, d) in std::iter::zip(norms.iter_mut(), data.row_iter()) {
515 *n = square_norm(d);
516 }
517
518 let transpose = BlockTransposed::<f32, GROUPSIZE>::from_matrix_view(data);
519 kmeans_plusplus_into_inner(centers, data.into(), transpose.as_view(), &norms, rng)
520}
521
522#[cfg(test)]
523mod tests {
524 use diskann_utils::{lazy_format, views::Matrix};
525 use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
526 use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
527
528 use super::*;
529 use crate::utils;
530
531 fn is_in(needle: &[f32], haystack: MatrixView<'_, f32>) -> bool {
532 assert_eq!(needle.len(), haystack.ncols());
533 haystack.row_iter().any(|row| row == needle)
534 }
535
536 fn check_post_conditions(
537 centers: MatrixView<'_, f32>,
538 data: MatrixView<'_, f32>,
539 err: &KMeansPlusPlusError,
540 ) {
541 assert_eq!(err.expected, centers.nrows());
542 assert!(err.expected > err.selected);
543 for i in 0..err.selected {
544 assert!(is_in(centers.row(i), data.as_view()));
545 }
546 for i in err.selected..centers.nrows() {
547 assert!(centers.row(i).iter().all(|j| *j == 0.0));
548 }
549 }
550
551 #[test]
556 fn test_error_display() {
557 assert_eq!(
558 format!("{}", FailureReason::DatasetTooSmall),
559 "dataset does not have enough points"
560 );
561
562 assert_eq!(
563 format!("{}", FailureReason::InsufficientDiversity),
564 "dataset is insufficiently diverse"
565 );
566
567 assert_eq!(
568 format!("{}", FailureReason::SawInfinity),
569 "a value of infinity or NaN was observed"
570 );
571 }
572
573 fn set_default_values(mut x: MutMatrixView<'_, f32>) {
585 for (i, row) in x.row_iter_mut().enumerate() {
586 for (j, r) in row.iter_mut().enumerate() {
587 *r = (i + j) as f32;
588 }
589 }
590 }
591
592 fn test_update_distances_impl<const N: usize, R>(num_points: usize, dim: usize, rng: &mut R)
601 where
602 BlockTransposed<f32, N>: MicroKernel,
603 R: Rng,
604 {
605 let context = lazy_format!(
606 "setup: N = {}, num_points = {}, dim = {}",
607 N,
608 num_points,
609 dim
610 );
611
612 let mut data = Matrix::<f32>::new(0.0, num_points, dim);
613 set_default_values(data.as_mut_view());
614
615 let square_norms: Vec<f32> = data.row_iter().map(square_norm).collect();
616
617 let num_samples = 3;
619 let mut samples = Matrix::<f32>::new(0.0, num_samples, dim);
620 let mut distances = vec![f32::INFINITY; num_points];
621 let distribution = Uniform::<u32>::new(0, (num_points + dim) as u32).unwrap();
622 let transpose = BlockTransposed::<f32, N>::from_matrix_view(data.as_view());
623
624 let mut last_residual = f64::INFINITY;
625 for i in 0..num_samples {
626 {
628 let row = samples.row_mut(i);
629 row.iter_mut().for_each(|r| {
630 *r = distribution.sample(rng) as f32;
631 });
632 }
633 let row = samples.row(i);
634 let norm = square_norm(row);
635
636 let residual = update_distances(
637 &mut distances,
638 transpose.as_view(),
639 &square_norms,
640 row,
641 norm,
642 );
643
644 for (n, (d, data)) in std::iter::zip(distances.iter(), data.row_iter()).enumerate() {
646 let mut min_distance = f32::INFINITY;
647 for j in 0..=i {
648 let distance = SquaredL2::evaluate(samples.row(j), data);
649 min_distance = min_distance.min(distance);
650 }
651 assert_eq!(
652 min_distance, *d,
653 "failed on row {n} on iteration {i}. {}",
654 context
655 );
656 }
657
658 assert_eq!(
660 residual,
661 distances.iter().sum::<f32>() as f64,
662 "residual sum failed on iteration {i} - {}",
663 context
664 );
665
666 assert!(
668 residual <= last_residual,
669 "residual check failed on iteration {}, last = {}, this = {} - {}",
670 i,
671 last_residual,
672 residual,
673 context
674 );
675
676 last_residual = residual;
677 }
678 }
679
680 #[test]
691 fn test_update_distances() {
692 let mut rng = StdRng::seed_from_u64(0x56c94b53c73e4fd9);
693 for num_points in 0..48 {
694 #[cfg(miri)]
695 if num_points % 7 != 0 {
696 continue;
697 }
698
699 for dim in 1..4 {
700 test_update_distances_impl(num_points, dim, &mut rng);
701 }
702 }
703 }
704
705 #[cfg(not(miri))]
712 fn sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
713 let repeats_per_center = 3;
714 let context = lazy_format!(
715 "dim = {}, ncenters = {}, repeats_per_center = {}",
716 dim,
717 ncenters,
718 repeats_per_center
719 );
720
721 let ndata = repeats_per_center * ncenters;
722 let mut values: Vec<f32> = (0..ncenters)
723 .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
724 .collect();
725 assert_eq!(values.len(), ndata);
726
727 values.shuffle(rng);
728 let mut data = Matrix::new(0.0, ndata, dim);
729 for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
730 r.fill(*v);
731 }
732
733 let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
734 kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
735
736 let mut seen = HashSet::<usize>::new();
738 for c in centers.row_iter() {
739 let first = c[0];
740 assert!(c.iter().all(|i| *i == first));
741
742 let v: usize = first.round() as usize;
743 assert_eq!(v as f32, first, "conversion was not lossless - {}", context);
744
745 if !seen.insert(v) {
746 panic!("value {first} seen more than oncex - {}", context);
747 }
748 }
749 assert_eq!(
750 seen.len(),
751 ncenters,
752 "not all points were seen - {}",
753 context
754 );
755 }
756
757 #[test]
758 #[cfg(not(miri))]
759 fn sanity_check() {
760 let dims = [1, 16];
761 let ncenters = [1, 5, 20, 255];
762 let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
763
764 for ncenters in ncenters {
765 for dim in dims {
766 sanity_check_impl(ncenters, dim, &mut rng);
767 }
768 }
769 }
770
771 #[cfg(not(miri))]
774 fn fuzzy_sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
775 let repeats_per_center = 3;
776
777 let spreading_multiplier: usize = 16;
779 let perturbation_distribution = Uniform::new(-0.125, 0.125).unwrap();
781
782 let context = lazy_format!(
783 "dim = {}, ncenters = {}, repeats_per_center = {}, multiplier = {}",
784 dim,
785 ncenters,
786 repeats_per_center,
787 spreading_multiplier,
788 );
789
790 let ndata = repeats_per_center * ncenters;
791 let mut values: Vec<f32> = (0..ncenters)
792 .flat_map(|i| {
793 let v: Vec<f32> = (0..repeats_per_center)
795 .map(|_| {
796 (spreading_multiplier * i) as f32 + perturbation_distribution.sample(rng)
797 })
798 .collect();
799
800 v.into_iter()
801 })
802 .collect();
803 assert_eq!(values.len(), ndata);
804
805 values.shuffle(rng);
806 let mut data = Matrix::new(0.0, ndata, dim);
807 for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
808 r.fill(*v);
809 }
810
811 let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
812 kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
813
814 let mut seen = HashSet::<usize>::new();
816 for (i, c) in centers.row_iter().enumerate() {
817 let first = c[0];
818 let v: usize = first.round() as usize;
819 assert_eq!(
820 v % spreading_multiplier,
821 0,
822 "expected row value to be close to a multiple of the spreading multiplier, \
823 instead got {} - {}",
824 v,
825 context
826 );
827 seen.insert(v);
828
829 let mut found = false;
831 for r in data.row_iter() {
832 if r == c {
833 found = true;
834 break;
835 }
836 }
837 if !found {
838 panic!(
839 "center {} was not found in the original dataset - {}",
840 i, context,
841 );
842 }
843 }
844 assert!(
845 seen.len() as f32 >= 0.95 * (ncenters as f32),
846 "expected the distribution of centers to be wide, \
847 instead {} unique values were found - {}",
848 seen.len(),
849 context
850 );
851 }
852
853 #[test]
854 #[cfg(not(miri))]
855 fn fuzzy_sanity_check() {
856 let dims = [1, 16];
857 let ncenters = [0, 1, 5, 20, 255];
859 let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
860
861 for ncenters in ncenters {
862 for dim in dims {
863 fuzzy_sanity_check_impl(ncenters, dim, &mut rng);
864 }
865 }
866 }
867
868 #[test]
870 fn fail_empty_dataset() {
871 let data = Matrix::new(0.0, 0, 5);
872 let mut centers = Matrix::new(0.0, 10, data.ncols());
873
874 let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
875
876 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
877 assert!(
878 result.is_err(),
879 "kmeans++ on an empty dataset with non-empty centers should be an error"
880 );
881 let err = result.unwrap_err();
882 assert_eq!(err.selected, 0);
883 assert_eq!(err.expected, centers.nrows());
884 assert_eq!(err.reason, FailureReason::DatasetTooSmall);
885 assert!(!err.is_numerically_recoverable());
886
887 check_post_conditions(centers.as_view(), data.as_view(), &err);
888 }
889
890 #[test]
891 fn both_empty_is_okay() {
892 let data = Matrix::new(0.0, 0, 5);
893 let mut centers = Matrix::new(0.0, 0, data.ncols());
894 let mut rng = StdRng::seed_from_u64(0x6f7031afd9b5aa18);
895 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
896 assert!(
897 result.is_ok(),
898 "selecting 0 points from an empty dataset is okay"
899 );
900 }
901
902 #[test]
903 fn fail_dataset_not_big_enough() {
904 let ndata = 5;
905 let ncenters = 10;
906 let dim = 5;
907
908 let mut data = Matrix::new(0.0, ndata, dim);
909 set_default_values(data.as_mut_view());
910 let mut centers = Matrix::new(f32::INFINITY, ncenters, data.ncols());
911
912 let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
913
914 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
915 assert!(
916 result.is_err(),
917 "kmeans++ on an empty dataset with non-empty centers should be an error"
918 );
919 let err = result.unwrap_err();
920 assert_eq!(err.selected, data.nrows());
921 assert_eq!(err.expected, centers.nrows());
922 assert_eq!(err.reason, FailureReason::DatasetTooSmall);
923 assert!(err.is_numerically_recoverable());
924
925 check_post_conditions(centers.as_view(), data.as_view(), &err);
926 }
927
928 #[test]
931 fn fail_diversity_check() {
932 let ncenters = 10;
933 let ndata = 50;
934 let dim = 3;
935 let mut rng = StdRng::seed_from_u64(0xca57b032c21bf4bb);
936
937 let repeats_per_center = 10;
939 assert!(ncenters * repeats_per_center > ndata);
940 let mut values: Vec<f32> = (0..utils::div_round_up(ndata, repeats_per_center))
941 .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
942 .collect();
943 assert!(values.len() >= ndata);
944
945 values.shuffle(&mut rng);
946 let mut data = Matrix::new(0.0, ndata, dim);
947 for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
948 r.fill(*v);
949 }
950
951 let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
952 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
953 assert!(
954 result.is_err(),
955 "dataset should not have enough unique points"
956 );
957 let err = result.unwrap_err();
958 assert_eq!(err.selected, utils::div_round_up(ndata, repeats_per_center));
959 assert_eq!(err.expected, centers.nrows());
960 assert_eq!(err.reason, FailureReason::InsufficientDiversity);
961 assert!(err.is_numerically_recoverable());
962
963 check_post_conditions(centers.as_view(), data.as_view(), &err);
964 }
965
966 #[test]
967 fn fail_intinity_check() {
968 let mut data = Matrix::new(0.0, 10, 1);
969 set_default_values(data.as_mut_view());
970
971 data[(6, 0)] = -3.4028235e38;
973 let mut centers = Matrix::new(0.0, 2, 1);
974
975 let mut rng = StdRng::seed_from_u64(0xc0449b2aa4e12f05);
976
977 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
978 assert!(result.is_err(), "result should complain about infinity");
979 let err = result.unwrap_err();
980 assert_eq!(err.selected, 1);
981 assert_eq!(err.expected, centers.nrows());
982 assert_eq!(err.reason, FailureReason::SawInfinity);
983 assert!(!err.is_numerically_recoverable());
984
985 check_post_conditions(centers.as_view(), data.as_view(), &err);
986 }
987
988 #[test]
989 fn fail_nan_check() {
990 let mut data = Matrix::new(0.0, 10, 1);
991 set_default_values(data.as_mut_view());
992
993 data[(6, 0)] = f32::NAN;
995 let mut centers = Matrix::new(0.0, 2, 1);
996
997 let mut rng = StdRng::seed_from_u64(0x55808c6c728c8473);
998
999 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
1000 assert!(result.is_err(), "result should complain about NaN");
1001 let err = result.unwrap_err();
1002 assert_eq!(err.selected, 1);
1003 assert_eq!(err.expected, centers.nrows());
1004 assert_eq!(err.reason, FailureReason::SawInfinity);
1005 assert!(!err.is_numerically_recoverable());
1006
1007 check_post_conditions(centers.as_view(), data.as_view(), &err);
1008 }
1009
1010 #[test]
1015 #[should_panic(expected = "new point and dataset must have the same dimension")]
1016 fn update_distances_panics_dim_mismatch() {
1017 let npoints = 5;
1018 let dim = 8;
1019 let mut square_distances = vec![0.0; npoints];
1020 let data = Matrix::new(0.0, npoints, dim);
1021 let norms = vec![0.0; npoints];
1022 let this = vec![0.0; dim + 1]; let this_square_norm = 0.0;
1024 update_distances::<16>(
1025 &mut square_distances,
1026 BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
1027 &norms,
1028 &this,
1029 this_square_norm,
1030 );
1031 }
1032
1033 #[test]
1034 #[should_panic(expected = "distances buffer and dataset must have the same length")]
1035 fn update_distances_panics_distances_length_mismatch() {
1036 let npoints = 5;
1037 let dim = 8;
1038 let mut square_distances = vec![0.0; npoints + 1]; let data = Matrix::new(0.0, npoints, dim);
1040 let norms = vec![0.0; npoints];
1041 let this = vec![0.0; dim];
1042 let this_square_norm = 0.0;
1043 update_distances::<16>(
1044 &mut square_distances,
1045 BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
1046 &norms,
1047 &this,
1048 this_square_norm,
1049 );
1050 }
1051
1052 #[test]
1053 #[should_panic(expected = "norms and dataset must have the same length")]
1054 fn update_distances_panics_norms_length_mismatch() {
1055 let npoints = 5;
1056 let dim = 8;
1057 let mut square_distances = vec![0.0; npoints];
1058 let data = Matrix::new(0.0, npoints, dim);
1059 let norms = vec![0.0; npoints + 1]; let this = vec![0.0; dim];
1061 let this_square_norm = 0.0;
1062 update_distances::<16>(
1063 &mut square_distances,
1064 BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
1065 &norms,
1066 &this,
1067 this_square_norm,
1068 );
1069 }
1070
1071 #[test]
1076 #[should_panic(
1077 expected = "centers output matrix should have the same dimensionality as the dataset"
1078 )]
1079 fn kmeans_plusplus_into_panics_dim_mismatch() {
1080 let mut centers = Matrix::new(0.0, 2, 10);
1081 let data = Matrix::new(0.0, 2, 9);
1082 kmeans_plusplus_into(
1083 centers.as_mut_view(),
1084 data.as_view(),
1085 &mut rand::rngs::ThreadRng::default(),
1086 )
1087 .unwrap();
1088 }
1089}