Skip to main content

diskann_quantization/algorithms/kmeans/
plusplus.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{collections::HashSet, fmt};
7
8use diskann_utils::{
9    strided::StridedView,
10    views::{MatrixView, MutMatrixView},
11};
12use diskann_wide::{SIMDMulAdd, SIMDPartialOrd, SIMDSelect, SIMDVector};
13use rand::{
14    RngCore,
15    distr::{Distribution, Uniform},
16};
17use thiserror::Error;
18
19use super::common::{BlockTranspose, square_norm};
20
21/// An internal trait implemented for `BlockTranspose` used to accelerate
22///
23/// 1. Computation of distances between a newly selected kmeans++ center and all elements
24///    in the dataset.
25///
26/// 2. Updating of the minimum distance between the dataset and all candidates.
27///
28/// 3. Computing the sum of the squares of all the minimum distances in the dataset so far.
29///
30/// All of these operations are fused for efficiency.
31///
32/// This trait is only meant to be implemented by `BlockTranspose`.
33pub(crate) trait MicroKernel {
34    /// The intermediate value storing inner products.
35    type Intermediate;
36
37    /// The type of the rolling sum used to accumulate the sum of all the squared minimum
38    /// distances.
39    type RollingSum: Default + Copy;
40
41    /// A potentially broadcasted representation of a `f32` number.
42    type Splat: Copy;
43
44    /// Return the broadcasted representation of `x`.
45    fn splat(x: f32) -> Self::Splat;
46
47    /// Compute the distances between `this` and all the vectors stored in `block`.
48    ///
49    /// This implementation works for both full blocks and partial blocks.
50    ///
51    /// # SAFETY
52    ///
53    /// `block` must be the base pointer of a data block in a `BlockTranspose` and the
54    /// block size of this block must have the same length as `this`.
55    unsafe fn accum_full(block: *const f32, this: &[f32]) -> Self::Intermediate;
56
57    /// Accumulate intermediate distances and store the result in `mins`.
58    fn finish(
59        intermediate: Self::Intermediate,
60        splat: Self::Splat,
61        rolling_sum: Self::RollingSum,
62        norms: &[f32],
63        mins: &mut [f32],
64    ) -> Self::RollingSum;
65
66    /// Accumulate the first `first` intermediate distances and store the result in `mins`.
67    fn finish_last(
68        intermediate: Self::Intermediate,
69        splat: Self::Splat,
70        rolling_sum: Self::RollingSum,
71        norms: &[f32],
72        mins: &mut [f32],
73        first: usize,
74    ) -> Self::RollingSum;
75
76    /// Turn the rolling sum from its internal representation to a final `f32`.
77    fn complete_sum(x: Self::RollingSum) -> f64;
78}
79
80diskann_wide::alias!(f32s = f32x8);
81
82impl MicroKernel for BlockTranspose<16> {
83    // Process 16-dimensions concurrently, split across two `Wide`s.
84    type Intermediate = (f32s, f32s);
85    type RollingSum = f64;
86    type Splat = f32s;
87
88    fn splat(x: f32) -> Self::Splat {
89        Self::Splat::splat(diskann_wide::ARCH, x)
90    }
91
92    #[inline(always)]
93    unsafe fn accum_full(block_ptr: *const f32, this: &[f32]) -> Self::Intermediate {
94        let mut s0 = f32s::default(diskann_wide::ARCH);
95        let mut s1 = f32s::default(diskann_wide::ARCH);
96
97        this.iter().enumerate().for_each(|(i, b)| {
98            let b = f32s::splat(diskann_wide::ARCH, *b);
99
100            // SAFETY: From the requirement that `self.block_size() == this.len()`, then
101            // `this.len() * 16` elements are readible from `block_ptr` and `i < this.len()`.
102            let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i)) };
103            s0 = a.mul_add_simd(b, s0);
104
105            // SAFETY: From the requirement that `self.block_size() == this.len()`, then
106            // `this.len() * 16` elements are readible from `block_ptr` and `i < this.len()`.
107            let a = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(16 * i + 8)) };
108            s1 = a.mul_add_simd(b, s1);
109        });
110
111        // Apply the final -2.0 transformation.
112        let negative = f32s::splat(diskann_wide::ARCH, -2.0);
113        (s0 * negative, s1 * negative)
114    }
115
116    #[inline(always)]
117    fn finish(
118        intermediate: Self::Intermediate,
119        splat: Self::Splat,
120        rolling_sum: Self::RollingSum,
121        norms: &[f32],
122        mins: &mut [f32],
123    ) -> Self::RollingSum {
124        assert_eq!(norms.len(), 16);
125        assert_eq!(mins.len(), 16);
126
127        // SAFETY: `norms` has length 16 - this loads the first 8.
128        let norms0 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr()) };
129        // SAFETY: `norms` has length 16 - this loads the last 8.
130        let norms1 = unsafe { f32s::load_simd(diskann_wide::ARCH, norms.as_ptr().add(8)) };
131
132        let distances0 = norms0 + splat + intermediate.0;
133        let distances1 = norms1 + splat + intermediate.1;
134
135        // SAFETY: `mins` has length 16 - this loads the first 8.
136        let current_distances0 = unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr()) };
137        // SAFETY: `mins` has length 16 - this loads the last 8.
138        let current_distances1 =
139            unsafe { f32s::load_simd(diskann_wide::ARCH, mins.as_ptr().add(8)) };
140
141        let mask0 = distances0.lt_simd(current_distances0);
142        let mask1 = distances1.lt_simd(current_distances1);
143
144        let current_distances0 = mask0.select(distances0, current_distances0);
145        let current_distances1 = mask1.select(distances1, current_distances1);
146
147        // SAFETY: `mins` has length 16 - this stores the first 8.
148        unsafe { current_distances0.store_simd(mins.as_mut_ptr()) };
149        // SAFETY: `mins` has length 16 - this stores the last 8.
150        unsafe { current_distances1.store_simd(mins.as_mut_ptr().add(8)) };
151
152        rolling_sum
153            + std::iter::zip(
154                current_distances0.to_array().iter(),
155                current_distances1.to_array().iter(),
156            )
157            .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
158            .sum::<f64>()
159    }
160
161    #[inline(always)]
162    fn finish_last(
163        intermediate: Self::Intermediate,
164        splat: Self::Splat,
165        rolling_sum: Self::RollingSum,
166        norms: &[f32],
167        mins: &mut [f32],
168        first: usize,
169    ) -> Self::RollingSum {
170        // Check 1.
171        assert_eq!(norms.len(), first);
172        // Check 2.
173        assert_eq!(mins.len(), first);
174
175        let lo = first.min(8);
176        let hi = first - lo;
177
178        // SAFETY: This loads `first.min(8)` elements from `norms`, which is valid
179        // by check 1.
180        let norms0 = unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr(), lo) };
181        let norms1 = if hi == 0 {
182            f32s::default(diskann_wide::ARCH)
183        } else {
184            // SAFETY: This is only called if `first > 8`, which means `norms.len() > 8` by
185            // check 1. Therefore, we can load `first - 8` elements from `norms.as_ptr() + 8`.
186            unsafe { f32s::load_simd_first(diskann_wide::ARCH, norms.as_ptr().add(8), hi) }
187        };
188
189        let distances0 = norms0 + splat + intermediate.0;
190        let distances1 = norms1 + splat + intermediate.1;
191
192        // SAFETY: Same logic as the load for `norms0`.
193        let current_distances0 =
194            unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr(), lo) };
195        let current_distances1 = if hi == 0 {
196            f32s::default(diskann_wide::ARCH)
197        } else {
198            // SAFETY: Same logic as the load for `norms1`.
199            unsafe { f32s::load_simd_first(diskann_wide::ARCH, mins.as_ptr().add(8), hi) }
200        };
201
202        let mask0 = distances0.lt_simd(current_distances0);
203        let mask1 = distances1.lt_simd(current_distances1);
204
205        let current_distances0 = mask0.select(distances0, current_distances0);
206        let current_distances1 = mask1.select(distances1, current_distances1);
207
208        // SAFETY: As per the logic for the load above, it is safe to store at least `lo`
209        // elements from the base pointer.
210        unsafe { current_distances0.store_simd_first(mins.as_mut_ptr(), lo) };
211        if hi != 0 {
212            // SAFETY: If `hi != 0`, then `first` must be at least 9. Therefore, adding 8
213            // to the base pointer is valid, as is storing `first - 8` elements to that
214            // pointer.
215            unsafe { current_distances1.store_simd_first(mins.as_mut_ptr().add(8), hi) };
216        }
217
218        rolling_sum
219            + std::iter::zip(
220                current_distances0.to_array().iter(),
221                current_distances1.to_array().iter(),
222            )
223            .map(|(d0, d1)| (*d0 as f64) + (*d1 as f64))
224            .sum::<f64>()
225    }
226
227    fn complete_sum(x: Self::RollingSum) -> f64 {
228        x
229    }
230}
231
232/// Update `square_distances` to contain the minimum of its current value and the distance
233/// between each element in `transpose` and `this`.
234///
235/// Return the sum of the new `square_distances`.
236fn update_distances<const N: usize>(
237    square_distances: &mut [f32],
238    transpose: &BlockTranspose<N>,
239    norms: &[f32],
240    this: &[f32],
241    this_square_norm: f32,
242) -> f64
243where
244    BlockTranspose<N>: MicroKernel,
245{
246    // Establish our safety requirements.
247    // Check 1.
248    assert_eq!(
249        this.len(),
250        transpose.ncols(),
251        "new point and dataset must have the same dimension",
252    );
253    // Check 2.
254    assert_eq!(
255        square_distances.len(),
256        transpose.nrows(),
257        "distances buffer and dataset must have the same length",
258    );
259    // Check 3.
260    assert_eq!(
261        norms.len(),
262        transpose.nrows(),
263        "norms and dataset must have the same length",
264    );
265
266    let splat = BlockTranspose::<N>::splat(this_square_norm);
267    let mut rolling_sum = <BlockTranspose<N> as MicroKernel>::RollingSum::default();
268
269    let iter =
270        std::iter::zip(norms.chunks_exact(N), square_distances.chunks_exact_mut(N)).enumerate();
271    iter.for_each(|(block, (these_norms, these_distances))| {
272        debug_assert!(block < transpose.num_blocks());
273        // SAFETY: Because `transpose.nrows() == norms.len()`, the number of full blocks in
274        // `transpose` is `norms.nrows() / N` and therefore the induction variable `block`
275        // is less that `transpose.full_blocks()`.
276        let base = unsafe { transpose.block_ptr_unchecked(block) };
277
278        // SAFETY: The pointer `base` does point to a full block and by Check 1,
279        // `transpose.ncols() == this.len()`.
280        let intermediate = unsafe { BlockTranspose::<N>::accum_full(base, this) };
281
282        rolling_sum = BlockTranspose::<N>::finish(
283            intermediate,
284            splat,
285            rolling_sum,
286            these_norms,
287            these_distances,
288        );
289    });
290
291    // Do the last iteration if there is an un-even number of rows.
292    let remainder = transpose.remainder();
293    if remainder != 0 {
294        // SAFETY: We've checked that there is a `remainder` block. Therefore,
295        // `transpose.full_blocks() < transpose.num_blocks()`.
296        let base = unsafe { transpose.block_ptr_unchecked(transpose.full_blocks()) };
297
298        // A full accumulation is fine because `BlockTranspose` allocates at the granularity
299        // of blocks. We will just ignore the extra lanes.
300        // SAFETY: The pointer `base` does point to a full block and by Check 1,
301        // `transpose.ncols() == this.len()`.
302        let intermediate = unsafe { BlockTranspose::<N>::accum_full(base, this) };
303
304        let start = N * transpose.full_blocks();
305        rolling_sum = BlockTranspose::<N>::finish_last(
306            intermediate,
307            splat,
308            rolling_sum,
309            &norms[start..],
310            &mut square_distances[start..],
311            remainder,
312        );
313    }
314
315    BlockTranspose::<N>::complete_sum(rolling_sum)
316}
317
318#[derive(Debug, Clone, Copy, PartialEq)]
319pub enum FailureReason {
320    /// This error happens when the dataset contains fewer than the requested number of
321    /// centers.
322    DatasetTooSmall,
323    /// The dataset contains fewer than the requested number of points.
324    InsufficientDiversity,
325    /// Infinity was observed (this also happens when a NaN is present in the input data)
326    SawInfinity,
327}
328
329impl FailureReason {
330    pub fn is_numerically_recoverable(self) -> bool {
331        match self {
332            // Datasets being too small is recoverable from a `kmeans` perspective, we can
333            // simply proceed with fewer points.
334            Self::DatasetTooSmall | Self::InsufficientDiversity => true,
335
336            // If we see Infinity, downstream algorithms will likely just break.
337            // Don't expect to recover.
338            Self::SawInfinity => false,
339        }
340    }
341}
342
343impl fmt::Display for FailureReason {
344    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
345        let reason: &str = match self {
346            Self::DatasetTooSmall => "dataset does not have enough points",
347            Self::InsufficientDiversity => "dataset is insufficiently diverse",
348            Self::SawInfinity => "a value of infinity or NaN was observed",
349        };
350        f.write_str(reason)
351    }
352}
353
354#[derive(Debug, Clone, Copy, Error)]
355#[error("only populated {selected} of {expected} points because {reason}")]
356pub struct KMeansPlusPlusError {
357    /// The number of points that were selected.
358    pub selected: usize,
359    /// The number of points that were expected.
360    pub expected: usize,
361    /// A concrete reason for the failure.
362    pub reason: FailureReason,
363}
364
365impl KMeansPlusPlusError {
366    fn new(selected: usize, expected: usize, reason: FailureReason) -> Self {
367        Self {
368            selected,
369            expected,
370            reason,
371        }
372    }
373
374    pub fn is_numerically_recoverable(&self) -> bool {
375        self.reason.is_numerically_recoverable() && self.selected > 0
376    }
377}
378
379pub(crate) fn kmeans_plusplus_into_inner<const N: usize>(
380    mut points: MutMatrixView<'_, f32>,
381    data: StridedView<'_, f32>,
382    transpose: &BlockTranspose<N>,
383    norms: &[f32],
384    rng: &mut dyn RngCore,
385) -> Result<(), KMeansPlusPlusError>
386where
387    BlockTranspose<N>: MicroKernel,
388{
389    assert_eq!(norms.len(), data.nrows());
390    assert_eq!(transpose.nrows(), data.nrows());
391    assert_eq!(transpose.block_size(), data.ncols());
392    assert_eq!(points.ncols(), data.ncols());
393
394    // Zero the argument
395    points.as_mut_slice().fill(0.0);
396    let expected = points.nrows();
397
398    // Is someone trying to operate on an empty dataset?
399    //
400    // We can determine this by constructing a `Uniform` distribution over the rows and
401    // checking if the resulting range is Empty.
402    let all_rows = match Uniform::new(0, data.nrows()) {
403        Ok(dist) => dist,
404        Err(_) => {
405            // If they want 0 points, then I guess this is okay.
406            return if expected == 0 {
407                Ok(())
408            } else {
409                Err(KMeansPlusPlusError::new(
410                    0,
411                    expected,
412                    FailureReason::DatasetTooSmall,
413                ))
414            };
415        }
416    };
417
418    let mut min_distances: Vec<f32> = vec![f32::INFINITY; data.nrows()];
419    let mut picked = HashSet::with_capacity(expected);
420
421    // Pick the first point randomly.
422    let mut previous_square_norm = {
423        let i = all_rows.sample(rng);
424        points.row_mut(0).copy_from_slice(data.row(i));
425        picked.insert(i);
426        norms[i]
427    };
428
429    let mut selected = 1;
430    for current in 1..expected.min(data.nrows()) {
431        let last = points.row(current - 1);
432        let s = update_distances(
433            &mut min_distances,
434            transpose,
435            norms,
436            last,
437            previous_square_norm,
438        );
439
440        // Pick a threshold.
441        // Due to the way we compute distances, values less than 0.0 are technically
442        // possible.
443        match Uniform::<f64>::new(0.0, s) {
444            Ok(distribution) => {
445                let threshold = distribution.sample(rng);
446                let mut rolling_sum: f64 = 0.0;
447                for (i, d) in min_distances.iter().enumerate() {
448                    rolling_sum += <f32 as Into<f64>>::into(*d);
449                    if rolling_sum >= threshold && (*d > 0.0) && !picked.contains(&i) {
450                        // This point is the winner.
451                        // Copy it over and update our scratch variables.
452                        points.row_mut(current).clone_from_slice(data.row(i));
453                        picked.insert(i);
454                        previous_square_norm = norms[i];
455                        selected = current + 1;
456                        break;
457                    }
458                }
459            }
460            // If the range is empty, this implies that `s == 0.0`.
461            //
462            // In this case, we skip and try again,
463            Err(rand::distr::uniform::Error::EmptyRange) => {}
464            // The upper bound is infinite - this is an error.
465            Err(rand::distr::uniform::Error::NonFinite) => {
466                return Err(KMeansPlusPlusError::new(
467                    selected,
468                    expected,
469                    FailureReason::SawInfinity,
470                ));
471            }
472        }
473
474        // If we successfully picked a row, than `selected == current`.
475        //
476        // If this is not the case, then we failed due to insufficient diversity.
477        if selected != (current + 1) {
478            return Err(KMeansPlusPlusError::new(
479                selected,
480                expected,
481                FailureReason::InsufficientDiversity,
482            ));
483        }
484    }
485
486    // We may have terminated early due to the dataset being too small.
487    if selected != expected {
488        Err(KMeansPlusPlusError::new(
489            selected,
490            expected,
491            FailureReason::DatasetTooSmall,
492        ))
493    } else {
494        Ok(())
495    }
496}
497
498pub fn kmeans_plusplus_into(
499    centers: MutMatrixView<'_, f32>,
500    data: MatrixView<'_, f32>,
501    rng: &mut dyn RngCore,
502) -> Result<(), KMeansPlusPlusError> {
503    assert_eq!(
504        centers.ncols(),
505        data.ncols(),
506        "centers output matrix should have the same dimensionality as the dataset"
507    );
508
509    const GROUPSIZE: usize = 16;
510    let mut norms: Vec<f32> = vec![0.0; data.nrows()];
511
512    for (n, d) in std::iter::zip(norms.iter_mut(), data.row_iter()) {
513        *n = square_norm(d);
514    }
515
516    let transpose = BlockTranspose::<GROUPSIZE>::from_matrix_view(data);
517    kmeans_plusplus_into_inner(centers, data.into(), &transpose, &norms, rng)
518}
519
520#[cfg(test)]
521mod tests {
522    use diskann_utils::{lazy_format, views::Matrix};
523    use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
524    use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
525
526    use super::*;
527    use crate::utils;
528
529    fn is_in(needle: &[f32], haystack: MatrixView<'_, f32>) -> bool {
530        assert_eq!(needle.len(), haystack.ncols());
531        haystack.row_iter().any(|row| row == needle)
532    }
533
534    fn check_post_conditions(
535        centers: MatrixView<'_, f32>,
536        data: MatrixView<'_, f32>,
537        err: &KMeansPlusPlusError,
538    ) {
539        assert_eq!(err.expected, centers.nrows());
540        assert!(err.expected > err.selected);
541        for i in 0..err.selected {
542            assert!(is_in(centers.row(i), data.as_view()));
543        }
544        for i in err.selected..centers.nrows() {
545            assert!(centers.row(i).iter().all(|j| *j == 0.0));
546        }
547    }
548
549    ///////////////////
550    // Error Display //
551    ///////////////////
552
553    #[test]
554    fn test_error_display() {
555        assert_eq!(
556            format!("{}", FailureReason::DatasetTooSmall),
557            "dataset does not have enough points"
558        );
559
560        assert_eq!(
561            format!("{}", FailureReason::InsufficientDiversity),
562            "dataset is insufficiently diverse"
563        );
564
565        assert_eq!(
566            format!("{}", FailureReason::SawInfinity),
567            "a value of infinity or NaN was observed"
568        );
569    }
570
571    //////////////////////
572    // Update Distances //
573    //////////////////////
574
575    /// Seed `x` (a KxN matrix) with the following:
576    /// ```text
577    /// 0,   1,   2,   3 ...   N-1
578    /// 1,   2,   3,   4 ...   N
579    /// ...
580    /// K-1, K,   K+1, K+3 ... N+K-2
581    /// ```
582    fn set_default_values(mut x: MutMatrixView<'_, f32>) {
583        for (i, row) in x.row_iter_mut().enumerate() {
584            for (j, r) in row.iter_mut().enumerate() {
585                *r = (i + j) as f32;
586            }
587        }
588    }
589
590    // The implementation of `update_distances` is not particularly unsafe with relation
591    // to `dim`, but *is* potentially unsafe with respect to `num_points`.
592    //
593    // Therefore, we need to be sure we sweep sufficient values of `num_points` to get full
594    // coverage with Miri.
595    //
596    // This works our to our advantage because smaller `dims` mean we can have precise
597    // floating point values.
598    fn test_update_distances_impl<const N: usize, R>(num_points: usize, dim: usize, rng: &mut R)
599    where
600        BlockTranspose<N>: MicroKernel,
601        R: Rng,
602    {
603        let context = lazy_format!(
604            "setup: N = {}, num_points = {}, dim = {}",
605            N,
606            num_points,
607            dim
608        );
609
610        let mut data = Matrix::<f32>::new(0.0, num_points, dim);
611        set_default_values(data.as_mut_view());
612
613        let square_norms: Vec<f32> = data.row_iter().map(square_norm).collect();
614
615        // The sample points we are computing the distances against.
616        let num_samples = 3;
617        let mut samples = Matrix::<f32>::new(0.0, num_samples, dim);
618        let mut distances = vec![f32::INFINITY; num_points];
619        let distribution = Uniform::<u32>::new(0, (num_points + dim) as u32).unwrap();
620        let transpose = BlockTranspose::<N>::from_matrix_view(data.as_view());
621
622        let mut last_residual = f64::INFINITY;
623        for i in 0..num_samples {
624            // Pick a sample.
625            {
626                let row = samples.row_mut(i);
627                row.iter_mut().for_each(|r| {
628                    *r = distribution.sample(rng) as f32;
629                });
630            }
631            let row = samples.row(i);
632            let norm = square_norm(row);
633
634            let residual = update_distances(&mut distances, &transpose, &square_norms, row, norm);
635
636            // Make sure all the distances are correct.
637            for (n, (d, data)) in std::iter::zip(distances.iter(), data.row_iter()).enumerate() {
638                let mut min_distance = f32::INFINITY;
639                for j in 0..=i {
640                    let distance = SquaredL2::evaluate(samples.row(j), data);
641                    min_distance = min_distance.min(distance);
642                }
643                assert_eq!(
644                    min_distance, *d,
645                    "failed on row {n} on iteration {i}. {}",
646                    context
647                );
648            }
649
650            // The distances match - ensure that the residual was computed properly.
651            assert_eq!(
652                residual,
653                distances.iter().sum::<f32>() as f64,
654                "residual sum failed on iteration {i} - {}",
655                context
656            );
657
658            // Finally - make sure the residual is dropping.
659            assert!(
660                residual <= last_residual,
661                "residual check failed on iteration {}, last = {}, this = {} - {}",
662                i,
663                last_residual,
664                residual,
665                context
666            );
667
668            last_residual = residual;
669        }
670    }
671
672    /// A note on testing methodology:
673    ///
674    /// This function targets running under Miri to test the indexing logic in
675    /// `update_distances`.
676    ///
677    /// It is appropriate for implementations what are not "unsafe" in `dim` but are
678    /// "unsafe" in `num_points`.
679    ///
680    /// In other words, the SIMD operations we are tracking block along `num_points` and
681    /// not `dim`. This lets us run at a much smaller `dim` to help Miri finish more quickly.
682    #[test]
683    fn test_update_distances() {
684        let mut rng = StdRng::seed_from_u64(0x56c94b53c73e4fd9);
685        for num_points in 0..48 {
686            #[cfg(miri)]
687            if num_points % 7 != 0 {
688                continue;
689            }
690
691            for dim in 1..4 {
692                test_update_distances_impl(num_points, dim, &mut rng);
693            }
694        }
695    }
696
697    //////////////
698    // Kmeans++ //
699    //////////////
700
701    // Kmeans++ sanity checks - if there are only `N` distinct and we want `N` centers,
702    // then all `N` should be selected without repeats.
703    #[cfg(not(miri))]
704    fn sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
705        let repeats_per_center = 3;
706        let context = lazy_format!(
707            "dim = {}, ncenters = {}, repeats_per_center = {}",
708            dim,
709            ncenters,
710            repeats_per_center
711        );
712
713        let ndata = repeats_per_center * ncenters;
714        let mut values: Vec<f32> = (0..ncenters)
715            .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
716            .collect();
717        assert_eq!(values.len(), ndata);
718
719        values.shuffle(rng);
720        let mut data = Matrix::new(0.0, ndata, dim);
721        for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
722            r.fill(*v);
723        }
724
725        let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
726        kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
727
728        // Make sure that each value was selected for a center.
729        let mut seen = HashSet::<usize>::new();
730        for c in centers.row_iter() {
731            let first = c[0];
732            assert!(c.iter().all(|i| *i == first));
733
734            let v: usize = first.round() as usize;
735            assert_eq!(v as f32, first, "conversion was not lossless - {}", context);
736
737            if !seen.insert(v) {
738                panic!("value {first} seen more than oncex - {}", context);
739            }
740        }
741        assert_eq!(
742            seen.len(),
743            ncenters,
744            "not all points were seen - {}",
745            context
746        );
747    }
748
749    #[test]
750    #[cfg(not(miri))]
751    fn sanity_check() {
752        let dims = [1, 16];
753        let ncenters = [1, 5, 20, 255];
754        let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
755
756        for ncenters in ncenters {
757            for dim in dims {
758                sanity_check_impl(ncenters, dim, &mut rng);
759            }
760        }
761    }
762
763    // This test is like the sanity check - but instead of exact repeats, we use slightly
764    // perturbed values to test that the proportionality is of distances is respected.
765    #[cfg(not(miri))]
766    fn fuzzy_sanity_check_impl<R: Rng>(ncenters: usize, dim: usize, rng: &mut R) {
767        let repeats_per_center = 3;
768
769        // A spreading coefficient to space-out points.
770        let spreading_multiplier: usize = 16;
771        // Purturbation distribution to apply to the input data.
772        let perturbation_distribution = Uniform::new(-0.125, 0.125).unwrap();
773
774        let context = lazy_format!(
775            "dim = {}, ncenters = {}, repeats_per_center = {}, multiplier = {}",
776            dim,
777            ncenters,
778            repeats_per_center,
779            spreading_multiplier,
780        );
781
782        let ndata = repeats_per_center * ncenters;
783        let mut values: Vec<f32> = (0..ncenters)
784            .flat_map(|i| {
785                // We need to bounce through a vec to avoid borrowing issues.
786                let v: Vec<f32> = (0..repeats_per_center)
787                    .map(|_| {
788                        (spreading_multiplier * i) as f32 + perturbation_distribution.sample(rng)
789                    })
790                    .collect();
791
792                v.into_iter()
793            })
794            .collect();
795        assert_eq!(values.len(), ndata);
796
797        values.shuffle(rng);
798        let mut data = Matrix::new(0.0, ndata, dim);
799        for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
800            r.fill(*v);
801        }
802
803        let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
804        kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), rng).unwrap();
805
806        // Make sure that each value was selected for a center.
807        let mut seen = HashSet::<usize>::new();
808        for (i, c) in centers.row_iter().enumerate() {
809            let first = c[0];
810            let v: usize = first.round() as usize;
811            assert_eq!(
812                v % spreading_multiplier,
813                0,
814                "expected row value to be close to a multiple of the spreading multiplier, \
815                 instead got {} - {}",
816                v,
817                context
818            );
819            seen.insert(v);
820
821            // Make sure the center is equal to one of the data points.
822            let mut found = false;
823            for r in data.row_iter() {
824                if r == c {
825                    found = true;
826                    break;
827                }
828            }
829            if !found {
830                panic!(
831                    "center {} was not found in the original dataset - {}",
832                    i, context,
833                );
834            }
835        }
836        assert!(
837            seen.len() as f32 >= 0.95 * (ncenters as f32),
838            "expected the distribution of centers to be wide, \
839             instead {} unique values were found - {}",
840            seen.len(),
841            context
842        );
843    }
844
845    #[test]
846    #[cfg(not(miri))]
847    fn fuzzy_sanity_check() {
848        let dims = [1, 16];
849        // Apparently passing in `0` for `ncenters` works in a well-defined way.
850        let ncenters = [0, 1, 5, 20, 255];
851        let mut rng = StdRng::seed_from_u64(0x68c2080f2ea36f5a);
852
853        for ncenters in ncenters {
854            for dim in dims {
855                fuzzy_sanity_check_impl(ncenters, dim, &mut rng);
856            }
857        }
858    }
859
860    // Failure modes
861    #[test]
862    fn fail_empty_dataset() {
863        let data = Matrix::new(0.0, 0, 5);
864        let mut centers = Matrix::new(0.0, 10, data.ncols());
865
866        let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
867
868        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
869        assert!(
870            result.is_err(),
871            "kmeans++ on an empty dataset with non-empty centers should be an error"
872        );
873        let err = result.unwrap_err();
874        assert_eq!(err.selected, 0);
875        assert_eq!(err.expected, centers.nrows());
876        assert_eq!(err.reason, FailureReason::DatasetTooSmall);
877        assert!(!err.is_numerically_recoverable());
878
879        check_post_conditions(centers.as_view(), data.as_view(), &err);
880    }
881
882    #[test]
883    fn both_empty_is_okay() {
884        let data = Matrix::new(0.0, 0, 5);
885        let mut centers = Matrix::new(0.0, 0, data.ncols());
886        let mut rng = StdRng::seed_from_u64(0x6f7031afd9b5aa18);
887        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
888        assert!(
889            result.is_ok(),
890            "selecting 0 points from an empty dataset is okay"
891        );
892    }
893
894    #[test]
895    fn fail_dataset_not_big_enough() {
896        let ndata = 5;
897        let ncenters = 10;
898        let dim = 5;
899
900        let mut data = Matrix::new(0.0, ndata, dim);
901        set_default_values(data.as_mut_view());
902        let mut centers = Matrix::new(f32::INFINITY, ncenters, data.ncols());
903
904        let mut rng = StdRng::seed_from_u64(0xa9eae150d30845a1);
905
906        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
907        assert!(
908            result.is_err(),
909            "kmeans++ on an empty dataset with non-empty centers should be an error"
910        );
911        let err = result.unwrap_err();
912        assert_eq!(err.selected, data.nrows());
913        assert_eq!(err.expected, centers.nrows());
914        assert_eq!(err.reason, FailureReason::DatasetTooSmall);
915        assert!(err.is_numerically_recoverable());
916
917        check_post_conditions(centers.as_view(), data.as_view(), &err);
918    }
919
920    // In this test - we ensure that we process as much of a non-diverse dataset as we can
921    // before returning an error.
922    #[test]
923    fn fail_diversity_check() {
924        let ncenters = 10;
925        let ndata = 50;
926        let dim = 3;
927        let mut rng = StdRng::seed_from_u64(0xca57b032c21bf4bb);
928
929        // Make sure the dataset only contains 5 unique values.
930        let repeats_per_center = 10;
931        assert!(ncenters * repeats_per_center > ndata);
932        let mut values: Vec<f32> = (0..utils::div_round_up(ndata, repeats_per_center))
933            .flat_map(|i| (0..repeats_per_center).map(move |_| i as f32))
934            .collect();
935        assert!(values.len() >= ndata);
936
937        values.shuffle(&mut rng);
938        let mut data = Matrix::new(0.0, ndata, dim);
939        for (r, v) in std::iter::zip(data.row_iter_mut(), values.iter()) {
940            r.fill(*v);
941        }
942
943        let mut centers = Matrix::new(f32::INFINITY, ncenters, dim);
944        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
945        assert!(
946            result.is_err(),
947            "dataset should not have enough unique points"
948        );
949        let err = result.unwrap_err();
950        assert_eq!(err.selected, utils::div_round_up(ndata, repeats_per_center));
951        assert_eq!(err.expected, centers.nrows());
952        assert_eq!(err.reason, FailureReason::InsufficientDiversity);
953        assert!(err.is_numerically_recoverable());
954
955        check_post_conditions(centers.as_view(), data.as_view(), &err);
956    }
957
958    #[test]
959    fn fail_intinity_check() {
960        let mut data = Matrix::new(0.0, 10, 1);
961        set_default_values(data.as_mut_view());
962
963        // A very large value that will overflow to infinity when computing the norm.
964        data[(6, 0)] = -3.4028235e38;
965        let mut centers = Matrix::new(0.0, 2, 1);
966
967        let mut rng = StdRng::seed_from_u64(0xc0449b2aa4e12f05);
968
969        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
970        assert!(result.is_err(), "result should complain about infinity");
971        let err = result.unwrap_err();
972        assert_eq!(err.selected, 1);
973        assert_eq!(err.expected, centers.nrows());
974        assert_eq!(err.reason, FailureReason::SawInfinity);
975        assert!(!err.is_numerically_recoverable());
976
977        check_post_conditions(centers.as_view(), data.as_view(), &err);
978    }
979
980    #[test]
981    fn fail_nan_check() {
982        let mut data = Matrix::new(0.0, 10, 1);
983        set_default_values(data.as_mut_view());
984
985        // A very large value that will overflow to infinity when computing the norm.
986        data[(6, 0)] = f32::NAN;
987        let mut centers = Matrix::new(0.0, 2, 1);
988
989        let mut rng = StdRng::seed_from_u64(0x55808c6c728c8473);
990
991        let result = kmeans_plusplus_into(centers.as_mut_view(), data.as_view(), &mut rng);
992        assert!(result.is_err(), "result should complain about NaN");
993        let err = result.unwrap_err();
994        assert_eq!(err.selected, 1);
995        assert_eq!(err.expected, centers.nrows());
996        assert_eq!(err.reason, FailureReason::SawInfinity);
997        assert!(!err.is_numerically_recoverable());
998
999        check_post_conditions(centers.as_view(), data.as_view(), &err);
1000    }
1001
1002    ///////////////////////////////
1003    // Panics - update_distances //
1004    ///////////////////////////////
1005
1006    #[test]
1007    #[should_panic(expected = "new point and dataset must have the same dimension")]
1008    fn update_distances_panics_dim_mismatch() {
1009        let npoints = 5;
1010        let dim = 8;
1011        let mut square_distances = vec![0.0; npoints];
1012        let data = Matrix::new(0.0, npoints, dim);
1013        let norms = vec![0.0; npoints];
1014        let this = vec![0.0; dim + 1]; // Incorrect
1015        let this_square_norm = 0.0;
1016        update_distances::<16>(
1017            &mut square_distances,
1018            &BlockTranspose::from_matrix_view(data.as_view()),
1019            &norms,
1020            &this,
1021            this_square_norm,
1022        );
1023    }
1024
1025    #[test]
1026    #[should_panic(expected = "distances buffer and dataset must have the same length")]
1027    fn update_distances_panics_distances_length_mismatch() {
1028        let npoints = 5;
1029        let dim = 8;
1030        let mut square_distances = vec![0.0; npoints + 1]; // Incorrect
1031        let data = Matrix::new(0.0, npoints, dim);
1032        let norms = vec![0.0; npoints];
1033        let this = vec![0.0; dim];
1034        let this_square_norm = 0.0;
1035        update_distances::<16>(
1036            &mut square_distances,
1037            &BlockTranspose::from_matrix_view(data.as_view()),
1038            &norms,
1039            &this,
1040            this_square_norm,
1041        );
1042    }
1043
1044    #[test]
1045    #[should_panic(expected = "norms and dataset must have the same length")]
1046    fn update_distances_panics_norms_length_mismatch() {
1047        let npoints = 5;
1048        let dim = 8;
1049        let mut square_distances = vec![0.0; npoints];
1050        let data = Matrix::new(0.0, npoints, dim);
1051        let norms = vec![0.0; npoints + 1]; // Incorrect
1052        let this = vec![0.0; dim];
1053        let this_square_norm = 0.0;
1054        update_distances::<16>(
1055            &mut square_distances,
1056            &BlockTranspose::from_matrix_view(data.as_view()),
1057            &norms,
1058            &this,
1059            this_square_norm,
1060        );
1061    }
1062
1063    ///////////////////////////////////
1064    // Panics - kmeans_plusplus_into //
1065    ///////////////////////////////////
1066
1067    #[test]
1068    #[should_panic(
1069        expected = "centers output matrix should have the same dimensionality as the dataset"
1070    )]
1071    fn kmeans_plusplus_into_panics_dim_mismatch() {
1072        let mut centers = Matrix::new(0.0, 2, 10);
1073        let data = Matrix::new(0.0, 2, 9);
1074        kmeans_plusplus_into(
1075            centers.as_mut_view(),
1076            data.as_view(),
1077            &mut rand::rngs::ThreadRng::default(),
1078        )
1079        .unwrap();
1080    }
1081}