Skip to main content

diskann_quantization/algorithms/kmeans/
plusplus.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{collections::HashSet, fmt};
7
8use diskann_utils::{
9    strided::StridedView,
10    views::{MatrixView, MutMatrixView},
11};
12use diskann_wide::{SIMDMulAdd, SIMDPartialOrd, SIMDSelect, SIMDVector};
13use rand::{
14    RngCore,
15    distr::{Distribution, Uniform},
16};
17use thiserror::Error;
18
19use super::common::{BlockTranspose, square_norm};
20
21/// An internal trait implemented for `BlockTranspose` used to accelerate
22///
23/// 1. Computation of distances between a newly selected kmeans++ center and all elements
24///    in the dataset.
25///
26/// 2. Updating of the minimum distance between the dataset and all candidates.
27///
28/// 3. Computing the sum of the squares of all the minimum distances in the dataset so far.
29///
30/// All of these operations are fused for efficiency.
31///
32/// This trait is only meant to be implemented by `BlockTranspose`.
33pub(crate) trait MicroKernel {
34    /// The intermediate value storing inner products.
35    type Intermediate;
36
37    /// The type of the rolling sum used to accumulate the sum of all the squared minimum
38    /// distances.
39    type RollingSum: Default + Copy;
40
41    /// A potentially broadcasted representation of a `f32` number.
42    type Splat: Copy;
43
44    /// Return the broadcasted representation of `x`.
45    fn splat(x: f32) -> Self::Splat;
46
47    /// Compute the distances between `this` and all the vectors stored in `block`.
48    ///
49    /// This implementation works for both full blocks and partial blocks.
50    ///
51    /// # SAFETY
52    ///
53    /// `block` must be the base pointer of a data block in a `BlockTranspose` and the
54    /// block size of this block must have the same length as `this`.
55    unsafe fn accum_full(block: *const f32, this: &[f32]) -> Self::Intermediate;
56
57    /// Accumulate intermediate distances and store the result in `mins`.
58    fn finish(
59        intermediate: Self::Intermediate,
60        splat: Self::Splat,
61        rolling_sum: Self::RollingSum,
62        norms: &[f32],
63        mins: &mut [f32],
64    ) -> Self::RollingSum;
65
66    /// Accumulate the first `first` intermediate distances and store the result in `mins`.
67    fn finish_last(
68        intermediate: Self::Intermediate,
69        splat: Self::Splat,
70        rolling_sum: Self::RollingSum,
71        norms: &[f32],
72        mins: &mut [f32],
73        first: usize,
74    ) -> Self::RollingSum;
75
76    /// Turn the rolling sum from its internal representation to a final `f32`.
77    fn complete_sum(x: Self::RollingSum) -> f64;
78}
79
80diskann_wide::alias!(f32s = f32x8);
81
82impl MicroKernel for BlockTranspose<16> {
83    // Process 16-dimensions concurrently, split across two `Wide`s.
84    type Intermediate = (f32s, f32s);
85    type RollingSum = f64;
86    type Splat = f32s;
87
88    fn splat(x: f32) -> Self::Splat {
89        Self::Splat::splat(diskann_wide::ARCH, x)
90    }
91
92    #[inline(always)]
93    unsafe fn accum_full(block_ptr: *const f32, this: &[f32]) -> Self::Intermediate {
94        let mut s0 = f32s::default(diskann_wide::ARCH);
95        let mut s1 = f32s::default(diskann_wide::ARCH);
96
97        this.iter().enumerate().for_each(|(i, b)| {
98            let b = f32s::splat(diskann_wide::ARCH, *b);
99
100            // SAFETY: From the requirement that `self.block_size() == this.len()`, then
101            // `this.len() * 16` elements are readible from `block_ptr` and `i < this.len()`.
102            let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i)) };
103            s0 = a.mul_add_simd(b, s0);
104
105            // SAFETY: From the requirement that `self.block_size() == this.len()`, then
106            // `this.len() * 16` elements are readible from `block_ptr` and `i < this.len()`.
107            let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i + 8)) };
108            s1 = a.mul_add_simd(b, s1);
109        });
110
111        // Apply the final -2.0 transformation.
112        let negative = f32s::splat(diskann_wide::ARCH, -2.0);
113        (s0 * negative, s1 * negative)
114    }
115
116    #[inline(always)]
117    fn finish(
118        intermediate: Self::Intermediate,
119        splat: Self::Splat,
120        rolling_sum: Self::RollingSum,
121        norms: &[f32],
122        mins: &mut [f32],
123    ) -> Self::RollingSum {
124        assert_eq!(norms.len(), 16);
125        assert_eq!(mins.len(), 16);
126
127        // SAFETY: `norms` has length 16 - this loads the first 8.
128        let norms0 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr()) };
129        // SAFETY: `norms` has length 16 - this loads the last 8.
130        let norms1 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr().add(8)) };
131
132        let distances0 = norms0 + splat + intermediate.0;
133        let distances1 = norms1 + splat + intermediate.1;
134
135        // SAFETY: `mins` has length 16 - this loads the first 8.
136        let current_distances0 = unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr()) };
137        // SAFETY: `mins` has length 16 - this loads the last 8.
138        let current_distances1 =
139            unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr().add(8)) };
140
141        let mask0 = distances0.lt_simd(current_distances0);
142        let mask1 = distances1.lt_simd(current_distances1);
143
144        let current_distances0 = mask0.select(distances0, current_distances0);
145        let current_distances1 = mask1.select(distances1, current_distances1);
146
147        // SAFETY: `mins` has length 16 - this stores the first 8.
148        unsafe { current_distances0.store_simd(mins.as_mut_ptr()) };
149        // SAFETY: `mins` has length 16 - this stores the last 8.
150        unsafe { current_distances1.store_simd(mins.as_mut_ptr().add(8)) };
151
152        rolling_sum
153            + std::iter::zip(
154                current_distances0.to_array().iter(),
155                current_distances1.to_array().iter(),
156            )
157            .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
158            .sum::<f64>()
159    }
160
161    #[inline(always)]
162    fn finish_last(
163        intermediate: Self::Intermediate,
164        splat: Self::Splat,
165        rolling_sum: Self::RollingSum,
166        norms: &[f32],
167        mins: &mut [f32],
168        first: usize,
169    ) -> Self::RollingSum {
170        // Check 1.
171        assert_eq!(norms.len(), first);
172        // Check 2.
173        assert_eq!(mins.len(), first);
174
175        let lo = first.min(8);
176        let hi = first - lo;
177
178        // SAFETY: This loads `first.min(8)` elements from `norms`, which is valid
179        // by check 1.
180        let norms0 = unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr(), lo) };
181        let norms1 = if hi == 0 {
182            f32s::default(diskann_wide::ARCH)
183        } else {
184            // SAFETY: This is only called if `first > 8`, which means `norms.len() > 8` by
185            // check 1. Therefore, we can load `first - 8` elements from `norms.as_ptr() + 8`.
186            unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr().add(8), hi) }
187        };
188
189        let distances0 = norms0 + splat + intermediate.0;
190        let distances1 = norms1 + splat + intermediate.1;
191
192        // SAFETY: Same logic as the load for `norms0`.
193        let current_distances0 =
194            unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr(), lo) };
195        let current_distances1 = if hi == 0 {
196            f32s::default(diskann_wide::ARCH)
197        } else {
198            // SAFETY: Same logic as the load for `norms1`.
199            unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr().add(8), hi) }
200        };
201
202        let mask0 = distances0.lt_simd(current_distances0);
203        let mask1 = distances1.lt_simd(current_distances1);
204
205        let current_distances0 = mask0.select(distances0, current_distances0);
206        let current_distances1 = mask1.select(distances1, current_distances1);
207
208        // SAFETY: As per the logic for the load above, it is safe to store at least `lo`
209        // elements from the base pointer.
210        unsafe { current_distances0.store_simd_first(mins.as_mut_ptr(), lo) };
211        if hi != 0 {
212            // SAFETY: If `hi != 0`, then `first` must be at least 9. Therefore, adding 8
213            // to the base pointer is valid, as is storing `first - 8` elements to that
214            // pointer.
215            unsafe { current_distances1.store_simd_first(mins.as_mut_ptr().add(8), hi) };
216        }
217
218        rolling_sum
219            + std::iter::zip(
220                current_distances0.to_array().iter(),
221                current_distances1.to_array().iter(),
222            )
223            .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
224            .sum::<f64>()
225    }
226
227    fn complete_sum(x: Self::RollingSum) -> f64 {
228        x
229    }
230}
231
232/// Update `square_distances` to contain the minimum of its current value and the distance
233/// between each element in `transpose` and `this`.
234///
235/// Return the sum of the new `square_distances`.
236fn update_distances<const N: usize>(
237    square_distances: &mut [f32],
238    transpose: &BlockTranspose<N>,
239    norms: &[f32],
240    this: &[f32],
241    this_square_norm: f32,
242) -> f64
243where
244    BlockTranspose<N>: MicroKernel,
245{
246    // Establish our safety requirements.
247    // Check 1.
248    assert_eq!(
249        this.len(),
250        transpose.ncols(),
251        "new point and dataset must have the same dimension",
252    );
253    // Check 2.
254    assert_eq!(
255        square_distances.len(),
256        transpose.nrows(),
257        "distances buffer and dataset must have the same length",
258    );
259    // Check 3.
260    assert_eq!(
261        norms.len(),
262        transpose.nrows(),
263        "norms and dataset must have the same length",
264    );
265
266    let splat = BlockTranspose::<N>::splat(this_square_norm);
267    let mut rolling_sum = <BlockTranspose<N> as MicroKernel>::RollingSum::default();
268
269    let iter =
270        std::iter::zip(norms.chunks_exact(N), square_distances.chunks_exact_mut(N)).enumerate();
271    iter.for_each(|(block, (these_norms, these_distances))| {
272        debug_assert!(block < transpose.num_blocks());
273        // SAFETY: Because `transpose.nrows() == norms.len()`, the number of full blocks in
274        // `transpose` is `norms.nrows() / N` and therefore the induction variable `block`
275        // is less that `transpose.full_blocks()`.
276        let base = unsafe { transpose.block_ptr_unchecked(block) };
277
278        // SAFETY: The pointer `base` does point to a full block and by Check 1,
279        // `transpose.ncols() == this.len()`.
280        let intermediate = unsafe { BlockTranspose::<N>::accum_full(base, this) };
281
282        rolling_sum = BlockTranspose::<N>::finish(
283            intermediate,
284            splat,
285            rolling_sum,
286            these_norms,
287            these_distances,
288        );
289    });
290
291    // Do the last iteration if there is an un-even number of rows.
292    let remainder = transpose.remainder();
293    if remainder != 0 {
294        // SAFETY: We've checked that there is a `remainder` block. Therefore,
295        // `transpose.full_blocks() < transpose.num_blocks()`.
296        let base = unsafe { transpose.block_ptr_unchecked(transpose.full_blocks()) };
297
298        // A full accumulation is fine because `BlockTranspose` allocates at the granularity
299        // of blocks. We will just ignore the extra lanes.
300        // SAFETY: The pointer `base` does point to a full block and by Check 1,
301        // `transpose.ncols() == this.len()`.
302        let intermediate = unsafe { BlockTranspose::<N>::accum_full(base, this) };
303
304        let start = N * transpose.full_blocks();
305        rolling_sum = BlockTranspose::<N>::finish_last(
306            intermediate,
307            splat,
308            rolling_sum,
309            &norms[start..],
310            &mut square_distances[start..],
311            remainder,
312        );
313    }
314
315    BlockTranspose::<N>::complete_sum(rolling_sum)
316}
317
318#[derive(Debug, Clone, Copy, PartialEq)]
319pub enum FailureReason {
320    /// This error happens when the dataset contains fewer than the requested number of
321    /// centers.
322    DatasetTooSmall,
323    /// The dataset contains fewer than the requested number of points.
324    InsufficientDiversity,
325    /// Infinity was observed (this also happens when a NaN is present in the input data)
326    SawInfinity,
327}
328
329impl FailureReason {
330    pub fn is_numerically_recoverable(self) -> bool {
331        match self {
332            // Datasets being too small is recoverable from a `kmeans` perspective, we can
333            // simply proceed with fewer points.
334            Self::DatasetTooSmall | Self::InsufficientDiversity => true,
335
336            // If we see Infinity, downstream algorithms will likely just break.
337            // Don't expect to recover.
338            Self::SawInfinity => false,
339        }
340    }
341}
342
343impl fmt::Display for FailureReason {
344    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
345        let reason: &str = match self {
346            Self::DatasetTooSmall => "dataset does not have enough points",
347            Self::InsufficientDiversity => "dataset is insufficiently diverse",
348            Self::SawInfinity => "a value of infinity or NaN was observed",
349        };
350        f.write_str(reason)
351    }
352}
353
354#[derive(Debug, Clone, Copy, Error)]
355#[error("only populated {selected} of {expected} points because {reason}")]
356pub struct KMeansPlusPlusError {
357    /// The number of points that were selected.
358    pub selected: usize,
359    /// The number of points that were expected.
360    pub expected: usize,
361    /// A concrete reason for the failure.
362    pub reason: FailureReason,
363}
364
365impl KMeansPlusPlusError {
366    fn new(selected: usize, expected: usize, reason: FailureReason) -> Self {
367        Self {
368            selected,
369            expected,
370            reason,
371        }
372    }
373
374    pub fn is_numerically_recoverable(&self) -> bool {
375        self.reason.is_numerically_recoverable() && self.selected > 0
376    }
377}
378
379pub(crate) fn kmeans_plusplus_into_inner<const N: usize>(
380    mut points: MutMatrixView<'_, f32>,
381    data: StridedView<'_, f32>,
382    transpose: &BlockTranspose<N>,
383    norms: &[f32],
384    rng: &mut dyn RngCore,
385) -> Result<(), KMeansPlusPlusError>
386where
387    BlockTranspose<N>: MicroKernel,
388{
389    assert_eq!(norms.len(), data.nrows());
390    assert_eq!(transpose.nrows(), data.nrows());
391    assert_eq!(transpose.block_size(), data.ncols());
392    assert_eq!(points.ncols(), data.ncols());
393
394    // Zero the argument
395    points.as_mut_slice().fill(0.0);
396    let expected = points.nrows();
397
398    // Is someone trying to operate on an empty dataset?
399    //
400    // We can determine this by constructing a `Uniform` distribution over the rows and
401    // checking if the resulting range is Empty.
402    let all_rows = match Uniform::new(0, data.nrows()) {
403        Ok(dist) => dist,
404        Err(_) => {
405            // If they want 0 points, then I guess this is okay.
406            return if expected == 0 {
407                Ok(())
408            } else {
409                Err(KMeansPlusPlusError::new(
410                    0,
411                    expected,
412                    FailureReason::DatasetTooSmall,
413                ))
414            };
415        }
416    };
417
418    let mut min_distances: Vec<f32> = vec![f32::INFINITY; data.nrows()];
419    let mut picked = HashSet::with_capacity(expected);
420
421    // Pick the first point randomly.
422    let mut previous_square_norm = {
423        let i = all_rows.sample(rng);
424        points.row_mut(0).copy_from_slice(data.row(i));
425        picked.insert(i);
426        norms[i]
427    };
428
429    let mut selected = 1;
430    for current in 1..expected.min(data.nrows()) {
431        let last = points.row(current - 1);
432        let s = update_distances(
433            &mut min_distances,
434            transpose,
435            norms,
436            last,
437            previous_square_norm,
438        );
439
440        // Pick a threshold.
441        // Due to the way we compute distances, values less than 0.0 are technically
442        // possible.
443        match Uniform::<f64>::new(0.0, s) {
444            Ok(distribution) => {
445                let threshold = distribution.sample(rng);
446                let mut rolling_sum: f64 = 0.0;
447                for (i, d) in min_distances.iter().enumerate() {
448                    rolling_sum += <f32 as Into<f64>>::into(*d);
449                    if rolling_sum >= threshold && (*d > 0.0) && !picked.contains(&i) {
450                        // This point is the winner.
451                        // Copy it over and update our scratch variables.
452                        points.row_mut(current).clone_from_slice(data.row(i));
453                        picked.insert(i);
454                        previous_square_norm = norms[i];
455                        selected = current + 1;
456                        break;
457                    }
458                }
459            }
460            // If the range is empty, this implies that `s == 0.0`.
461            //
462            // In this case, we skip and try again,
463            Err(rand::distr::uniform::Error::EmptyRange) => {}
464            // The upper bound is infinite - this is an error.
465            Err(rand::distr::uniform::Error::NonFinite) => {
466                return Err(KMeansPlusPlusError::new(
467                    selected,
468                    expected,
469                    FailureReason::SawInfinity,
470                ));
471            }
472        }
473
474        // If we successfully picked a row, than `selected == current`.
475        //
476        // If this is not the case, then we failed due to insufficient diversity.
477        if selected != (current + 1) {
478            return Err(KMeansPlusPlusError::new(
479                selected,
480                expected,
481                FailureReason::InsufficientDiversity,
482            ));
483        }
484    }
485
486    // We may have terminated early due to the dataset being too small.
487    if selected != expected {
488        Err(KMeansPlusPlusError::new(
489            selected,
490            expected,
491            FailureReason::DatasetTooSmall,
492        ))
493    } else {
494        Ok(())
495    }
496}
497
498pub fn kmeans_plusplus_into(
499    centers: MutMatrixView<'_, f32>,
500    data: MatrixView<'_, f32>,
501    rng: &mut dyn RngCore,
502) -> Result<(), KMeansPlusPlusError> {
503    assert_eq!(
504        centers.ncols(),
505        data.ncols(),
506        "centers output matrix should have the same dimensionality as the dataset"
507    );
508
509    const GROUPSIZE: usize = 16;
510    let mut norms: Vec<f32> = vec![0.0; data.nrows()];
511
512    for (n, d) in std::iter::zip(norms.iter_mut(), data.row_iter()) {
513        *n = square_norm(d);
514    }
515
516    let transpose = BlockTranspose::<GROUPSIZE>::from_matrix_view(data);
517    kmeans_plusplus_into_inner(centers, data.into(), &transpose, &norms, rng)
518}
519
520#[cfg(test)]
521mod tests {
522    use diskann_utils::{lazy_format, views::Matrix};
523    use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
524    use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
525
526    use super::*;
527    use crate::utils;
528
529    fn is_in(needle: &[f32], haystack: MatrixView<'_, f32>) -> bool {
530        assert_eq!(needle.len(), haystack.ncols());
531        haystack.row_iter().any(|row| row == needle)
532    }
533
534    fn check_post_conditions(
535        centers: MatrixView<'_, f32>,
536        data: MatrixView<'_, f32>,
537        err: &KMeansPlusPlusError,
538    ) {
539        assert_eq!(err.expected, centers.nrows());
540        assert!(err.expected > err.selected);
541        for i in 0..err.selected {
542            assert!(is_in(centers.row(i), data.as_view()));
543        }
544        for i in err.selected..centers.nrows() {
545            assert!(centers.row(i).iter().all(|j| *j == 0.0));
546        }
547    }
548
549    ///////////////////
550    // Error Display //
551    ///////////////////
552
553    #[test]
554    fn test_error_display() {
555        assert_eq!(
556            format!("{}", FailureReason::DatasetTooSmall),
557            "dataset does not have enough points"
558        );
559
560        assert_eq!(
561            format!("{}", FailureReason::InsufficientDiversity),
562            "dataset is insufficiently diverse"
563        );
564
565        assert_eq!(
566            format!("{}", FailureReason::SawInfinity),
567            "a value of infinity or NaN was observed"
568        );
569    }
570
571    //////////////////////
572    // Update Distances //
573    //////////////////////
574
575    /// Seed `x` (a KxN matrix) with the following:
576    /// ```text
577    /// 0,   1,   2,   3 ...   N-1
578    /// 1,   2,   3,   4 ...   N
579    /// ...
580    /// K-1, K,   K+1, K+3 ... N+K-2
581    /// ```
582    fn set_default_values(mut x: MutMatrixView<'_, f32>) {
583        for (i, row) in x.row_iter_mut().enumerate() {
584            for (j, r) in row.iter_mut().enumerate() {
585                *r = (i + j) as f32;
586            }
587        }
588    }
589
590    // The implementation of `update_distances` is not particularly unsafe with relation
591    // to `dim`, but *is* potentially unsafe with respect to `num_points`.
592    //
593    // Therefore, we need to be sure we sweep sufficient values of `num_points` to get full
594    // coverage with Miri.
595    //
596    // This works our to our advantage because smaller `dims` mean we can have precise
597    // floating point values.
598    fn test_update_distances_impl<const N: usize, R>(num_points: usize, dim: usize, rng: &mut R)
599    where
600        BlockTranspose<N>: MicroKernel,
601        R: Rng,
602    {
603        let context = lazy_format!(
604            "setup: N = {}, num_points = {}, dim = {}",
605            N,
606            num_points,
607            dim
608        );
609
610        let mut data = Matrix::<f32>::new(0.0, num_points, dim);
611        set_default_values(data.as_mut_view());
612
613        let square_norms: Vec<f32> = data.row_iter().map(square_norm).collect();
614
615        // The sample points we are computing the distances against.
616        let num_samples = 3;
617        let mut samples = Matrix::<f32>::new(0.0, num_samples, dim);
618        let mut distances = vec![f32::INFINITY; num_points];
619        let distribution = Uniform::<u32>::new(0, (num_points + dim) as u32).unwrap();
620        let transpose = BlockTranspose::<N>::from_matrix_view(data.as_view());
621
622        let mut last_residual = f64::INFINITY;
623        for i in 0..num_samples {
624            // Pick a sample.
625            {
626                let row = samples.row_mut(i);
627                row.iter_mut().for_each(|r| {
628                    *r = distribution.sample(rng) as f32;
629                });
630            }
631            let row = samples.row(i);
632            let norm = square_norm(row);
633
634            let residual = update_distances(&mut distances, &transpose, &square_norms, row, norm);
635
636            // Make sure all the distances are correct.
637            for (n, (d, data)) in std::iter::zip(distances.iter(), data.row_iter()).enumerate() {
638                let mut min_distance = f32::INFINITY;
639                for j in 0..=i {
640                    let distance = SquaredL2::evaluate(samples.row(j), data);
641                    min_distance = min_distance.min(distance);
642                }
643                assert_eq!(
644                    min_distance, *d,
645                    "failed on row {n} on iteration {i}. {}",
646                    context
647                );
648            }
649
650            // The distances match - ensure that the residual was computed properly.
651            assert_eq!(
652                residual,
653                distances.iter().sum::<f32>() as f64,
654                "residual sum failed on iteration {i} - {}",
655                context
656            );
657
658            // Finally - make sure the residual is dropping.
659            assert!(
660                residual <= last_residual,
661                "residual check failed on iteration {}, last = {}, this = {} - {}",
662                i,
663                last_residual,
664                residual,
665                context
666            );
667
668            last_residual = residual;
669        }
670    }
671
672    /// A note on testing methodology:
673    ///
674    /// This function targets running under Miri to test the indexing logic in
675    /// `update_distances`.
676    ///
677    /// It is appropriate for implementations what are not "unsafe" in `dim` but are
678    /// "unsafe" in `num_points`.
679    ///
680    /// In other words, the SIMD operations we are tracking block along `num_points` and
681    /// not `dim`. This lets us run at a much smaller `dim` to help Miri finish more quickly.
682    #[test]
683    fn test_update_distances() {
684        let mut rng = StdRng::seed_from_u64(0x56c94b53c73e4fd9);
685        for num_points in 0..48 {
686            for dim in 1..4 {
687                test_update_distances_impl(num_points, dim, &mut rng);
688            }
689        }
690    }
691
692    //////////////
693    // Kmeans++ //
694    //////////////
695
696    // Kmeans++ sanity checks - if there are only `N` distinct and we want `N` centers,
697    // then all `N` should be selected without repeats.
698    fn sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
699        let repeats_per_center = 3;
700        let context = lazy_format!(
701            "dim = {}, ncenters = {}, repeats_per_center = {}",
702            dim,
703            ncenters,
704            repeats_per_center
705        );
706
707        let ndata = repeats_per_center * ncenters;
708        let mut values: Vec<f32> = (0..ncenters)
709            .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
710            .collect();
711        assert_eq!(values.len(), ndata);
712
713        values.shuffle(rng);
714        let mut data = Matrix::new(0.0, ndata, dim);
715        for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
716            r.fill(*v);
717        }
718
719        let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
720        kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
721
722        // Make sure that each value was selected for a center.
723        let mut seen = HashSet::<usize>::new();
724        for c in centers.row_iter() {
725            let first = c[0];
726            assert!(c.iter().all(|i| *i == first));
727
728            let v: usize = first.round() as usize;
729            assert_eq!(v as f32, first, "conversion was not lossless - {}", context);
730
731            if !seen.insert(v) {
732                panic!("value {first} seen more than oncex - {}", context);
733            }
734        }
735        assert_eq!(
736            seen.len(),
737            ncenters,
738            "not all points were seen - {}",
739            context
740        );
741    }
742
743    #[test]
744    #[cfg(not(miri))]
745    fn sanity_check() {
746        let dims = [1, 16];
747        let ncenters = [1, 5, 20, 255];
748        let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
749
750        for ncenters in ncenters {
751            for dim in dims {
752                sanity_check_impl(ncenters, dim, &mut rng);
753            }
754        }
755    }
756
757    // This test is like the sanity check - but instead of exact repeats, we use slightly
758    // perturbed values to test that the proportionality is of distances is respected.
759    fn fuzzy_sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
760        let repeats_per_center = 3;
761
762        // A spreading coefficient to space-out points.
763        let spreading_multiplier: usize = 16;
764        // Purturbation distribution to apply to the input data.
765        let perturbation_distribution = Uniform::new(-0.125, 0.125).unwrap();
766
767        let context = lazy_format!(
768            "dim = {}, ncenters = {}, repeats_per_center = {}, multiplier = {}",
769            dim,
770            ncenters,
771            repeats_per_center,
772            spreading_multiplier,
773        );
774
775        let ndata = repeats_per_center * ncenters;
776        let mut values: Vec<f32> = (0..ncenters)
777            .flat_map(|i| {
778                // We need to bounce through a vec to avoid borrowing issues.
779                let v: Vec<f32> = (0..repeats_per_center)
780                    .map(|_| {
781                        (spreading_multiplier * i) as f32 + perturbation_distribution.sample(rng)
782                    })
783                    .collect();
784
785                v.into_iter()
786            })
787            .collect();
788        assert_eq!(values.len(), ndata);
789
790        values.shuffle(rng);
791        let mut data = Matrix::new(0.0, ndata, dim);
792        for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
793            r.fill(*v);
794        }
795
796        let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
797        kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
798
799        // Make sure that each value was selected for a center.
800        let mut seen = HashSet::<usize>::new();
801        for (i, c) in centers.row_iter().enumerate() {
802            let first = c[0];
803            let v: usize = first.round() as usize;
804            assert_eq!(
805                v % spreading_multiplier,
806                0,
807                "expected row value to be close to a multiple of the spreading multiplier, \
808                 instead got {} - {}",
809                v,
810                context
811            );
812            seen.insert(v);
813
814            // Make sure the center is equal to one of the data points.
815            let mut found = false;
816            for r in data.row_iter() {
817                if r == c {
818                    found = true;
819                    break;
820                }
821            }
822            if !found {
823                panic!(
824                    "center {} was not found in the original dataset - {}",
825                    i, context,
826                );
827            }
828        }
829        assert!(
830            seen.len() as f32 >= 0.95 * (ncenters as f32),
831            "expected the distribution of centers to be wide, \
832             instead {} unique values were found - {}",
833            seen.len(),
834            context
835        );
836    }
837
838    #[test]
839    #[cfg(not(miri))]
840    fn fuzzy_sanity_check() {
841        let dims = [1, 16];
842        // Apparently passing in `0` for `ncenters` works in a well-defined way.
843        let ncenters = [0, 1, 5, 20, 255];
844        let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
845
846        for ncenters in ncenters {
847            for dim in dims {
848                fuzzy_sanity_check_impl(ncenters, dim, &mut rng);
849            }
850        }
851    }
852
853    // Failure modes
854    #[test]
855    fn fail_empty_dataset() {
856        let data = Matrix::new(0.0, 0, 5);
857        let mut centers = Matrix::new(0.0, 10, data.ncols());
858
859        let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
860
861        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
862        assert!(
863            result.is_err(),
864            "kmeans++ on an empty dataset with non-empty centers should be an error"
865        );
866        let err = result.unwrap_err();
867        assert_eq!(err.selected, 0);
868        assert_eq!(err.expected, centers.nrows());
869        assert_eq!(err.reason, FailureReason::DatasetTooSmall);
870        assert!(!err.is_numerically_recoverable());
871
872        check_post_conditions(centers.as_view(), data.as_view(), &err);
873    }
874
875    #[test]
876    fn both_empty_is_okay() {
877        let data = Matrix::new(0.0, 0, 5);
878        let mut centers = Matrix::new(0.0, 0, data.ncols());
879        let mut rng = StdRng::seed_from_u64(0x6f7031afd9b5aa18);
880        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
881        assert!(
882            result.is_ok(),
883            "selecting 0 points from an empty dataset is okay"
884        );
885    }
886
887    #[test]
888    fn fail_dataset_not_big_enough() {
889        let ndata = 5;
890        let ncenters = 10;
891        let dim = 5;
892
893        let mut data = Matrix::new(0.0, ndata, dim);
894        set_default_values(data.as_mut_view());
895        let mut centers = Matrix::new(f32::INFINITY, ncenters, data.ncols());
896
897        let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
898
899        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
900        assert!(
901            result.is_err(),
902            "kmeans++ on an empty dataset with non-empty centers should be an error"
903        );
904        let err = result.unwrap_err();
905        assert_eq!(err.selected, data.nrows());
906        assert_eq!(err.expected, centers.nrows());
907        assert_eq!(err.reason, FailureReason::DatasetTooSmall);
908        assert!(err.is_numerically_recoverable());
909
910        check_post_conditions(centers.as_view(), data.as_view(), &err);
911    }
912
913    // In this test - we ensure that we process as much of a non-diverse dataset as we can
914    // before returning an error.
915    #[test]
916    fn fail_diversity_check() {
917        let ncenters = 10;
918        let ndata = 50;
919        let dim = 3;
920        let mut rng = StdRng::seed_from_u64(0xca57b032c21bf4bb);
921
922        // Make sure the dataset only contains 5 unique values.
923        let repeats_per_center = 10;
924        assert!(ncenters * repeats_per_center > ndata);
925        let mut values: Vec<f32> = (0..utils::div_round_up(ndata, repeats_per_center))
926            .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
927            .collect();
928        assert!(values.len() >= ndata);
929
930        values.shuffle(&mut rng);
931        let mut data = Matrix::new(0.0, ndata, dim);
932        for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
933            r.fill(*v);
934        }
935
936        let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
937        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
938        assert!(
939            result.is_err(),
940            "dataset should not have enough unique points"
941        );
942        let err = result.unwrap_err();
943        assert_eq!(err.selected, utils::div_round_up(ndata, repeats_per_center));
944        assert_eq!(err.expected, centers.nrows());
945        assert_eq!(err.reason, FailureReason::InsufficientDiversity);
946        assert!(err.is_numerically_recoverable());
947
948        check_post_conditions(centers.as_view(), data.as_view(), &err);
949    }
950
951    #[test]
952    fn fail_intinity_check() {
953        let mut data = Matrix::new(0.0, 10, 1);
954        set_default_values(data.as_mut_view());
955
956        // A very large value that will overflow to infinity when computing the norm.
957        data[(6, 0)] = -3.4028235e38;
958        let mut centers = Matrix::new(0.0, 2, 1);
959
960        let mut rng = StdRng::seed_from_u64(0xc0449b2aa4e12f05);
961
962        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
963        assert!(result.is_err(), "result should complain about infinity");
964        let err = result.unwrap_err();
965        assert_eq!(err.selected, 1);
966        assert_eq!(err.expected, centers.nrows());
967        assert_eq!(err.reason, FailureReason::SawInfinity);
968        assert!(!err.is_numerically_recoverable());
969
970        check_post_conditions(centers.as_view(), data.as_view(), &err);
971    }
972
973    #[test]
974    fn fail_nan_check() {
975        let mut data = Matrix::new(0.0, 10, 1);
976        set_default_values(data.as_mut_view());
977
978        // A very large value that will overflow to infinity when computing the norm.
979        data[(6, 0)] = f32::NAN;
980        let mut centers = Matrix::new(0.0, 2, 1);
981
982        let mut rng = StdRng::seed_from_u64(0x55808c6c728c8473);
983
984        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
985        assert!(result.is_err(), "result should complain about NaN");
986        let err = result.unwrap_err();
987        assert_eq!(err.selected, 1);
988        assert_eq!(err.expected, centers.nrows());
989        assert_eq!(err.reason, FailureReason::SawInfinity);
990        assert!(!err.is_numerically_recoverable());
991
992        check_post_conditions(centers.as_view(), data.as_view(), &err);
993    }
994
995    ///////////////////////////////
996    // Panics - update_distances //
997    ///////////////////////////////
998
999    #[test]
1000    #[should_panic(expected = "new point and dataset must have the same dimension")]
1001    fn update_distances_panics_dim_mismatch() {
1002        let npoints = 5;
1003        let dim = 8;
1004        let mut square_distances = vec![0.0; npoints];
1005        let data = Matrix::new(0.0, npoints, dim);
1006        let norms = vec![0.0; npoints];
1007        let this = vec![0.0; dim + 1]; // Incorrect
1008        let this_square_norm = 0.0;
1009        update_distances::<16>(
1010            &mut square_distances,
1011            &BlockTranspose::from_matrix_view(data.as_view()),
1012            &norms,
1013            &this,
1014            this_square_norm,
1015        );
1016    }
1017
1018    #[test]
1019    #[should_panic(expected = "distances buffer and dataset must have the same length")]
1020    fn update_distances_panics_distances_length_mismatch() {
1021        let npoints = 5;
1022        let dim = 8;
1023        let mut square_distances = vec![0.0; npoints + 1]; // Incorrect
1024        let data = Matrix::new(0.0, npoints, dim);
1025        let norms = vec![0.0; npoints];
1026        let this = vec![0.0; dim];
1027        let this_square_norm = 0.0;
1028        update_distances::<16>(
1029            &mut square_distances,
1030            &BlockTranspose::from_matrix_view(data.as_view()),
1031            &norms,
1032            &this,
1033            this_square_norm,
1034        );
1035    }
1036
1037    #[test]
1038    #[should_panic(expected = "norms and dataset must have the same length")]
1039    fn update_distances_panics_norms_length_mismatch() {
1040        let npoints = 5;
1041        let dim = 8;
1042        let mut square_distances = vec![0.0; npoints];
1043        let data = Matrix::new(0.0, npoints, dim);
1044        let norms = vec![0.0; npoints + 1]; // Incorrect
1045        let this = vec![0.0; dim];
1046        let this_square_norm = 0.0;
1047        update_distances::<16>(
1048            &mut square_distances,
1049            &BlockTranspose::from_matrix_view(data.as_view()),
1050            &norms,
1051            &this,
1052            this_square_norm,
1053        );
1054    }
1055
1056    ///////////////////////////////////
1057    // Panics - kmeans_plusplus_into //
1058    ///////////////////////////////////
1059
1060    #[test]
1061    #[should_panic(
1062        expected = "centers output matrix should have the same dimensionality as the dataset"
1063    )]
1064    fn kmeans_plusplus_into_panics_dim_mismatch() {
1065        let mut centers = Matrix::new(0.0, 2, 10);
1066        let data = Matrix::new(0.0, 2, 9);
1067        kmeans_plusplus_into(
1068            centers.as_mut_view(),
1069            data.as_view(),
1070            &mut rand::rngs::ThreadRng::default(),
1071        )
1072        .unwrap();
1073    }
1074}