Skip to main content

diskann_quantization/algorithms/kmeans/
plusplus.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{collections::HashSet, fmt};
7
8use diskann_utils::{
9    strided::StridedView,
10    views::{MatrixView, MutMatrixView},
11};
12use diskann_wide::{SIMDMulAdd, SIMDPartialOrd, SIMDSelect, SIMDVector};
13use rand::{
14    RngCore,
15    distr::{Distribution, Uniform},
16};
17use thiserror::Error;
18
19use super::common::square_norm;
20use crate::multi_vector::{BlockTransposed, BlockTransposedRef};
21
22/// An internal trait implemented for `BlockTransposed` used to accelerate
23///
24/// 1. Computation of distances between a newly selected kmeans++ center and all elements
25///    in the dataset.
26///
27/// 2. Updating of the minimum distance between the dataset and all candidates.
28///
29/// 3. Computing the sum of the squares of all the minimum distances in the dataset so far.
30///
31/// All of these operations are fused for efficiency.
32///
33/// This trait is only meant to be implemented by `BlockTransposed`.
34pub(crate) trait MicroKernel {
35    /// The intermediate value storing inner products.
36    type Intermediate;
37
38    /// The type of the rolling sum used to accumulate the sum of all the squared minimum
39    /// distances.
40    type RollingSum: Default + Copy;
41
42    /// A potentially broadcasted representation of a `f32` number.
43    type Splat: Copy;
44
45    /// Return the broadcasted representation of `x`.
46    fn splat(x: f32) -> Self::Splat;
47
48    /// Compute the distances between `this` and all the vectors stored in `block`.
49    ///
50    /// This implementation works for both full blocks and partial blocks.
51    ///
52    /// # SAFETY
53    ///
54    /// `block` must be the base pointer of a data block in a `BlockTransposed<f32, N>`
55    /// whose column count equals `this.len()`. That is, `N * this.len()` elements must
56    /// be readable from `block`.
57    unsafe fn accum_full(block: *const f32, this: &[f32]) -> Self::Intermediate;
58
59    /// Accumulate intermediate distances and store the result in `mins`.
60    fn finish(
61        intermediate: Self::Intermediate,
62        splat: Self::Splat,
63        rolling_sum: Self::RollingSum,
64        norms: &[f32],
65        mins: &mut [f32],
66    ) -> Self::RollingSum;
67
68    /// Accumulate the first `first` intermediate distances and store the result in `mins`.
69    fn finish_last(
70        intermediate: Self::Intermediate,
71        splat: Self::Splat,
72        rolling_sum: Self::RollingSum,
73        norms: &[f32],
74        mins: &mut [f32],
75        first: usize,
76    ) -> Self::RollingSum;
77
78    /// Turn the rolling sum from its internal representation to a final `f32`.
79    fn complete_sum(x: Self::RollingSum) -> f64;
80}
81
82diskann_wide::alias!(f32s = f32x8);
83
84impl MicroKernel for BlockTransposed<f32, 16> {
85    // Process 16-dimensions concurrently, split across two `Wide`s.
86    type Intermediate = (f32s, f32s);
87    type RollingSum = f64;
88    type Splat = f32s;
89
90    fn splat(x: f32) -> Self::Splat {
91        Self::Splat::splat(diskann_wide::ARCH, x)
92    }
93
94    #[inline(always)]
95    unsafe fn accum_full(block_ptr: *const f32, this: &[f32]) -> Self::Intermediate {
96        let mut s0 = f32s::default(diskann_wide::ARCH);
97        let mut s1 = f32s::default(diskann_wide::ARCH);
98
99        this.iter().enumerate().for_each(|(i, b)| {
100            let b = f32s::splat(diskann_wide::ARCH, *b);
101
102            // SAFETY: Each block stores `16 * ncols` contiguous f32s (GROUP=16).
103            // The caller guarantees `ncols == this.len()`, so `16 * this.len()`
104            // elements are readable from `block_ptr` and `i < this.len()`.
105            let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i)) };
106            s0 = a.mul_add_simd(b, s0);
107
108            // SAFETY: Same as above; offset `16 * i + 8 < 16 * this.len()`.
109            let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i + 8)) };
110            s1 = a.mul_add_simd(b, s1);
111        });
112
113        // Apply the final -2.0 transformation.
114        let negative = f32s::splat(diskann_wide::ARCH, -2.0);
115        (s0 * negative, s1 * negative)
116    }
117
118    #[inline(always)]
119    fn finish(
120        intermediate: Self::Intermediate,
121        splat: Self::Splat,
122        rolling_sum: Self::RollingSum,
123        norms: &[f32],
124        mins: &mut [f32],
125    ) -> Self::RollingSum {
126        assert_eq!(norms.len(), 16);
127        assert_eq!(mins.len(), 16);
128
129        // SAFETY: `norms` has length 16 - this loads the first 8.
130        let norms0 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr()) };
131        // SAFETY: `norms` has length 16 - this loads the last 8.
132        let norms1 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr().add(8)) };
133
134        let distances0 = norms0 + splat + intermediate.0;
135        let distances1 = norms1 + splat + intermediate.1;
136
137        // SAFETY: `mins` has length 16 - this loads the first 8.
138        let current_distances0 = unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr()) };
139        // SAFETY: `mins` has length 16 - this loads the last 8.
140        let current_distances1 =
141            unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr().add(8)) };
142
143        let mask0 = distances0.lt_simd(current_distances0);
144        let mask1 = distances1.lt_simd(current_distances1);
145
146        let current_distances0 = mask0.select(distances0, current_distances0);
147        let current_distances1 = mask1.select(distances1, current_distances1);
148
149        // SAFETY: `mins` has length 16 - this stores the first 8.
150        unsafe { current_distances0.store_simd(mins.as_mut_ptr()) };
151        // SAFETY: `mins` has length 16 - this stores the last 8.
152        unsafe { current_distances1.store_simd(mins.as_mut_ptr().add(8)) };
153
154        rolling_sum
155            + std::iter::zip(
156                current_distances0.to_array().iter(),
157                current_distances1.to_array().iter(),
158            )
159            .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
160            .sum::<f64>()
161    }
162
163    #[inline(always)]
164    fn finish_last(
165        intermediate: Self::Intermediate,
166        splat: Self::Splat,
167        rolling_sum: Self::RollingSum,
168        norms: &[f32],
169        mins: &mut [f32],
170        first: usize,
171    ) -> Self::RollingSum {
172        // Check 1.
173        assert_eq!(norms.len(), first);
174        // Check 2.
175        assert_eq!(mins.len(), first);
176
177        let lo = first.min(8);
178        let hi = first - lo;
179
180        // SAFETY: This loads `first.min(8)` elements from `norms`, which is valid
181        // by check 1.
182        let norms0 = unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr(), lo) };
183        let norms1 = if hi == 0 {
184            f32s::default(diskann_wide::ARCH)
185        } else {
186            // SAFETY: This is only called if `first > 8`, which means `norms.len() > 8` by
187            // check 1. Therefore, we can load `first - 8` elements from `norms.as_ptr() + 8`.
188            unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr().add(8), hi) }
189        };
190
191        let distances0 = norms0 + splat + intermediate.0;
192        let distances1 = norms1 + splat + intermediate.1;
193
194        // SAFETY: Same logic as the load for `norms0`.
195        let current_distances0 =
196            unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr(), lo) };
197        let current_distances1 = if hi == 0 {
198            f32s::default(diskann_wide::ARCH)
199        } else {
200            // SAFETY: Same logic as the load for `norms1`.
201            unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr().add(8), hi) }
202        };
203
204        let mask0 = distances0.lt_simd(current_distances0);
205        let mask1 = distances1.lt_simd(current_distances1);
206
207        let current_distances0 = mask0.select(distances0, current_distances0);
208        let current_distances1 = mask1.select(distances1, current_distances1);
209
210        // SAFETY: As per the logic for the load above, it is safe to store at least `lo`
211        // elements from the base pointer.
212        unsafe { current_distances0.store_simd_first(mins.as_mut_ptr(), lo) };
213        if hi != 0 {
214            // SAFETY: If `hi != 0`, then `first` must be at least 9. Therefore, adding 8
215            // to the base pointer is valid, as is storing `first - 8` elements to that
216            // pointer.
217            unsafe { current_distances1.store_simd_first(mins.as_mut_ptr().add(8), hi) };
218        }
219
220        rolling_sum
221            + std::iter::zip(
222                current_distances0.to_array().iter(),
223                current_distances1.to_array().iter(),
224            )
225            .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
226            .sum::<f64>()
227    }
228
229    fn complete_sum(x: Self::RollingSum) -> f64 {
230        x
231    }
232}
233
234/// Update `square_distances` to contain the minimum of its current value and the distance
235/// between each element in `transpose` and `this`.
236///
237/// Return the sum of the new `square_distances`.
238fn update_distances<const N: usize>(
239    square_distances: &mut [f32],
240    transpose: BlockTransposedRef<'_, f32, N>,
241    norms: &[f32],
242    this: &[f32],
243    this_square_norm: f32,
244) -> f64
245where
246    BlockTransposed<f32, N>: MicroKernel,
247{
248    // Establish our safety requirements.
249    // Check 1.
250    assert_eq!(
251        this.len(),
252        transpose.ncols(),
253        "new point and dataset must have the same dimension",
254    );
255    // Check 2.
256    assert_eq!(
257        square_distances.len(),
258        transpose.nrows(),
259        "distances buffer and dataset must have the same length",
260    );
261    // Check 3.
262    assert_eq!(
263        norms.len(),
264        transpose.nrows(),
265        "norms and dataset must have the same length",
266    );
267
268    let splat = BlockTransposed::<f32, N>::splat(this_square_norm);
269    let mut rolling_sum = <BlockTransposed<f32, N> as MicroKernel>::RollingSum::default();
270
271    let iter =
272        std::iter::zip(norms.chunks_exact(N), square_distances.chunks_exact_mut(N)).enumerate();
273    iter.for_each(|(block, (these_norms, these_distances))| {
274        debug_assert!(block < transpose.num_blocks());
275        // SAFETY: Because `transpose.nrows() == norms.len()`, the number of full blocks in
276        // `transpose` is `norms.nrows() / N` and therefore the induction variable `block`
277        // is less that `transpose.full_blocks()`.
278        let base = unsafe { transpose.block_ptr_unchecked(block) };
279
280        // SAFETY: The pointer `base` does point to a full block and by Check 1,
281        // `transpose.ncols() == this.len()`.
282        let intermediate = unsafe { BlockTransposed::<f32, N>::accum_full(base, this) };
283
284        rolling_sum = BlockTransposed::<f32, N>::finish(
285            intermediate,
286            splat,
287            rolling_sum,
288            these_norms,
289            these_distances,
290        );
291    });
292
293    // Do the last iteration if there is an un-even number of rows.
294    let remainder = transpose.remainder();
295    if remainder != 0 {
296        // SAFETY: We've checked that there is a `remainder` block. Therefore,
297        // `transpose.full_blocks() < transpose.num_blocks()`.
298        let base = unsafe { transpose.block_ptr_unchecked(transpose.full_blocks()) };
299
300        // A full accumulation is fine because `BlockTransposed` allocates at the granularity
301        // of blocks. We will just ignore the extra lanes.
302        // SAFETY: The pointer `base` does point to a full block and by Check 1,
303        // `transpose.ncols() == this.len()`.
304        let intermediate = unsafe { BlockTransposed::<f32, N>::accum_full(base, this) };
305
306        let start = N * transpose.full_blocks();
307        rolling_sum = BlockTransposed::<f32, N>::finish_last(
308            intermediate,
309            splat,
310            rolling_sum,
311            &norms[start..],
312            &mut square_distances[start..],
313            remainder,
314        );
315    }
316
317    BlockTransposed::<f32, N>::complete_sum(rolling_sum)
318}
319
320#[derive(Debug, Clone, Copy, PartialEq)]
321pub enum FailureReason {
322    /// This error happens when the dataset contains fewer than the requested number of
323    /// centers.
324    DatasetTooSmall,
325    /// The dataset contains fewer than the requested number of points.
326    InsufficientDiversity,
327    /// Infinity was observed (this also happens when a NaN is present in the input data)
328    SawInfinity,
329}
330
331impl FailureReason {
332    pub fn is_numerically_recoverable(self) -> bool {
333        match self {
334            // Datasets being too small is recoverable from a `kmeans` perspective, we can
335            // simply proceed with fewer points.
336            Self::DatasetTooSmall | Self::InsufficientDiversity => true,
337
338            // If we see Infinity, downstream algorithms will likely just break.
339            // Don't expect to recover.
340            Self::SawInfinity => false,
341        }
342    }
343}
344
345impl fmt::Display for FailureReason {
346    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
347        let reason: &str = match self {
348            Self::DatasetTooSmall => "dataset does not have enough points",
349            Self::InsufficientDiversity => "dataset is insufficiently diverse",
350            Self::SawInfinity => "a value of infinity or NaN was observed",
351        };
352        f.write_str(reason)
353    }
354}
355
356#[derive(Debug, Clone, Copy, Error)]
357#[error("only populated {selected} of {expected} points because {reason}")]
358pub struct KMeansPlusPlusError {
359    /// The number of points that were selected.
360    pub selected: usize,
361    /// The number of points that were expected.
362    pub expected: usize,
363    /// A concrete reason for the failure.
364    pub reason: FailureReason,
365}
366
367impl KMeansPlusPlusError {
368    fn new(selected: usize, expected: usize, reason: FailureReason) -> Self {
369        Self {
370            selected,
371            expected,
372            reason,
373        }
374    }
375
376    pub fn is_numerically_recoverable(&self) -> bool {
377        self.reason.is_numerically_recoverable() && self.selected > 0
378    }
379}
380
381pub(crate) fn kmeans_plusplus_into_inner<const N: usize>(
382    mut points: MutMatrixView<'_, f32>,
383    data: StridedView<'_, f32>,
384    transpose: BlockTransposedRef<'_, f32, N>,
385    norms: &[f32],
386    rng: &mut dyn RngCore,
387) -> Result<(), KMeansPlusPlusError>
388where
389    BlockTransposed<f32, N>: MicroKernel,
390{
391    assert_eq!(norms.len(), data.nrows());
392    assert_eq!(transpose.nrows(), data.nrows());
393    assert_eq!(transpose.ncols(), data.ncols());
394    assert_eq!(points.ncols(), data.ncols());
395
396    // Zero the argument
397    points.as_mut_slice().fill(0.0);
398    let expected = points.nrows();
399
400    // Is someone trying to operate on an empty dataset?
401    //
402    // We can determine this by constructing a `Uniform` distribution over the rows and
403    // checking if the resulting range is Empty.
404    let all_rows = match Uniform::new(0, data.nrows()) {
405        Ok(dist) => dist,
406        Err(_) => {
407            // If they want 0 points, then I guess this is okay.
408            return if expected == 0 {
409                Ok(())
410            } else {
411                Err(KMeansPlusPlusError::new(
412                    0,
413                    expected,
414                    FailureReason::DatasetTooSmall,
415                ))
416            };
417        }
418    };
419
420    let mut min_distances: Vec<f32> = vec![f32::INFINITY; data.nrows()];
421    let mut picked = HashSet::with_capacity(expected);
422
423    // Pick the first point randomly.
424    let mut previous_square_norm = {
425        let i = all_rows.sample(rng);
426        points.row_mut(0).copy_from_slice(data.row(i));
427        picked.insert(i);
428        norms[i]
429    };
430
431    let mut selected = 1;
432    for current in 1..expected.min(data.nrows()) {
433        let last = points.row(current - 1);
434        let s = update_distances(
435            &mut min_distances,
436            transpose,
437            norms,
438            last,
439            previous_square_norm,
440        );
441
442        // Pick a threshold.
443        // Due to the way we compute distances, values less than 0.0 are technically
444        // possible.
445        match Uniform::<f64>::new(0.0, s) {
446            Ok(distribution) => {
447                let threshold = distribution.sample(rng);
448                let mut rolling_sum: f64 = 0.0;
449                for (i, d) in min_distances.iter().enumerate() {
450                    rolling_sum += <f32 as Into<f64>>::into(*d);
451                    if rolling_sum >= threshold && (*d > 0.0) && !picked.contains(&i) {
452                        // This point is the winner.
453                        // Copy it over and update our scratch variables.
454                        points.row_mut(current).clone_from_slice(data.row(i));
455                        picked.insert(i);
456                        previous_square_norm = norms[i];
457                        selected = current + 1;
458                        break;
459                    }
460                }
461            }
462            // If the range is empty, this implies that `s == 0.0`.
463            //
464            // In this case, we skip and try again,
465            Err(rand::distr::uniform::Error::EmptyRange) => {}
466            // The upper bound is infinite - this is an error.
467            Err(rand::distr::uniform::Error::NonFinite) => {
468                return Err(KMeansPlusPlusError::new(
469                    selected,
470                    expected,
471                    FailureReason::SawInfinity,
472                ));
473            }
474        }
475
476        // If we successfully picked a row, than `selected == current`.
477        //
478        // If this is not the case, then we failed due to insufficient diversity.
479        if selected != (current + 1) {
480            return Err(KMeansPlusPlusError::new(
481                selected,
482                expected,
483                FailureReason::InsufficientDiversity,
484            ));
485        }
486    }
487
488    // We may have terminated early due to the dataset being too small.
489    if selected != expected {
490        Err(KMeansPlusPlusError::new(
491            selected,
492            expected,
493            FailureReason::DatasetTooSmall,
494        ))
495    } else {
496        Ok(())
497    }
498}
499
500pub fn kmeans_plusplus_into(
501    centers: MutMatrixView<'_, f32>,
502    data: MatrixView<'_, f32>,
503    rng: &mut dyn RngCore,
504) -> Result<(), KMeansPlusPlusError> {
505    assert_eq!(
506        centers.ncols(),
507        data.ncols(),
508        "centers output matrix should have the same dimensionality as the dataset"
509    );
510
511    const GROUPSIZE: usize = 16;
512    let mut norms: Vec<f32> = vec![0.0; data.nrows()];
513
514    for (n, d) in std::iter::zip(norms.iter_mut(), data.row_iter()) {
515        *n = square_norm(d);
516    }
517
518    let transpose = BlockTransposed::<f32, GROUPSIZE>::from_matrix_view(data);
519    kmeans_plusplus_into_inner(centers, data.into(), transpose.as_view(), &norms, rng)
520}
521
522#[cfg(test)]
523mod tests {
524    use diskann_utils::{lazy_format, views::Matrix};
525    use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
526    use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
527
528    use super::*;
529    use crate::utils;
530
531    fn is_in(needle: &[f32], haystack: MatrixView<'_, f32>) -> bool {
532        assert_eq!(needle.len(), haystack.ncols());
533        haystack.row_iter().any(|row| row == needle)
534    }
535
536    fn check_post_conditions(
537        centers: MatrixView<'_, f32>,
538        data: MatrixView<'_, f32>,
539        err: &KMeansPlusPlusError,
540    ) {
541        assert_eq!(err.expected, centers.nrows());
542        assert!(err.expected > err.selected);
543        for i in 0..err.selected {
544            assert!(is_in(centers.row(i), data.as_view()));
545        }
546        for i in err.selected..centers.nrows() {
547            assert!(centers.row(i).iter().all(|j| *j == 0.0));
548        }
549    }
550
551    ///////////////////
552    // Error Display //
553    ///////////////////
554
555    #[test]
556    fn test_error_display() {
557        assert_eq!(
558            format!("{}", FailureReason::DatasetTooSmall),
559            "dataset does not have enough points"
560        );
561
562        assert_eq!(
563            format!("{}", FailureReason::InsufficientDiversity),
564            "dataset is insufficiently diverse"
565        );
566
567        assert_eq!(
568            format!("{}", FailureReason::SawInfinity),
569            "a value of infinity or NaN was observed"
570        );
571    }
572
573    //////////////////////
574    // Update Distances //
575    //////////////////////
576
577    /// Seed `x` (a KxN matrix) with the following:
578    /// ```text
579    /// 0,   1,   2,   3 ...   N-1
580    /// 1,   2,   3,   4 ...   N
581    /// ...
582    /// K-1, K,   K+1, K+3 ... N+K-2
583    /// ```
584    fn set_default_values(mut x: MutMatrixView<'_, f32>) {
585        for (i, row) in x.row_iter_mut().enumerate() {
586            for (j, r) in row.iter_mut().enumerate() {
587                *r = (i + j) as f32;
588            }
589        }
590    }
591
592    // The implementation of `update_distances` is not particularly unsafe with relation
593    // to `dim`, but *is* potentially unsafe with respect to `num_points`.
594    //
595    // Therefore, we need to be sure we sweep sufficient values of `num_points` to get full
596    // coverage with Miri.
597    //
598    // This works our to our advantage because smaller `dims` mean we can have precise
599    // floating point values.
600    fn test_update_distances_impl<const N: usize, R>(num_points: usize, dim: usize, rng: &mut R)
601    where
602        BlockTransposed<f32, N>: MicroKernel,
603        R: Rng,
604    {
605        let context = lazy_format!(
606            "setup: N = {}, num_points = {}, dim = {}",
607            N,
608            num_points,
609            dim
610        );
611
612        let mut data = Matrix::<f32>::new(0.0, num_points, dim);
613        set_default_values(data.as_mut_view());
614
615        let square_norms: Vec<f32> = data.row_iter().map(square_norm).collect();
616
617        // The sample points we are computing the distances against.
618        let num_samples = 3;
619        let mut samples = Matrix::<f32>::new(0.0, num_samples, dim);
620        let mut distances = vec![f32::INFINITY; num_points];
621        let distribution = Uniform::<u32>::new(0, (num_points + dim) as u32).unwrap();
622        let transpose = BlockTransposed::<f32, N>::from_matrix_view(data.as_view());
623
624        let mut last_residual = f64::INFINITY;
625        for i in 0..num_samples {
626            // Pick a sample.
627            {
628                let row = samples.row_mut(i);
629                row.iter_mut().for_each(|r| {
630                    *r = distribution.sample(rng) as f32;
631                });
632            }
633            let row = samples.row(i);
634            let norm = square_norm(row);
635
636            let residual = update_distances(
637                &mut distances,
638                transpose.as_view(),
639                &square_norms,
640                row,
641                norm,
642            );
643
644            // Make sure all the distances are correct.
645            for (n, (d, data)) in std::iter::zip(distances.iter(), data.row_iter()).enumerate() {
646                let mut min_distance = f32::INFINITY;
647                for j in 0..=i {
648                    let distance = SquaredL2::evaluate(samples.row(j), data);
649                    min_distance = min_distance.min(distance);
650                }
651                assert_eq!(
652                    min_distance, *d,
653                    "failed on row {n} on iteration {i}. {}",
654                    context
655                );
656            }
657
658            // The distances match - ensure that the residual was computed properly.
659            assert_eq!(
660                residual,
661                distances.iter().sum::<f32>() as f64,
662                "residual sum failed on iteration {i} - {}",
663                context
664            );
665
666            // Finally - make sure the residual is dropping.
667            assert!(
668                residual <= last_residual,
669                "residual check failed on iteration {}, last = {}, this = {} - {}",
670                i,
671                last_residual,
672                residual,
673                context
674            );
675
676            last_residual = residual;
677        }
678    }
679
680    /// A note on testing methodology:
681    ///
682    /// This function targets running under Miri to test the indexing logic in
683    /// `update_distances`.
684    ///
685    /// It is appropriate for implementations what are not "unsafe" in `dim` but are
686    /// "unsafe" in `num_points`.
687    ///
688    /// In other words, the SIMD operations we are tracking block along `num_points` and
689    /// not `dim`. This lets us run at a much smaller `dim` to help Miri finish more quickly.
690    #[test]
691    fn test_update_distances() {
692        let mut rng = StdRng::seed_from_u64(0x56c94b53c73e4fd9);
693        for num_points in 0..48 {
694            #[cfg(miri)]
695            if num_points % 7 != 0 {
696                continue;
697            }
698
699            for dim in 1..4 {
700                test_update_distances_impl(num_points, dim, &mut rng);
701            }
702        }
703    }
704
705    //////////////
706    // Kmeans++ //
707    //////////////
708
709    // Kmeans++ sanity checks - if there are only `N` distinct and we want `N` centers,
710    // then all `N` should be selected without repeats.
711    #[cfg(not(miri))]
712    fn sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
713        let repeats_per_center = 3;
714        let context = lazy_format!(
715            "dim = {}, ncenters = {}, repeats_per_center = {}",
716            dim,
717            ncenters,
718            repeats_per_center
719        );
720
721        let ndata = repeats_per_center * ncenters;
722        let mut values: Vec<f32> = (0..ncenters)
723            .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
724            .collect();
725        assert_eq!(values.len(), ndata);
726
727        values.shuffle(rng);
728        let mut data = Matrix::new(0.0, ndata, dim);
729        for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
730            r.fill(*v);
731        }
732
733        let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
734        kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
735
736        // Make sure that each value was selected for a center.
737        let mut seen = HashSet::<usize>::new();
738        for c in centers.row_iter() {
739            let first = c[0];
740            assert!(c.iter().all(|i| *i == first));
741
742            let v: usize = first.round() as usize;
743            assert_eq!(v as f32, first, "conversion was not lossless - {}", context);
744
745            if !seen.insert(v) {
746                panic!("value {first} seen more than oncex - {}", context);
747            }
748        }
749        assert_eq!(
750            seen.len(),
751            ncenters,
752            "not all points were seen - {}",
753            context
754        );
755    }
756
757    #[test]
758    #[cfg(not(miri))]
759    fn sanity_check() {
760        let dims = [1, 16];
761        let ncenters = [1, 5, 20, 255];
762        let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
763
764        for ncenters in ncenters {
765            for dim in dims {
766                sanity_check_impl(ncenters, dim, &mut rng);
767            }
768        }
769    }
770
771    // This test is like the sanity check - but instead of exact repeats, we use slightly
772    // perturbed values to test that the proportionality is of distances is respected.
773    #[cfg(not(miri))]
774    fn fuzzy_sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
775        let repeats_per_center = 3;
776
777        // A spreading coefficient to space-out points.
778        let spreading_multiplier: usize = 16;
779        // Purturbation distribution to apply to the input data.
780        let perturbation_distribution = Uniform::new(-0.125, 0.125).unwrap();
781
782        let context = lazy_format!(
783            "dim = {}, ncenters = {}, repeats_per_center = {}, multiplier = {}",
784            dim,
785            ncenters,
786            repeats_per_center,
787            spreading_multiplier,
788        );
789
790        let ndata = repeats_per_center * ncenters;
791        let mut values: Vec<f32> = (0..ncenters)
792            .flat_map(|i| {
793                // We need to bounce through a vec to avoid borrowing issues.
794                let v: Vec<f32> = (0..repeats_per_center)
795                    .map(|_| {
796                        (spreading_multiplier * i) as f32 + perturbation_distribution.sample(rng)
797                    })
798                    .collect();
799
800                v.into_iter()
801            })
802            .collect();
803        assert_eq!(values.len(), ndata);
804
805        values.shuffle(rng);
806        let mut data = Matrix::new(0.0, ndata, dim);
807        for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
808            r.fill(*v);
809        }
810
811        let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
812        kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
813
814        // Make sure that each value was selected for a center.
815        let mut seen = HashSet::<usize>::new();
816        for (i, c) in centers.row_iter().enumerate() {
817            let first = c[0];
818            let v: usize = first.round() as usize;
819            assert_eq!(
820                v % spreading_multiplier,
821                0,
822                "expected row value to be close to a multiple of the spreading multiplier, \
823                 instead got {} - {}",
824                v,
825                context
826            );
827            seen.insert(v);
828
829            // Make sure the center is equal to one of the data points.
830            let mut found = false;
831            for r in data.row_iter() {
832                if r == c {
833                    found = true;
834                    break;
835                }
836            }
837            if !found {
838                panic!(
839                    "center {} was not found in the original dataset - {}",
840                    i, context,
841                );
842            }
843        }
844        assert!(
845            seen.len() as f32 >= 0.95 * (ncenters as f32),
846            "expected the distribution of centers to be wide, \
847             instead {} unique values were found - {}",
848            seen.len(),
849            context
850        );
851    }
852
853    #[test]
854    #[cfg(not(miri))]
855    fn fuzzy_sanity_check() {
856        let dims = [1, 16];
857        // Apparently passing in `0` for `ncenters` works in a well-defined way.
858        let ncenters = [0, 1, 5, 20, 255];
859        let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
860
861        for ncenters in ncenters {
862            for dim in dims {
863                fuzzy_sanity_check_impl(ncenters, dim, &mut rng);
864            }
865        }
866    }
867
868    // Failure modes
869    #[test]
870    fn fail_empty_dataset() {
871        let data = Matrix::new(0.0, 0, 5);
872        let mut centers = Matrix::new(0.0, 10, data.ncols());
873
874        let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
875
876        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
877        assert!(
878            result.is_err(),
879            "kmeans++ on an empty dataset with non-empty centers should be an error"
880        );
881        let err = result.unwrap_err();
882        assert_eq!(err.selected, 0);
883        assert_eq!(err.expected, centers.nrows());
884        assert_eq!(err.reason, FailureReason::DatasetTooSmall);
885        assert!(!err.is_numerically_recoverable());
886
887        check_post_conditions(centers.as_view(), data.as_view(), &err);
888    }
889
890    #[test]
891    fn both_empty_is_okay() {
892        let data = Matrix::new(0.0, 0, 5);
893        let mut centers = Matrix::new(0.0, 0, data.ncols());
894        let mut rng = StdRng::seed_from_u64(0x6f7031afd9b5aa18);
895        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
896        assert!(
897            result.is_ok(),
898            "selecting 0 points from an empty dataset is okay"
899        );
900    }
901
902    #[test]
903    fn fail_dataset_not_big_enough() {
904        let ndata = 5;
905        let ncenters = 10;
906        let dim = 5;
907
908        let mut data = Matrix::new(0.0, ndata, dim);
909        set_default_values(data.as_mut_view());
910        let mut centers = Matrix::new(f32::INFINITY, ncenters, data.ncols());
911
912        let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
913
914        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
915        assert!(
916            result.is_err(),
917            "kmeans++ on an empty dataset with non-empty centers should be an error"
918        );
919        let err = result.unwrap_err();
920        assert_eq!(err.selected, data.nrows());
921        assert_eq!(err.expected, centers.nrows());
922        assert_eq!(err.reason, FailureReason::DatasetTooSmall);
923        assert!(err.is_numerically_recoverable());
924
925        check_post_conditions(centers.as_view(), data.as_view(), &err);
926    }
927
928    // In this test - we ensure that we process as much of a non-diverse dataset as we can
929    // before returning an error.
930    #[test]
931    fn fail_diversity_check() {
932        let ncenters = 10;
933        let ndata = 50;
934        let dim = 3;
935        let mut rng = StdRng::seed_from_u64(0xca57b032c21bf4bb);
936
937        // Make sure the dataset only contains 5 unique values.
938        let repeats_per_center = 10;
939        assert!(ncenters * repeats_per_center > ndata);
940        let mut values: Vec<f32> = (0..utils::div_round_up(ndata, repeats_per_center))
941            .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
942            .collect();
943        assert!(values.len() >= ndata);
944
945        values.shuffle(&mut rng);
946        let mut data = Matrix::new(0.0, ndata, dim);
947        for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
948            r.fill(*v);
949        }
950
951        let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
952        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
953        assert!(
954            result.is_err(),
955            "dataset should not have enough unique points"
956        );
957        let err = result.unwrap_err();
958        assert_eq!(err.selected, utils::div_round_up(ndata, repeats_per_center));
959        assert_eq!(err.expected, centers.nrows());
960        assert_eq!(err.reason, FailureReason::InsufficientDiversity);
961        assert!(err.is_numerically_recoverable());
962
963        check_post_conditions(centers.as_view(), data.as_view(), &err);
964    }
965
966    #[test]
967    fn fail_intinity_check() {
968        let mut data = Matrix::new(0.0, 10, 1);
969        set_default_values(data.as_mut_view());
970
971        // A very large value that will overflow to infinity when computing the norm.
972        data[(6, 0)] = -3.4028235e38;
973        let mut centers = Matrix::new(0.0, 2, 1);
974
975        let mut rng = StdRng::seed_from_u64(0xc0449b2aa4e12f05);
976
977        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
978        assert!(result.is_err(), "result should complain about infinity");
979        let err = result.unwrap_err();
980        assert_eq!(err.selected, 1);
981        assert_eq!(err.expected, centers.nrows());
982        assert_eq!(err.reason, FailureReason::SawInfinity);
983        assert!(!err.is_numerically_recoverable());
984
985        check_post_conditions(centers.as_view(), data.as_view(), &err);
986    }
987
988    #[test]
989    fn fail_nan_check() {
990        let mut data = Matrix::new(0.0, 10, 1);
991        set_default_values(data.as_mut_view());
992
993        // A very large value that will overflow to infinity when computing the norm.
994        data[(6, 0)] = f32::NAN;
995        let mut centers = Matrix::new(0.0, 2, 1);
996
997        let mut rng = StdRng::seed_from_u64(0x55808c6c728c8473);
998
999        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
1000        assert!(result.is_err(), "result should complain about NaN");
1001        let err = result.unwrap_err();
1002        assert_eq!(err.selected, 1);
1003        assert_eq!(err.expected, centers.nrows());
1004        assert_eq!(err.reason, FailureReason::SawInfinity);
1005        assert!(!err.is_numerically_recoverable());
1006
1007        check_post_conditions(centers.as_view(), data.as_view(), &err);
1008    }
1009
1010    ///////////////////////////////
1011    // Panics - update_distances //
1012    ///////////////////////////////
1013
1014    #[test]
1015    #[should_panic(expected = "new point and dataset must have the same dimension")]
1016    fn update_distances_panics_dim_mismatch() {
1017        let npoints = 5;
1018        let dim = 8;
1019        let mut square_distances = vec![0.0; npoints];
1020        let data = Matrix::new(0.0, npoints, dim);
1021        let norms = vec![0.0; npoints];
1022        let this = vec![0.0; dim + 1]; // Incorrect
1023        let this_square_norm = 0.0;
1024        update_distances::<16>(
1025            &mut square_distances,
1026            BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
1027            &norms,
1028            &this,
1029            this_square_norm,
1030        );
1031    }
1032
1033    #[test]
1034    #[should_panic(expected = "distances buffer and dataset must have the same length")]
1035    fn update_distances_panics_distances_length_mismatch() {
1036        let npoints = 5;
1037        let dim = 8;
1038        let mut square_distances = vec![0.0; npoints + 1]; // Incorrect
1039        let data = Matrix::new(0.0, npoints, dim);
1040        let norms = vec![0.0; npoints];
1041        let this = vec![0.0; dim];
1042        let this_square_norm = 0.0;
1043        update_distances::<16>(
1044            &mut square_distances,
1045            BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
1046            &norms,
1047            &this,
1048            this_square_norm,
1049        );
1050    }
1051
1052    #[test]
1053    #[should_panic(expected = "norms and dataset must have the same length")]
1054    fn update_distances_panics_norms_length_mismatch() {
1055        let npoints = 5;
1056        let dim = 8;
1057        let mut square_distances = vec![0.0; npoints];
1058        let data = Matrix::new(0.0, npoints, dim);
1059        let norms = vec![0.0; npoints + 1]; // Incorrect
1060        let this = vec![0.0; dim];
1061        let this_square_norm = 0.0;
1062        update_distances::<16>(
1063            &mut square_distances,
1064            BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
1065            &norms,
1066            &this,
1067            this_square_norm,
1068        );
1069    }
1070
1071    ///////////////////////////////////
1072    // Panics - kmeans_plusplus_into //
1073    ///////////////////////////////////
1074
1075    #[test]
1076    #[should_panic(
1077        expected = "centers output matrix should have the same dimensionality as the dataset"
1078    )]
1079    fn kmeans_plusplus_into_panics_dim_mismatch() {
1080        let mut centers = Matrix::new(0.0, 2, 10);
1081        let data = Matrix::new(0.0, 2, 9);
1082        kmeans_plusplus_into(
1083            centers.as_mut_view(),
1084            data.as_view(),
1085            &mut rand::rngs::ThreadRng::default(),
1086        )
1087        .unwrap();
1088    }
1089}