1use std::{collections::HashSet, fmt};
7
8use diskann_utils::{
9 strided::StridedView,
10 views::{MatrixView, MutMatrixView},
11};
12use diskann_wide::{SIMDMulAdd, SIMDPartialOrd, SIMDSelect, SIMDVector};
13use rand::{
14 RngCore,
15 distr::{Distribution, Uniform},
16};
17use thiserror::Error;
18
19use super::common::{BlockTranspose, square_norm};
20
21pub(crate) trait MicroKernel {
34 type Intermediate;
36
37 type RollingSum: Default + Copy;
40
41 type Splat: Copy;
43
44 fn splat(x: f32) -> Self::Splat;
46
47 unsafe fn accum_full(block: *const f32, this: &[f32]) -> Self::Intermediate;
56
57 fn finish(
59 intermediate: Self::Intermediate,
60 splat: Self::Splat,
61 rolling_sum: Self::RollingSum,
62 norms: &[f32],
63 mins: &mut [f32],
64 ) -> Self::RollingSum;
65
66 fn finish_last(
68 intermediate: Self::Intermediate,
69 splat: Self::Splat,
70 rolling_sum: Self::RollingSum,
71 norms: &[f32],
72 mins: &mut [f32],
73 first: usize,
74 ) -> Self::RollingSum;
75
76 fn complete_sum(x: Self::RollingSum) -> f64;
78}
79
80diskann_wide::alias!(f32s = f32x8);
81
82impl MicroKernel for BlockTranspose<16> {
83 type Intermediate = (f32s, f32s);
85 type RollingSum = f64;
86 type Splat = f32s;
87
88 fn splat(x: f32) -> Self::Splat {
89 Self::Splat::splat(diskann_wide::ARCH, x)
90 }
91
92 #[inline(always)]
93 unsafe fn accum_full(block_ptr: *const f32, this: &[f32]) -> Self::Intermediate {
94 let mut s0 = f32s::default(diskann_wide::ARCH);
95 let mut s1 = f32s::default(diskann_wide::ARCH);
96
97 this.iter().enumerate().for_each(|(i, b)| {
98 let b = f32s::splat(diskann_wide::ARCH, *b);
99
100 let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i)) };
103 s0 = a.mul_add_simd(b, s0);
104
105 let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i + 8)) };
108 s1 = a.mul_add_simd(b, s1);
109 });
110
111 let negative = f32s::splat(diskann_wide::ARCH, -2.0);
113 (s0 * negative, s1 * negative)
114 }
115
116 #[inline(always)]
117 fn finish(
118 intermediate: Self::Intermediate,
119 splat: Self::Splat,
120 rolling_sum: Self::RollingSum,
121 norms: &[f32],
122 mins: &mut [f32],
123 ) -> Self::RollingSum {
124 assert_eq!(norms.len(), 16);
125 assert_eq!(mins.len(), 16);
126
127 let norms0 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr()) };
129 let norms1 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr().add(8)) };
131
132 let distances0 = norms0 + splat + intermediate.0;
133 let distances1 = norms1 + splat + intermediate.1;
134
135 let current_distances0 = unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr()) };
137 let current_distances1 =
139 unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr().add(8)) };
140
141 let mask0 = distances0.lt_simd(current_distances0);
142 let mask1 = distances1.lt_simd(current_distances1);
143
144 let current_distances0 = mask0.select(distances0, current_distances0);
145 let current_distances1 = mask1.select(distances1, current_distances1);
146
147 unsafe { current_distances0.store_simd(mins.as_mut_ptr()) };
149 unsafe { current_distances1.store_simd(mins.as_mut_ptr().add(8)) };
151
152 rolling_sum
153 + std::iter::zip(
154 current_distances0.to_array().iter(),
155 current_distances1.to_array().iter(),
156 )
157 .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
158 .sum::<f64>()
159 }
160
161 #[inline(always)]
162 fn finish_last(
163 intermediate: Self::Intermediate,
164 splat: Self::Splat,
165 rolling_sum: Self::RollingSum,
166 norms: &[f32],
167 mins: &mut [f32],
168 first: usize,
169 ) -> Self::RollingSum {
170 assert_eq!(norms.len(), first);
172 assert_eq!(mins.len(), first);
174
175 let lo = first.min(8);
176 let hi = first - lo;
177
178 let norms0 = unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr(), lo) };
181 let norms1 = if hi == 0 {
182 f32s::default(diskann_wide::ARCH)
183 } else {
184 unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr().add(8), hi) }
187 };
188
189 let distances0 = norms0 + splat + intermediate.0;
190 let distances1 = norms1 + splat + intermediate.1;
191
192 let current_distances0 =
194 unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr(), lo) };
195 let current_distances1 = if hi == 0 {
196 f32s::default(diskann_wide::ARCH)
197 } else {
198 unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr().add(8), hi) }
200 };
201
202 let mask0 = distances0.lt_simd(current_distances0);
203 let mask1 = distances1.lt_simd(current_distances1);
204
205 let current_distances0 = mask0.select(distances0, current_distances0);
206 let current_distances1 = mask1.select(distances1, current_distances1);
207
208 unsafe { current_distances0.store_simd_first(mins.as_mut_ptr(), lo) };
211 if hi != 0 {
212 unsafe { current_distances1.store_simd_first(mins.as_mut_ptr().add(8), hi) };
216 }
217
218 rolling_sum
219 + std::iter::zip(
220 current_distances0.to_array().iter(),
221 current_distances1.to_array().iter(),
222 )
223 .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
224 .sum::<f64>()
225 }
226
227 fn complete_sum(x: Self::RollingSum) -> f64 {
228 x
229 }
230}
231
232fn update_distances<const N: usize>(
237 square_distances: &mut [f32],
238 transpose: &BlockTranspose<N>,
239 norms: &[f32],
240 this: &[f32],
241 this_square_norm: f32,
242) -> f64
243where
244 BlockTranspose<N>: MicroKernel,
245{
246 assert_eq!(
249 this.len(),
250 transpose.ncols(),
251 "new point and dataset must have the same dimension",
252 );
253 assert_eq!(
255 square_distances.len(),
256 transpose.nrows(),
257 "distances buffer and dataset must have the same length",
258 );
259 assert_eq!(
261 norms.len(),
262 transpose.nrows(),
263 "norms and dataset must have the same length",
264 );
265
266 let splat = BlockTranspose::<N>::splat(this_square_norm);
267 let mut rolling_sum = <BlockTranspose<N> as MicroKernel>::RollingSum::default();
268
269 let iter =
270 std::iter::zip(norms.chunks_exact(N), square_distances.chunks_exact_mut(N)).enumerate();
271 iter.for_each(|(block, (these_norms, these_distances))| {
272 debug_assert!(block < transpose.num_blocks());
273 let base = unsafe { transpose.block_ptr_unchecked(block) };
277
278 let intermediate = unsafe { BlockTranspose::<N>::accum_full(base, this) };
281
282 rolling_sum = BlockTranspose::<N>::finish(
283 intermediate,
284 splat,
285 rolling_sum,
286 these_norms,
287 these_distances,
288 );
289 });
290
291 let remainder = transpose.remainder();
293 if remainder != 0 {
294 let base = unsafe { transpose.block_ptr_unchecked(transpose.full_blocks()) };
297
298 let intermediate = unsafe { BlockTranspose::<N>::accum_full(base, this) };
303
304 let start = N * transpose.full_blocks();
305 rolling_sum = BlockTranspose::<N>::finish_last(
306 intermediate,
307 splat,
308 rolling_sum,
309 &norms[start..],
310 &mut square_distances[start..],
311 remainder,
312 );
313 }
314
315 BlockTranspose::<N>::complete_sum(rolling_sum)
316}
317
318#[derive(Debug, Clone, Copy, PartialEq)]
319pub enum FailureReason {
320 DatasetTooSmall,
323 InsufficientDiversity,
325 SawInfinity,
327}
328
329impl FailureReason {
330 pub fn is_numerically_recoverable(self) -> bool {
331 match self {
332 Self::DatasetTooSmall | Self::InsufficientDiversity => true,
335
336 Self::SawInfinity => false,
339 }
340 }
341}
342
343impl fmt::Display for FailureReason {
344 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
345 let reason: &str = match self {
346 Self::DatasetTooSmall => "dataset does not have enough points",
347 Self::InsufficientDiversity => "dataset is insufficiently diverse",
348 Self::SawInfinity => "a value of infinity or NaN was observed",
349 };
350 f.write_str(reason)
351 }
352}
353
354#[derive(Debug, Clone, Copy, Error)]
355#[error("only populated {selected} of {expected} points because {reason}")]
356pub struct KMeansPlusPlusError {
357 pub selected: usize,
359 pub expected: usize,
361 pub reason: FailureReason,
363}
364
365impl KMeansPlusPlusError {
366 fn new(selected: usize, expected: usize, reason: FailureReason) -> Self {
367 Self {
368 selected,
369 expected,
370 reason,
371 }
372 }
373
374 pub fn is_numerically_recoverable(&self) -> bool {
375 self.reason.is_numerically_recoverable() && self.selected > 0
376 }
377}
378
379pub(crate) fn kmeans_plusplus_into_inner<const N: usize>(
380 mut points: MutMatrixView<'_, f32>,
381 data: StridedView<'_, f32>,
382 transpose: &BlockTranspose<N>,
383 norms: &[f32],
384 rng: &mut dyn RngCore,
385) -> Result<(), KMeansPlusPlusError>
386where
387 BlockTranspose<N>: MicroKernel,
388{
389 assert_eq!(norms.len(), data.nrows());
390 assert_eq!(transpose.nrows(), data.nrows());
391 assert_eq!(transpose.block_size(), data.ncols());
392 assert_eq!(points.ncols(), data.ncols());
393
394 points.as_mut_slice().fill(0.0);
396 let expected = points.nrows();
397
398 let all_rows = match Uniform::new(0, data.nrows()) {
403 Ok(dist) => dist,
404 Err(_) => {
405 return if expected == 0 {
407 Ok(())
408 } else {
409 Err(KMeansPlusPlusError::new(
410 0,
411 expected,
412 FailureReason::DatasetTooSmall,
413 ))
414 };
415 }
416 };
417
418 let mut min_distances: Vec<f32> = vec![f32::INFINITY; data.nrows()];
419 let mut picked = HashSet::with_capacity(expected);
420
421 let mut previous_square_norm = {
423 let i = all_rows.sample(rng);
424 points.row_mut(0).copy_from_slice(data.row(i));
425 picked.insert(i);
426 norms[i]
427 };
428
429 let mut selected = 1;
430 for current in 1..expected.min(data.nrows()) {
431 let last = points.row(current - 1);
432 let s = update_distances(
433 &mut min_distances,
434 transpose,
435 norms,
436 last,
437 previous_square_norm,
438 );
439
440 match Uniform::<f64>::new(0.0, s) {
444 Ok(distribution) => {
445 let threshold = distribution.sample(rng);
446 let mut rolling_sum: f64 = 0.0;
447 for (i, d) in min_distances.iter().enumerate() {
448 rolling_sum += <f32 as Into<f64>>::into(*d);
449 if rolling_sum >= threshold && (*d > 0.0) && !picked.contains(&i) {
450 points.row_mut(current).clone_from_slice(data.row(i));
453 picked.insert(i);
454 previous_square_norm = norms[i];
455 selected = current + 1;
456 break;
457 }
458 }
459 }
460 Err(rand::distr::uniform::Error::EmptyRange) => {}
464 Err(rand::distr::uniform::Error::NonFinite) => {
466 return Err(KMeansPlusPlusError::new(
467 selected,
468 expected,
469 FailureReason::SawInfinity,
470 ));
471 }
472 }
473
474 if selected != (current + 1) {
478 return Err(KMeansPlusPlusError::new(
479 selected,
480 expected,
481 FailureReason::InsufficientDiversity,
482 ));
483 }
484 }
485
486 if selected != expected {
488 Err(KMeansPlusPlusError::new(
489 selected,
490 expected,
491 FailureReason::DatasetTooSmall,
492 ))
493 } else {
494 Ok(())
495 }
496}
497
498pub fn kmeans_plusplus_into(
499 centers: MutMatrixView<'_, f32>,
500 data: MatrixView<'_, f32>,
501 rng: &mut dyn RngCore,
502) -> Result<(), KMeansPlusPlusError> {
503 assert_eq!(
504 centers.ncols(),
505 data.ncols(),
506 "centers output matrix should have the same dimensionality as the dataset"
507 );
508
509 const GROUPSIZE: usize = 16;
510 let mut norms: Vec<f32> = vec![0.0; data.nrows()];
511
512 for (n, d) in std::iter::zip(norms.iter_mut(), data.row_iter()) {
513 *n = square_norm(d);
514 }
515
516 let transpose = BlockTranspose::<GROUPSIZE>::from_matrix_view(data);
517 kmeans_plusplus_into_inner(centers, data.into(), &transpose, &norms, rng)
518}
519
520#[cfg(test)]
521mod tests {
522 use diskann_utils::{lazy_format, views::Matrix};
523 use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
524 use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
525
526 use super::*;
527 use crate::utils;
528
529 fn is_in(needle: &[f32], haystack: MatrixView<'_, f32>) -> bool {
530 assert_eq!(needle.len(), haystack.ncols());
531 haystack.row_iter().any(|row| row == needle)
532 }
533
534 fn check_post_conditions(
535 centers: MatrixView<'_, f32>,
536 data: MatrixView<'_, f32>,
537 err: &KMeansPlusPlusError,
538 ) {
539 assert_eq!(err.expected, centers.nrows());
540 assert!(err.expected > err.selected);
541 for i in 0..err.selected {
542 assert!(is_in(centers.row(i), data.as_view()));
543 }
544 for i in err.selected..centers.nrows() {
545 assert!(centers.row(i).iter().all(|j| *j == 0.0));
546 }
547 }
548
549 #[test]
554 fn test_error_display() {
555 assert_eq!(
556 format!("{}", FailureReason::DatasetTooSmall),
557 "dataset does not have enough points"
558 );
559
560 assert_eq!(
561 format!("{}", FailureReason::InsufficientDiversity),
562 "dataset is insufficiently diverse"
563 );
564
565 assert_eq!(
566 format!("{}", FailureReason::SawInfinity),
567 "a value of infinity or NaN was observed"
568 );
569 }
570
571 fn set_default_values(mut x: MutMatrixView<'_, f32>) {
583 for (i, row) in x.row_iter_mut().enumerate() {
584 for (j, r) in row.iter_mut().enumerate() {
585 *r = (i + j) as f32;
586 }
587 }
588 }
589
590 fn test_update_distances_impl<const N: usize, R>(num_points: usize, dim: usize, rng: &mut R)
599 where
600 BlockTranspose<N>: MicroKernel,
601 R: Rng,
602 {
603 let context = lazy_format!(
604 "setup: N = {}, num_points = {}, dim = {}",
605 N,
606 num_points,
607 dim
608 );
609
610 let mut data = Matrix::<f32>::new(0.0, num_points, dim);
611 set_default_values(data.as_mut_view());
612
613 let square_norms: Vec<f32> = data.row_iter().map(square_norm).collect();
614
615 let num_samples = 3;
617 let mut samples = Matrix::<f32>::new(0.0, num_samples, dim);
618 let mut distances = vec![f32::INFINITY; num_points];
619 let distribution = Uniform::<u32>::new(0, (num_points + dim) as u32).unwrap();
620 let transpose = BlockTranspose::<N>::from_matrix_view(data.as_view());
621
622 let mut last_residual = f64::INFINITY;
623 for i in 0..num_samples {
624 {
626 let row = samples.row_mut(i);
627 row.iter_mut().for_each(|r| {
628 *r = distribution.sample(rng) as f32;
629 });
630 }
631 let row = samples.row(i);
632 let norm = square_norm(row);
633
634 let residual = update_distances(&mut distances, &transpose, &square_norms, row, norm);
635
636 for (n, (d, data)) in std::iter::zip(distances.iter(), data.row_iter()).enumerate() {
638 let mut min_distance = f32::INFINITY;
639 for j in 0..=i {
640 let distance = SquaredL2::evaluate(samples.row(j), data);
641 min_distance = min_distance.min(distance);
642 }
643 assert_eq!(
644 min_distance, *d,
645 "failed on row {n} on iteration {i}. {}",
646 context
647 );
648 }
649
650 assert_eq!(
652 residual,
653 distances.iter().sum::<f32>() as f64,
654 "residual sum failed on iteration {i} - {}",
655 context
656 );
657
658 assert!(
660 residual <= last_residual,
661 "residual check failed on iteration {}, last = {}, this = {} - {}",
662 i,
663 last_residual,
664 residual,
665 context
666 );
667
668 last_residual = residual;
669 }
670 }
671
672 #[test]
683 fn test_update_distances() {
684 let mut rng = StdRng::seed_from_u64(0x56c94b53c73e4fd9);
685 for num_points in 0..48 {
686 for dim in 1..4 {
687 test_update_distances_impl(num_points, dim, &mut rng);
688 }
689 }
690 }
691
692 fn sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
699 let repeats_per_center = 3;
700 let context = lazy_format!(
701 "dim = {}, ncenters = {}, repeats_per_center = {}",
702 dim,
703 ncenters,
704 repeats_per_center
705 );
706
707 let ndata = repeats_per_center * ncenters;
708 let mut values: Vec<f32> = (0..ncenters)
709 .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
710 .collect();
711 assert_eq!(values.len(), ndata);
712
713 values.shuffle(rng);
714 let mut data = Matrix::new(0.0, ndata, dim);
715 for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
716 r.fill(*v);
717 }
718
719 let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
720 kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
721
722 let mut seen = HashSet::<usize>::new();
724 for c in centers.row_iter() {
725 let first = c[0];
726 assert!(c.iter().all(|i| *i == first));
727
728 let v: usize = first.round() as usize;
729 assert_eq!(v as f32, first, "conversion was not lossless - {}", context);
730
731 if !seen.insert(v) {
732 panic!("value {first} seen more than oncex - {}", context);
733 }
734 }
735 assert_eq!(
736 seen.len(),
737 ncenters,
738 "not all points were seen - {}",
739 context
740 );
741 }
742
743 #[test]
744 #[cfg(not(miri))]
745 fn sanity_check() {
746 let dims = [1, 16];
747 let ncenters = [1, 5, 20, 255];
748 let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
749
750 for ncenters in ncenters {
751 for dim in dims {
752 sanity_check_impl(ncenters, dim, &mut rng);
753 }
754 }
755 }
756
757 fn fuzzy_sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
760 let repeats_per_center = 3;
761
762 let spreading_multiplier: usize = 16;
764 let perturbation_distribution = Uniform::new(-0.125, 0.125).unwrap();
766
767 let context = lazy_format!(
768 "dim = {}, ncenters = {}, repeats_per_center = {}, multiplier = {}",
769 dim,
770 ncenters,
771 repeats_per_center,
772 spreading_multiplier,
773 );
774
775 let ndata = repeats_per_center * ncenters;
776 let mut values: Vec<f32> = (0..ncenters)
777 .flat_map(|i| {
778 let v: Vec<f32> = (0..repeats_per_center)
780 .map(|_| {
781 (spreading_multiplier * i) as f32 + perturbation_distribution.sample(rng)
782 })
783 .collect();
784
785 v.into_iter()
786 })
787 .collect();
788 assert_eq!(values.len(), ndata);
789
790 values.shuffle(rng);
791 let mut data = Matrix::new(0.0, ndata, dim);
792 for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
793 r.fill(*v);
794 }
795
796 let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
797 kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
798
799 let mut seen = HashSet::<usize>::new();
801 for (i, c) in centers.row_iter().enumerate() {
802 let first = c[0];
803 let v: usize = first.round() as usize;
804 assert_eq!(
805 v % spreading_multiplier,
806 0,
807 "expected row value to be close to a multiple of the spreading multiplier, \
808 instead got {} - {}",
809 v,
810 context
811 );
812 seen.insert(v);
813
814 let mut found = false;
816 for r in data.row_iter() {
817 if r == c {
818 found = true;
819 break;
820 }
821 }
822 if !found {
823 panic!(
824 "center {} was not found in the original dataset - {}",
825 i, context,
826 );
827 }
828 }
829 assert!(
830 seen.len() as f32 >= 0.95 * (ncenters as f32),
831 "expected the distribution of centers to be wide, \
832 instead {} unique values were found - {}",
833 seen.len(),
834 context
835 );
836 }
837
838 #[test]
839 #[cfg(not(miri))]
840 fn fuzzy_sanity_check() {
841 let dims = [1, 16];
842 let ncenters = [0, 1, 5, 20, 255];
844 let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
845
846 for ncenters in ncenters {
847 for dim in dims {
848 fuzzy_sanity_check_impl(ncenters, dim, &mut rng);
849 }
850 }
851 }
852
853 #[test]
855 fn fail_empty_dataset() {
856 let data = Matrix::new(0.0, 0, 5);
857 let mut centers = Matrix::new(0.0, 10, data.ncols());
858
859 let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
860
861 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
862 assert!(
863 result.is_err(),
864 "kmeans++ on an empty dataset with non-empty centers should be an error"
865 );
866 let err = result.unwrap_err();
867 assert_eq!(err.selected, 0);
868 assert_eq!(err.expected, centers.nrows());
869 assert_eq!(err.reason, FailureReason::DatasetTooSmall);
870 assert!(!err.is_numerically_recoverable());
871
872 check_post_conditions(centers.as_view(), data.as_view(), &err);
873 }
874
875 #[test]
876 fn both_empty_is_okay() {
877 let data = Matrix::new(0.0, 0, 5);
878 let mut centers = Matrix::new(0.0, 0, data.ncols());
879 let mut rng = StdRng::seed_from_u64(0x6f7031afd9b5aa18);
880 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
881 assert!(
882 result.is_ok(),
883 "selecting 0 points from an empty dataset is okay"
884 );
885 }
886
887 #[test]
888 fn fail_dataset_not_big_enough() {
889 let ndata = 5;
890 let ncenters = 10;
891 let dim = 5;
892
893 let mut data = Matrix::new(0.0, ndata, dim);
894 set_default_values(data.as_mut_view());
895 let mut centers = Matrix::new(f32::INFINITY, ncenters, data.ncols());
896
897 let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
898
899 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
900 assert!(
901 result.is_err(),
902 "kmeans++ on an empty dataset with non-empty centers should be an error"
903 );
904 let err = result.unwrap_err();
905 assert_eq!(err.selected, data.nrows());
906 assert_eq!(err.expected, centers.nrows());
907 assert_eq!(err.reason, FailureReason::DatasetTooSmall);
908 assert!(err.is_numerically_recoverable());
909
910 check_post_conditions(centers.as_view(), data.as_view(), &err);
911 }
912
913 #[test]
916 fn fail_diversity_check() {
917 let ncenters = 10;
918 let ndata = 50;
919 let dim = 3;
920 let mut rng = StdRng::seed_from_u64(0xca57b032c21bf4bb);
921
922 let repeats_per_center = 10;
924 assert!(ncenters * repeats_per_center > ndata);
925 let mut values: Vec<f32> = (0..utils::div_round_up(ndata, repeats_per_center))
926 .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
927 .collect();
928 assert!(values.len() >= ndata);
929
930 values.shuffle(&mut rng);
931 let mut data = Matrix::new(0.0, ndata, dim);
932 for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
933 r.fill(*v);
934 }
935
936 let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
937 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
938 assert!(
939 result.is_err(),
940 "dataset should not have enough unique points"
941 );
942 let err = result.unwrap_err();
943 assert_eq!(err.selected, utils::div_round_up(ndata, repeats_per_center));
944 assert_eq!(err.expected, centers.nrows());
945 assert_eq!(err.reason, FailureReason::InsufficientDiversity);
946 assert!(err.is_numerically_recoverable());
947
948 check_post_conditions(centers.as_view(), data.as_view(), &err);
949 }
950
951 #[test]
952 fn fail_intinity_check() {
953 let mut data = Matrix::new(0.0, 10, 1);
954 set_default_values(data.as_mut_view());
955
956 data[(6, 0)] = -3.4028235e38;
958 let mut centers = Matrix::new(0.0, 2, 1);
959
960 let mut rng = StdRng::seed_from_u64(0xc0449b2aa4e12f05);
961
962 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
963 assert!(result.is_err(), "result should complain about infinity");
964 let err = result.unwrap_err();
965 assert_eq!(err.selected, 1);
966 assert_eq!(err.expected, centers.nrows());
967 assert_eq!(err.reason, FailureReason::SawInfinity);
968 assert!(!err.is_numerically_recoverable());
969
970 check_post_conditions(centers.as_view(), data.as_view(), &err);
971 }
972
973 #[test]
974 fn fail_nan_check() {
975 let mut data = Matrix::new(0.0, 10, 1);
976 set_default_values(data.as_mut_view());
977
978 data[(6, 0)] = f32::NAN;
980 let mut centers = Matrix::new(0.0, 2, 1);
981
982 let mut rng = StdRng::seed_from_u64(0x55808c6c728c8473);
983
984 let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
985 assert!(result.is_err(), "result should complain about NaN");
986 let err = result.unwrap_err();
987 assert_eq!(err.selected, 1);
988 assert_eq!(err.expected, centers.nrows());
989 assert_eq!(err.reason, FailureReason::SawInfinity);
990 assert!(!err.is_numerically_recoverable());
991
992 check_post_conditions(centers.as_view(), data.as_view(), &err);
993 }
994
995 #[test]
1000 #[should_panic(expected = "new point and dataset must have the same dimension")]
1001 fn update_distances_panics_dim_mismatch() {
1002 let npoints = 5;
1003 let dim = 8;
1004 let mut square_distances = vec![0.0; npoints];
1005 let data = Matrix::new(0.0, npoints, dim);
1006 let norms = vec![0.0; npoints];
1007 let this = vec![0.0; dim + 1]; let this_square_norm = 0.0;
1009 update_distances::<16>(
1010 &mut square_distances,
1011 &BlockTranspose::from_matrix_view(data.as_view()),
1012 &norms,
1013 &this,
1014 this_square_norm,
1015 );
1016 }
1017
1018 #[test]
1019 #[should_panic(expected = "distances buffer and dataset must have the same length")]
1020 fn update_distances_panics_distances_length_mismatch() {
1021 let npoints = 5;
1022 let dim = 8;
1023 let mut square_distances = vec![0.0; npoints + 1]; let data = Matrix::new(0.0, npoints, dim);
1025 let norms = vec![0.0; npoints];
1026 let this = vec![0.0; dim];
1027 let this_square_norm = 0.0;
1028 update_distances::<16>(
1029 &mut square_distances,
1030 &BlockTranspose::from_matrix_view(data.as_view()),
1031 &norms,
1032 &this,
1033 this_square_norm,
1034 );
1035 }
1036
1037 #[test]
1038 #[should_panic(expected = "norms and dataset must have the same length")]
1039 fn update_distances_panics_norms_length_mismatch() {
1040 let npoints = 5;
1041 let dim = 8;
1042 let mut square_distances = vec![0.0; npoints];
1043 let data = Matrix::new(0.0, npoints, dim);
1044 let norms = vec![0.0; npoints + 1]; let this = vec![0.0; dim];
1046 let this_square_norm = 0.0;
1047 update_distances::<16>(
1048 &mut square_distances,
1049 &BlockTranspose::from_matrix_view(data.as_view()),
1050 &norms,
1051 &this,
1052 this_square_norm,
1053 );
1054 }
1055
1056 #[test]
1061 #[should_panic(
1062 expected = "centers output matrix should have the same dimensionality as the dataset"
1063 )]
1064 fn kmeans_plusplus_into_panics_dim_mismatch() {
1065 let mut centers = Matrix::new(0.0, 2, 10);
1066 let data = Matrix::new(0.0, 2, 9);
1067 kmeans_plusplus_into(
1068 centers.as_mut_view(),
1069 data.as_view(),
1070 &mut rand::rngs::ThreadRng::default(),
1071 )
1072 .unwrap();
1073 }
1074}