1use std::{collections::HashSet, fmt};
7
8use diskann_utils::{
9 strided::StridedView,
10 views::{MatrixView, MutMatrixView},
11};
12use diskann_wide::{SIMDMulAdd, SIMDPartialOrd, SIMDSelect, SIMDVector};
13use rand::{
14 RngCore,
15 distr::{Distribution, Uniform},
16};
17use thiserror::Error;
18
19use super::common::{BlockTranspose, square_norm};
20
21pub(crate) trait MicroKernel {
34 type Intermediate;
36
37 type RollingSum: Default + Copy;
40
41 type Splat: Copy;
43
44 fn splat(x: f32) -> Self::Splat;
46
47 unsafe fn accum_full(block: *const f32, this: &[f32]) -> Self::Intermediate;
56
57 fn finish(
59 intermediate: Self::Intermediate,
60 splat: Self::Splat,
61 rolling_sum: Self::RollingSum,
62 norms: &[f32],
63 mins: &mut [f32],
64 ) -> Self::RollingSum;
65
66 fn finish_last(
68 intermediate: Self::Intermediate,
69 splat: Self::Splat,
70 rolling_sum: Self::RollingSum,
71 norms: &[f32],
72 mins: &mut [f32],
73 first: usize,
74 ) -> Self::RollingSum;
75
76 fn complete_sum(x: Self::RollingSum) -> f64;
78}
79
80diskann_wide::alias!(f32s = f32x8);
81
82impl MicroKernel for BlockTranspose<16> {
83 type Intermediate = (f32s, f32s);
85 type RollingSum = f64;
86 type Splat = f32s;
87
88 fn splat(x: f32) -> Self::Splat {
89 Self::Splat::splat(diskann_wide::ARCH, x)
90 }
91
92 #[inline(always)]
93 unsafe fn accum_full(block_ptr: *const f32, this: &[f32]) -> Self::Intermediate {
94 let mut s0 = f32s::default(diskann_wide::ARCH);
95 let mut s1 = f32s::default(diskann_wide::ARCH);
96
97 this.iter().enumerate().for_each(|(i, b)| {
98 let b = f32s::splat(diskann_wide::ARCH, *b);
99
100 let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i)) };
103 s0 = a.mul_add_simd(b, s0);
104
105 let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i + 8)) };
108 s1 = a.mul_add_simd(b, s1);
109 });
110
111 let negative = f32s::splat(diskann_wide::ARCH, -2.0);
113 (s0 * negative, s1 * negative)
114 }
115
116 #[inline(always)]
117 fn finish(
118 intermediate: Self::Intermediate,
119 splat: Self::Splat,
120 rolling_sum: Self::RollingSum,
121 norms: &[f32],
122 mins: &mut [f32],
123 ) -> Self::RollingSum {
124 assert_eq!(norms.len(), 16);
125 assert_eq!(mins.len(), 16);
126
127 let norms0 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr()) };
129 let norms1 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr().add(8)) };
131
132 let distances0 = norms0 + splat + intermediate.0;
133 let distances1 = norms1 + splat + intermediate.1;
134
135 let current_distances0 = unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr()) };
137 let current_distances1 =
139 unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr().add(8)) };
140
141 let mask0 = distances0.lt_simd(current_distances0);
142 let mask1 = distances1.lt_simd(current_distances1);
143
144 let current_distances0 = mask0.select(distances0, current_distances0);
145 let current_distances1 = mask1.select(distances1, current_distances1);
146
147 unsafe { current_distances0.store_simd(mins.as_mut_ptr()) };
149 unsafe { current_distances1.store_simd(mins.as_mut_ptr().add(8)) };
151
152 rolling_sum
153 + std::iter::zip(
154 current_distances0.to_array().iter(),
155 current_distances1.to_array().iter(),
156 )
157 .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
158 .sum::<f64>()
159 }
160
161 #[inline(always)]
162 fn finish_last(
163 intermediate: Self::Intermediate,
164 splat: Self::Splat,
165 rolling_sum: Self::RollingSum,
166 norms: &[f32],
167 mins: &mut [f32],
168 first: usize,
169 ) -> Self::RollingSum {
170 assert_eq!(norms.len(), first);
172 assert_eq!(mins.len(), first);
174
175 let lo = first.min(8);
176 let hi = first - lo;
177
178 let norms0 = unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr(), lo) };
181 let norms1 = if hi == 0 {
182 f32s::default(diskann_wide::ARCH)
183 } else {
184 unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr().add(8), hi) }
187 };
188
189 let distances0 = norms0 + splat + intermediate.0;
190 let distances1 = norms1 + splat + intermediate.1;
191
192 let current_distances0 =
194 unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr(), lo) };
195 let current_distances1 = if hi == 0 {
196 f32s::default(diskann_wide::ARCH)
197 } else {
198 unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr().add(8), hi) }
200 };
201
202 let mask0 = distances0.lt_simd(current_distances0);
203 let mask1 = distances1.lt_simd(current_distances1);
204
205 let current_distances0 = mask0.select(distances0, current_distances0);
206 let current_distances1 = mask1.select(distances1, current_distances1);
207
208 unsafe { current_distances0.store_simd_first(mins.as_mut_ptr(), lo) };
211 if hi != 0 {
212 unsafe { current_distances1.store_simd_first(mins.as_mut_ptr().add(8), hi) };
216 }
217
218 rolling_sum
219 + std::iter::zip(
220 current_distances0.to_array().iter(),
221 current_distances1.to_array().iter(),
222 )
223 .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
224 .sum::<f64>()
225 }
226
227 fn complete_sum(x: Self::RollingSum) -> f64 {
228 x
229 }
230}
231
232fn update_distances<const N: usize>(
237 square_distances: &mut [f32],
238 transpose: &BlockTranspose<N>,
239 norms: &[f32],
240 this: &[f32],
241 this_square_norm: f32,
242) -> f64
243where
244 BlockTranspose<N>: MicroKernel,
245{
246 assert_eq!(
249 this.len(),
250 transpose.ncols(),
251 "new point and dataset must have the same dimension",
252 );
253 assert_eq!(
255 square_distances.len(),
256 transpose.nrows(),
257 "distances buffer and dataset must have the same length",
258 );
259 assert_eq!(
261 norms.len(),
262 transpose.nrows(),
263 "norms and dataset must have the same length",
264 );
265
266 let splat = BlockTranspose::<N>::splat(this_square_norm);
267 let mut rolling_sum = <BlockTranspose<N> as MicroKernel>::RollingSum::default();
268
269 let iter =
270 std::iter::zip(norms.chunks_exact(N), square_distances.chunks_exact_mut(N)).enumerate();
271 iter.for_each(|(block, (these_norms, these_distances))| {
272 debug_assert!(block < transpose.num_blocks());
273 let base = unsafe { transpose.block_ptr_unchecked(block) };
277
278 let intermediate = unsafe { BlockTranspose::<N>::accum_full(base, this) };
281
282 rolling_sum = BlockTranspose::<N>::finish(
283 intermediate,
284 splat,
285 rolling_sum,
286 these_norms,
287 these_distances,
288 );
289 });
290
291 let remainder = transpose.remainder();
293 if remainder != 0 {
294 let base = unsafe { transpose.block_ptr_unchecked(transpose.full_blocks()) };
297
298 let intermediate = unsafe { BlockTranspose::<N>::accum_full(base, this) };
303
304 let start = N * transpose.full_blocks();
305 rolling_sum = BlockTranspose::<N>::finish_last(
306 intermediate,
307 splat,
308 rolling_sum,
309 &norms[start..],
310 &mut square_distances[start..],
311 remainder,
312 );
313 }
314
315 BlockTranspose::<N>::complete_sum(rolling_sum)
316}
317
318#[derive(Debug, Clone, Copy, PartialEq)]
319pub enum FailureReason {
320 DatasetTooSmall,
323 InsufficientDiversity,
325 SawInfinity,
327}
328
329impl FailureReason {
330 pub fn is_numerically_recoverable(self) -> bool {
331 match self {
332 Self::DatasetTooSmall | Self::InsufficientDiversity => true,
335
336 Self::SawInfinity => false,
339 }
340 }
341}
342
343impl fmt::Display for FailureReason {
344 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
345 let reason: &str = match self {
346 Self::DatasetTooSmall => "dataset does not have enough points",
347 Self::InsufficientDiversity => "dataset is insufficiently diverse",
348 Self::SawInfinity => "a value of infinity or NaN was observed",
349 };
350 f.write_str(reason)
351 }
352}
353
354#[derive(Debug, Clone, Copy, Error)]
355#[error("only populated {selected} of {expected} points because {reason}")]
356pub struct KMeansPlusPlusError {
357 pub selected: usize,
359 pub expected: usize,
361 pub reason: FailureReason,
363}
364
365impl KMeansPlusPlusError {
366 fn new(selected: usize, expected: usize, reason: FailureReason) -> Self {
367 Self {
368 selected,
369 expected,
370 reason,
371 }
372 }
373
374 pub fn is_numerically_recoverable(&self) -> bool {
375 self.reason.is_numerically_recoverable() && self.selected > 0
376 }
377}
378
379pub(crate) fn kmeans_plusplus_into_inner<const N: usize>(
380 mut points: MutMatrixView<'_, f32>,
381 data: StridedView<'_, f32>,
382 transpose: &BlockTranspose<N>,
383 norms: &[f32],
384 rng: &mut dyn RngCore,
385) -> Result<(), KMeansPlusPlusError>
386where
387 BlockTranspose<N>: MicroKernel,
388{
389 assert_eq!(norms.len(), data.nrows());
390 assert_eq!(transpose.nrows(), data.nrows());
391 assert_eq!(transpose.block_size(), data.ncols());
392 assert_eq!(points.ncols(), data.ncols());
393
394 points.as_mut_slice().fill(0.0);
396 let expected = points.nrows();
397
398 let all_rows = match Uniform::new(0, data.nrows()) {
403 Ok(dist) => dist,
404 Err(_) => {
405 return if expected == 0 {
407 Ok(())
408 } else {
409 Err(KMeansPlusPlusError::new(
410 0,
411 expected,
412 FailureReason::DatasetTooSmall,
413 ))
414 };
415 }
416 };
417
418 let mut min_distances: Vec<f32> = vec![f32::INFINITY; data.nrows()];
419 let mut picked = HashSet::with_capacity(expected);
420
421 let mut previous_square_norm = {
423 let i = all_rows.sample(rng);
424 points.row_mut(0).copy_from_slice(data.row(i));
425 picked.insert(i);
426 norms[i]
427 };
428
429 let mut selected = 1;
430 for current in 1..expected.min(data.nrows()) {
431 let last = points.row(current - 1);
432 let s = update_distances(
433 &mut min_distances,
434 transpose,
435 norms,
436 last,
437 previous_square_norm,
438 );
439
440 match Uniform::<f64>::new(0.0, s) {
444 Ok(distribution) => {
445 let threshold = distribution.sample(rng);
446 let mut rolling_sum: f64 = 0.0;
447 for (i, d) in min_distances.iter().enumerate() {
448 rolling_sum += <f32 as Into<f64>>::into(*d);
449 if rolling_sum >= threshold && (*d > 0.0) && !picked.contains(&i) {
450 points.row_mut(current).clone_from_slice(data.row(i));
453 picked.insert(i);
454 previous_square_norm = norms[i];
455 selected = current + 1;
456 break;
457 }
458 }
459 }
460 Err(rand::distr::uniform::Error::EmptyRange) => {}
464 Err(rand::distr::uniform::Error::NonFinite) => {
466 return Err(KMeansPlusPlusError::new(
467 selected,
468 expected,
469 FailureReason::SawInfinity,
470 ));
471 }
472 }
473
474 if selected != (current + 1) {
478 return Err(KMeansPlusPlusError::new(
479 selected,
480 expected,
481 FailureReason::InsufficientDiversity,
482 ));
483 }
484 }
485
486 if selected != expected {
488 Err(KMeansPlusPlusError::new(
489 selected,
490 expected,
491 FailureReason::DatasetTooSmall,
492 ))
493 } else {
494 Ok(())
495 }
496}
497
498pub fn kmeans_plusplus_into(
499 centers: MutMatrixView<'_, f32>,
500 data: MatrixView<'_, f32>,
501 rng: &mut dyn RngCore,
502) -> Result<(), KMeansPlusPlusError> {
503 assert_eq!(
504 centers.ncols(),
505 data.ncols(),
506 "centers output matrix should have the same dimensionality as the dataset"
507 );
508
509 const GROUPSIZE: usize = 16;
510 let mut norms: Vec<f32> = vec![0.0; data.nrows()];
511
512 for (n, d) in std::iter::zip(norms.iter_mut(), data.row_iter()) {
513 *n = square_norm(d);
514 }
515
516 let transpose = BlockTranspose::<GROUPSIZE>::from_matrix_view(data);
517 kmeans_plusplus_into_inner(centers, data.into(), &transpose, &norms, rng)
518}
519
520#[cfg(test)]
521mod tests {
522 use diskann_utils::{lazy_format, views::Matrix};
523 use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
524 use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
525
526 use super::*;
527 use crate::utils;
528
529 fn is_in(needle: &[f32], haystack: MatrixView<'_, f32>) -> bool {
530 assert_eq!(needle.len(), haystack.ncols());
531 haystack.row_iter().any(|row| row == needle)
532 }
533
534 fn check_post_conditions(
535 centers: MatrixView<'_, f32>,
536 data: MatrixView<'_, f32>,
537 err: &KMeansPlusPlusError,
538 ) {
539 assert_eq!(err.expected, centers.nrows());
540 assert!(err.expected > err.selected);
541 for i in 0..err.selected {
542 assert!(is_in(centers.row(i), data.as_view()));
543 }
544 for i in err.selected..centers.nrows() {
545 assert!(centers.row(i).iter().all(|j| *j == 0.0));
546 }
547 }
548
549 #[test]
554 fn test_error_display() {
555 assert_eq!(
556 format!("{}", FailureReason::DatasetTooSmall),
557 "dataset does not have enough points"
558 );
559
560 assert_eq!(
561 format!("{}", FailureReason::InsufficientDiversity),
562 "dataset is insufficiently diverse"
563 );
564
565 assert_eq!(
566 format!("{}", FailureReason::SawInfinity),
567 "a value of infinity or NaN was observed"
568 );
569 }
570
571 fn set_default_values(mut x: MutMatrixView<'_, f32>) {
583 for (i, row) in x.row_iter_mut().enumerate() {
584 for (j, r) in row.iter_mut().enumerate() {
585 *r = (i + j) as f32;
586 }
587 }
588 }
589
590 fn test_update_distances_impl<const N: usize, R>(num_points: usize, dim: usize, rng: &mut R)
599 where
600 BlockTranspose<N>: MicroKernel,
601 R: Rng,
602 {
603 let context = lazy_format!(
604 "setup: N = {}, num_points = {}, dim = {}",
605 N,
606 num_points,
607 dim
608 );
609
610 let mut data = Matrix::<f32>::new(0.0, num_points, dim);
611 set_default_values(data.as_mut_view());
612
613 let square_norms: Vec<f32> = data.row_iter().map(square_norm).collect();
614
615 let num_samples = 3;
617 let mut samples = Matrix::<f32>::new(0.0, num_samples, dim);
618 let mut distances = vec![f32::INFINITY; num_points];
619 let distribution = Uniform::<u32>::new(0, (num_points + dim) as u32).unwrap();
620 let transpose = BlockTranspose::<N>::from_matrix_view(data.as_view());
621
622 let mut last_residual = f64::INFINITY;
623 for i in 0..num_samples {
624 {
626 let row = samples.row_mut(i);
627 row.iter_mut().for_each(|r| {
628 *r = distribution.sample(rng) as f32;
629 });
630 }
631 let row = samples.row(i);
632 let norm = square_norm(row);
633
634 let residual = update_distances(&mut distances, &transpose, &square_norms, row, norm);
635
636 for (n, (d, data)) in std::iter::zip(distances.iter(), data.row_iter()).enumerate() {
638 let mut min_distance = f32::INFINITY;
639 for j in 0..=i {
640 let distance = SquaredL2::evaluate(samples.row(j), data);
641 min_distance = min_distance.min(distance);
642 }
643 assert_eq!(
644 min_distance, *d,
645 "failed on row {n} on iteration {i}. {}",
646 context
647 );
648 }
649
650 assert_eq!(
652 residual,
653 distances.iter().sum::<f32>() as f64,
654 "residual sum failed on iteration {i} - {}",
655 context
656 );
657
658 assert!(
660 residual <= last_residual,
661 "residual check failed on iteration {}, last = {}, this = {} - {}",
662 i,
663 last_residual,
664 residual,
665 context
666 );
667
668 last_residual = residual;
669 }
670 }
671
672 #[test]
683 fn test_update_distances() {
684 let mut rng = StdRng::seed_from_u64(0x56c94b53c73e4fd9);
685 for num_points in 0..48 {
686 #[cfg(miri)]
687 if num_points % 7 != 0 {
688 continue;
689 }
690
691 for dim in 1..4 {
692 test_update_distances_impl(num_points, dim, &mut rng);
693 }
694 }
695 }
696
697 #[cfg(not(miri))]
704 fn sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
705 let repeats_per_center = 3;
706 let context = lazy_format!(
707 "dim = {}, ncenters = {}, repeats_per_center = {}",
708 dim,
709 ncenters,
710 repeats_per_center
711 );
712
713 let ndata = repeats_per_center * ncenters;
714 let mut values: Vec<f32> = (0..ncenters)
715 .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
716 .collect();
717 assert_eq!(values.len(), ndata);
718
719 values.shuffle(rng);
720 let mut data = Matrix::new(0.0, ndata, dim);
721 for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
722 r.fill(*v);
723 }
724
725 let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
726 kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
727
728 let mut seen = HashSet::<usize>::new();
730 for c in centers.row_iter() {
731 let first = c[0];
732 assert!(c.iter().all(|i| *i == first));
733
734 let v: usize = first.round() as usize;
735 assert_eq!(v as f32, first, "conversion was not lossless - {}", context);
736
737 if !seen.insert(v) {
738 panic!("value {first} seen more than oncex - {}", context);
739 }
740 }
741 assert_eq!(
742 seen.len(),
743 ncenters,
744 "not all points were seen - {}",
745 context
746 );
747 }
748
749 #[test]
750 #[cfg(not(miri))]
751 fn sanity_check() {
752 let dims = [1, 16];
753 let ncenters = [1, 5, 20, 255];
754 let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
755
756 for ncenters in ncenters {
757 for dim in dims {
758 sanity_check_impl(ncenters, dim, &mut rng);
759 }
760 }
761 }
762
763 #[cfg(not(miri))]
766 fn fuzzy_sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
767 let repeats_per_center = 3;
768
769 let spreading_multiplier: usize = 16;
771 let perturbation_distribution = Uniform::new(-0.125, 0.125).unwrap();
773
774 let context = lazy_format!(
775 "dim = {}, ncenters = {}, repeats_per_center = {}, multiplier = {}",
776 dim,
777 ncenters,
778 repeats_per_center,
779 spreading_multiplier,
780 );
781
782 let ndata = repeats_per_center * ncenters;
783 let mut values: Vec<f32> = (0..ncenters)
784 .flat_map(|i| {
785 let v: Vec<f32> = (0..repeats_per_center)
787 .map(|_| {
788 (spreading_multiplier * i) as f32 + perturbation_distribution.sample(rng)
789 })
790 .collect();
791
792 v.into_iter()
793 })
794 .collect();
795 assert_eq!(values.len(), ndata);
796
797 values.shuffle(rng);
798 let mut data = Matrix::new(0.0, ndata, dim);
799 for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
800 r.fill(*v);
801 }
802
803 let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
804 kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
805
806 let mut seen = HashSet::<usize>::new();
808 for (i, c) in centers.row_iter().enumerate() {
809 let first = c[0];
810 let v: usize = first.round() as usize;
811 assert_eq!(
812 v % spreading_multiplier,
813 0,
814 "expected row value to be close to a multiple of the spreading multiplier, \
815 instead got {} - {}",
816 v,
817 context
818 );
819 seen.insert(v);
820
821 let mut found = false;
823 for r in data.row_iter() {
824 if r == c {
825 found = true;
826 break;
827 }
828 }
829 if !found {
830 panic!(
831 "center {} was not found in the original dataset - {}",
832 i, context,
833 );
834 }
835 }
836 assert!(
837 seen.len() as f32 >= 0.95 * (ncenters as f32),
838 "expected the distribution of centers to be wide, \
839 instead {} unique values were found - {}",
840 seen.len(),
841 context
842 );
843 }
844
845 #[test]
846 #[cfg(not(miri))]
847 fn fuzzy_sanity_check() {
848 let dims = [1, 16];
849 let ncenters = [0, 1, 5, 20, 255];
851 let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
852
853 for ncenters in ncenters {
854 for dim in dims {
855 fuzzy_sanity_check_impl(ncenters, dim, &mut rng);
856 }
857 }
858 }
859
860 #[test]
862 fn fail_empty_dataset() {
863 let data = Matrix::new(0.0, 0, 5);
864 let mut centers = Matrix::new(0.0, 10, data.ncols());
865
866 let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
867
868 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
869 assert!(
870 result.is_err(),
871 "kmeans++ on an empty dataset with non-empty centers should be an error"
872 );
873 let err = result.unwrap_err();
874 assert_eq!(err.selected, 0);
875 assert_eq!(err.expected, centers.nrows());
876 assert_eq!(err.reason, FailureReason::DatasetTooSmall);
877 assert!(!err.is_numerically_recoverable());
878
879 check_post_conditions(centers.as_view(), data.as_view(), &err);
880 }
881
882 #[test]
883 fn both_empty_is_okay() {
884 let data = Matrix::new(0.0, 0, 5);
885 let mut centers = Matrix::new(0.0, 0, data.ncols());
886 let mut rng = StdRng::seed_from_u64(0x6f7031afd9b5aa18);
887 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
888 assert!(
889 result.is_ok(),
890 "selecting 0 points from an empty dataset is okay"
891 );
892 }
893
894 #[test]
895 fn fail_dataset_not_big_enough() {
896 let ndata = 5;
897 let ncenters = 10;
898 let dim = 5;
899
900 let mut data = Matrix::new(0.0, ndata, dim);
901 set_default_values(data.as_mut_view());
902 let mut centers = Matrix::new(f32::INFINITY, ncenters, data.ncols());
903
904 let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
905
906 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
907 assert!(
908 result.is_err(),
909 "kmeans++ on an empty dataset with non-empty centers should be an error"
910 );
911 let err = result.unwrap_err();
912 assert_eq!(err.selected, data.nrows());
913 assert_eq!(err.expected, centers.nrows());
914 assert_eq!(err.reason, FailureReason::DatasetTooSmall);
915 assert!(err.is_numerically_recoverable());
916
917 check_post_conditions(centers.as_view(), data.as_view(), &err);
918 }
919
920 #[test]
923 fn fail_diversity_check() {
924 let ncenters = 10;
925 let ndata = 50;
926 let dim = 3;
927 let mut rng = StdRng::seed_from_u64(0xca57b032c21bf4bb);
928
929 let repeats_per_center = 10;
931 assert!(ncenters * repeats_per_center > ndata);
932 let mut values: Vec<f32> = (0..utils::div_round_up(ndata, repeats_per_center))
933 .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
934 .collect();
935 assert!(values.len() >= ndata);
936
937 values.shuffle(&mut rng);
938 let mut data = Matrix::new(0.0, ndata, dim);
939 for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
940 r.fill(*v);
941 }
942
943 let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
944 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
945 assert!(
946 result.is_err(),
947 "dataset should not have enough unique points"
948 );
949 let err = result.unwrap_err();
950 assert_eq!(err.selected, utils::div_round_up(ndata, repeats_per_center));
951 assert_eq!(err.expected, centers.nrows());
952 assert_eq!(err.reason, FailureReason::InsufficientDiversity);
953 assert!(err.is_numerically_recoverable());
954
955 check_post_conditions(centers.as_view(), data.as_view(), &err);
956 }
957
958 #[test]
959 fn fail_intinity_check() {
960 let mut data = Matrix::new(0.0, 10, 1);
961 set_default_values(data.as_mut_view());
962
963 data[(6, 0)] = -3.4028235e38;
965 let mut centers = Matrix::new(0.0, 2, 1);
966
967 let mut rng = StdRng::seed_from_u64(0xc0449b2aa4e12f05);
968
969 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
970 assert!(result.is_err(), "result should complain about infinity");
971 let err = result.unwrap_err();
972 assert_eq!(err.selected, 1);
973 assert_eq!(err.expected, centers.nrows());
974 assert_eq!(err.reason, FailureReason::SawInfinity);
975 assert!(!err.is_numerically_recoverable());
976
977 check_post_conditions(centers.as_view(), data.as_view(), &err);
978 }
979
980 #[test]
981 fn fail_nan_check() {
982 let mut data = Matrix::new(0.0, 10, 1);
983 set_default_values(data.as_mut_view());
984
985 data[(6, 0)] = f32::NAN;
987 let mut centers = Matrix::new(0.0, 2, 1);
988
989 let mut rng = StdRng::seed_from_u64(0x55808c6c728c8473);
990
991 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
992 assert!(result.is_err(), "result should complain about NaN");
993 let err = result.unwrap_err();
994 assert_eq!(err.selected, 1);
995 assert_eq!(err.expected, centers.nrows());
996 assert_eq!(err.reason, FailureReason::SawInfinity);
997 assert!(!err.is_numerically_recoverable());
998
999 check_post_conditions(centers.as_view(), data.as_view(), &err);
1000 }
1001
1002 #[test]
1007 #[should_panic(expected = "new point and dataset must have the same dimension")]
1008 fn update_distances_panics_dim_mismatch() {
1009 let npoints = 5;
1010 let dim = 8;
1011 let mut square_distances = vec![0.0; npoints];
1012 let data = Matrix::new(0.0, npoints, dim);
1013 let norms = vec![0.0; npoints];
1014 let this = vec![0.0; dim + 1]; let this_square_norm = 0.0;
1016 update_distances::<16>(
1017 &mut square_distances,
1018 &BlockTranspose::from_matrix_view(data.as_view()),
1019 &norms,
1020 &this,
1021 this_square_norm,
1022 );
1023 }
1024
1025 #[test]
1026 #[should_panic(expected = "distances buffer and dataset must have the same length")]
1027 fn update_distances_panics_distances_length_mismatch() {
1028 let npoints = 5;
1029 let dim = 8;
1030 let mut square_distances = vec![0.0; npoints + 1]; let data = Matrix::new(0.0, npoints, dim);
1032 let norms = vec![0.0; npoints];
1033 let this = vec![0.0; dim];
1034 let this_square_norm = 0.0;
1035 update_distances::<16>(
1036 &mut square_distances,
1037 &BlockTranspose::from_matrix_view(data.as_view()),
1038 &norms,
1039 &this,
1040 this_square_norm,
1041 );
1042 }
1043
1044 #[test]
1045 #[should_panic(expected = "norms and dataset must have the same length")]
1046 fn update_distances_panics_norms_length_mismatch() {
1047 let npoints = 5;
1048 let dim = 8;
1049 let mut square_distances = vec![0.0; npoints];
1050 let data = Matrix::new(0.0, npoints, dim);
1051 let norms = vec![0.0; npoints + 1]; let this = vec![0.0; dim];
1053 let this_square_norm = 0.0;
1054 update_distances::<16>(
1055 &mut square_distances,
1056 &BlockTranspose::from_matrix_view(data.as_view()),
1057 &norms,
1058 &this,
1059 this_square_norm,
1060 );
1061 }
1062
1063 #[test]
1068 #[should_panic(
1069 expected = "centers output matrix should have the same dimensionality as the dataset"
1070 )]
1071 fn kmeans_plusplus_into_panics_dim_mismatch() {
1072 let mut centers = Matrix::new(0.0, 2, 10);
1073 let data = Matrix::new(0.0, 2, 9);
1074 kmeans_plusplus_into(
1075 centers.as_mut_view(),
1076 data.as_view(),
1077 &mut rand::rngs::ThreadRng::default(),
1078 )
1079 .unwrap();
1080 }
1081}