Skip to main content

diskann_quantization/algorithms/kmeans/
lloyds.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_wide::{SIMDMask, SIMDMulAdd, SIMDPartialOrd, SIMDSelect, SIMDSumTree, SIMDVector};
7
8use super::common::square_norm;
9use crate::multi_vector::{BlockTransposed, BlockTransposedRef};
10use diskann_utils::{
11    strided::StridedView,
12    views::{Matrix, MatrixView, MutMatrixView},
13};
14
15////////////////////////////////
16// Closest Centers Algorithms //
17////////////////////////////////
18
19diskann_wide::alias!(f32s = f32x8);
20diskann_wide::alias!(u32s = u32x8);
21
22// A computation strategy where the final distance values are accumulated in-place.
23//
24// This is sufficient for low-dimensional clusterings but suffers when the dimensionality
25// increases.
26//
27// Return the residual distance.
28pub fn distances_in_place(
29    dataset: BlockTransposedRef<'_, f32, 16>,
30    data_norms: &[f32],
31    centers: MatrixView<'_, f32>,
32    center_norms: &[f32],
33    nearest: &mut [u32],
34) -> f32 {
35    // Safety Checks!
36    // Our unchecked-loads rely on these invariants holding.
37
38    // Check 1: Same number of norms as dataset elements.
39    assert_eq!(
40        dataset.nrows(),
41        data_norms.len(),
42        "dataset and data norms should have the same length"
43    );
44    // Check 2: Datasets have the same dimension.
45    assert_eq!(
46        centers.ncols(),
47        dataset.ncols(),
48        "dataset and centers should have the same dimensions"
49    );
50    // Check 3: Same number of center norms as centers.
51    assert_eq!(
52        centers.nrows(),
53        center_norms.len(),
54        "centers and center norms should have the same length"
55    );
56    // Check 4: The `nearest` output's length matches the input dataset.
57    assert_eq!(
58        nearest.len(),
59        dataset.nrows(),
60        "dataset and nearest-buffer should have the same length"
61    );
62
63    const N: usize = 16;
64    const N2: usize = N / 2;
65
66    diskann_wide::alias!(m32s = mask_f32x8);
67
68    let mut residual = f32s::default(diskann_wide::ARCH);
69
70    // Compute the distances between all vectors in the block with index `block` and
71    // two consecutive centers starting at `center_row_start`.
72    //
73    // SAFETY: The following must hold:
74    // * `block < transpose.num_blocks()` (this is safe to call on the remainder block).
75    // * `center_row_start + 1 < centers.nrows()`: This unrolls by a factor of 2, so reading
76    //    two rows must be valid.
77    let process_block_unroll_2 = |block: usize, center_row_start: usize| {
78        debug_assert!(block < dataset.num_blocks());
79        debug_assert!(center_row_start + 1 < centers.nrows());
80
81        let mut s00 = f32s::default(diskann_wide::ARCH);
82        let mut s01 = f32s::default(diskann_wide::ARCH);
83        let mut s10 = f32s::default(diskann_wide::ARCH);
84        let mut s11 = f32s::default(diskann_wide::ARCH);
85
86        // SAFETY: Closure pre-conditions mean that this access is in-bounds.
87        let block_ptr = unsafe { dataset.block_ptr_unchecked(block) };
88        for dim in 0..dataset.ncols() {
89            // SAFETY: Each block stores `N * ncols` contiguous f32s (N=16).
90            // `dim < dataset.ncols()`, so `N * dim + 7 < N * ncols`. Loads first 8.
91            let d0 = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(N * dim)) };
92            // SAFETY: Same reasoning; `N * dim + N2 + 7 < N * ncols`. Loads last 8.
93            let d1 = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(N * dim + N2)) };
94
95            // SAFETY: Closure pre-conditions and Check 2 make this a valid access.
96            let c0 = f32s::splat(diskann_wide::ARCH, unsafe {
97                *centers.get_unchecked(center_row_start, dim)
98            });
99            // SAFETY: Closure pre-conditions and Check 2 make this a valid access.
100            let c1 = f32s::splat(diskann_wide::ARCH, unsafe {
101                *centers.get_unchecked(center_row_start + 1, dim)
102            });
103
104            s00 = c0.mul_add_simd(d0, s00);
105            s01 = c0.mul_add_simd(d1, s01);
106            s10 = c1.mul_add_simd(d0, s10);
107            s11 = c1.mul_add_simd(d1, s11);
108        }
109        (s00, s01, s10, s11)
110    };
111
112    // Compute the distances between all vectors in the block with index `block` and one
113    // center starting at `center_row_start`.
114    //
115    // SAFETY: The following must hold:
116    // * `block < transpose.num_blocks()` (this is safe to call on the remainder block).
117    // * `center_row_start < centers.nrows()`: This unrolls by a factor of 2, so reading
118    //    two rows must be valid.
119    let process_block_no_unroll = |block: usize, center_row_start: usize| {
120        debug_assert!(block < dataset.num_blocks());
121        debug_assert!(center_row_start + 1 == centers.nrows());
122
123        let mut s00 = f32s::default(diskann_wide::ARCH);
124        let mut s01 = f32s::default(diskann_wide::ARCH);
125
126        // SAFETY: Closure pre-conditions mean that this access is in-bounds.
127        let block_ptr = unsafe { dataset.block_ptr_unchecked(block) };
128        for dim in 0..dataset.ncols() {
129            // SAFETY: Each block stores `N * ncols` contiguous f32s (N=16).
130            // `dim < dataset.ncols()`, so `N * dim + 7 < N * ncols`. Loads first 8.
131            let d0 = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(N * dim)) };
132            // SAFETY: Same reasoning; `N * dim + N2 + 7 < N * ncols`. Loads last 8.
133            let d1 = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(N * dim + N2)) };
134
135            // SAFETY: Closure pre-conditions and Check 2 make this a valid access.
136            let c0 = f32s::splat(diskann_wide::ARCH, unsafe {
137                *centers.get_unchecked(center_row_start, dim)
138            });
139
140            s00 = c0.mul_add_simd(d0, s00);
141            s01 = c0.mul_add_simd(d1, s01);
142        }
143        (s00, s01)
144    };
145
146    // Figure out if the number of centers to process is even or not.
147    // If it's even, we can work on centers two rows at a time.
148    //
149    // Otherwise, we need to deal with the last `centers` row independently.
150    let last_pair = if centers.nrows().is_multiple_of(2) {
151        centers.nrows()
152    } else {
153        centers.nrows() - 1
154    };
155
156    for i in 0..dataset.full_blocks() {
157        let mut t0 = (
158            f32s::splat(diskann_wide::ARCH, f32::INFINITY),
159            u32s::splat(diskann_wide::ARCH, u32::MAX),
160        );
161        let mut t1 = (
162            f32s::splat(diskann_wide::ARCH, f32::INFINITY),
163            u32s::splat(diskann_wide::ARCH, u32::MAX),
164        );
165
166        // SAFETY: Check 1 means this access is in-bounds.
167        let data_norm_ptr = unsafe { data_norms.as_ptr().add(N * i) };
168
169        // SAFETY: By Check 1 and by being in a full-block, this implies that up to 16
170        // values are safe to read from `data_norm_ptr`.
171        let d0 = unsafe { f32s::load_simd(diskann_wide::ARCH, data_norm_ptr) };
172
173        // SAFETY: By Check 1 and by being in a full-block, this implies that up to 16
174        // values are safe to read from `data_norm_ptr`.
175        let d1 = unsafe { f32s::load_simd(diskann_wide::ARCH, data_norm_ptr.add(N2)) };
176        for row_start in (0..last_pair).step_by(2) {
177            // SAFETY: By construction, `i < transpose.num_blocks()` and
178            // `row_start + 1 < centers.nrows()`.
179            let (s00, s01, s10, s11) = process_block_unroll_2(i, row_start);
180
181            // Compensate for the inner-product calculation.
182            // SAFETY: By Check 3, this access is in-bounds.
183            let n0 = f32s::splat(diskann_wide::ARCH, *unsafe {
184                center_norms.get_unchecked(row_start)
185            });
186            // SAFETY: By Check 3 and loop construction, this access is in-bounds.
187            let n1 = f32s::splat(diskann_wide::ARCH, *unsafe {
188                center_norms.get_unchecked(row_start + 1)
189            });
190
191            let s00 = n0 - s00 - s00 + d0;
192            let s01 = n0 - s01 - s01 + d1;
193            let s10 = n1 - s10 - s10 + d0;
194            let s11 = n1 - s11 - s11 + d1;
195
196            let r0 = u32s::splat(diskann_wide::ARCH, row_start as u32);
197            let r1 = u32s::splat(diskann_wide::ARCH, (row_start + 1) as u32);
198            t0 = update(update(t0, (s00, r0)), (s10, r1));
199            t1 = update(update(t1, (s01, r0)), (s11, r1));
200        }
201
202        // If there is an odd-number of centers, we need to handle that individually.
203        if !centers.nrows().is_multiple_of(2) {
204            // SAFETY: By construction, `i < transpose.num_blocks()` and
205            // `last_pair < centers.nrows()`.
206            let (s00, s01) = process_block_no_unroll(i, last_pair);
207            // SAFETY: by Check 3, this access is in-bounds.
208            let n0 = f32s::splat(diskann_wide::ARCH, unsafe {
209                *center_norms.get_unchecked(last_pair)
210            });
211
212            let s00 = n0 - s00 - s00 + d0;
213            let s01 = n0 - s01 - s01 + d1;
214
215            let r = u32s::splat(diskann_wide::ARCH, last_pair as u32);
216            t0 = update(t0, (s00, r));
217            t1 = update(t1, (s01, r));
218        }
219
220        // Write back.
221        // SAFETY: By Check 4, at least 16 elements are valid and mutable beginning at the
222        // offset `N * i`. This writes the first 8.
223        unsafe { t0.1.store_simd(nearest.as_mut_ptr().add(N * i)) }
224        // SAFETY: By Check 4, at least 16 elements are valid and mutable beginning at the
225        // offset `N * i`. This writes the last 8.
226        unsafe { t1.1.store_simd(nearest.as_mut_ptr().add(N * i + N2)) }
227
228        // Update the residual.
229        residual = residual + t0.0 + t1.0;
230    }
231
232    // IF there is a remainder block - we can do pretty much exactly the same thing we did
233    // for the full blocks. We just need to be a bit more careful when writing back the
234    // results.
235    let remainder = dataset.remainder();
236    if remainder != 0 {
237        let i = dataset.full_blocks();
238        let lo = remainder.min(N2);
239        let hi = remainder - lo;
240
241        let mut t0 = (
242            f32s::splat(diskann_wide::ARCH, f32::INFINITY),
243            u32s::splat(diskann_wide::ARCH, u32::MAX),
244        );
245        let mut t1 = (
246            f32s::splat(diskann_wide::ARCH, f32::INFINITY),
247            u32s::splat(diskann_wide::ARCH, u32::MAX),
248        );
249
250        // SAFETY: Check 1 means this access is in-bounds.
251        let data_norm_ptr = unsafe { data_norms.as_ptr().add(N * i) };
252
253        // SAFETY: By Check 1 and by being in a partial block means that up to `remainder`
254        // elements are valid. This loads up to the first 8.
255        let d0 = unsafe { f32s::load_simd_first(diskann_wide::ARCH, data_norm_ptr, lo) };
256        let d1 = if hi == 0 {
257            f32s::default(diskann_wide::ARCH)
258        } else {
259            // SAFETY: By Check 1 and by being in a partial block means that up to `remainder`
260            // elements are valid. By taking this branch, we know that `remainder` is
261            // at least 9. So it's okay to add 8 to `data_norm_pointer` and load `hi` elements.
262            unsafe { f32s::load_simd_first(diskann_wide::ARCH, data_norm_ptr.add(N2), hi) }
263        };
264
265        for row_start in (0..last_pair).step_by(2) {
266            // SAFETY: By construction, `i < transpose.num_blocks()` and
267            // `row_start + 1 < centers.nrows()`.
268            let (s00, s01, s10, s11) = process_block_unroll_2(i, row_start);
269
270            // Compensate for the inner-product calculation.
271            // SAFETY: By Check 3, this access is in-bounds.
272            let n0 = f32s::splat(diskann_wide::ARCH, *unsafe {
273                center_norms.get_unchecked(row_start)
274            });
275            // SAFETY: By Check 3 and loop construction, this access is in-bounds.
276            let n1 = f32s::splat(diskann_wide::ARCH, *unsafe {
277                center_norms.get_unchecked(row_start + 1)
278            });
279
280            let s00 = n0 - s00 - s00 + d0;
281            let s01 = n0 - s01 - s01 + d1;
282            let s10 = n1 - s10 - s10 + d0;
283            let s11 = n1 - s11 - s11 + d1;
284
285            let r0 = u32s::splat(diskann_wide::ARCH, row_start as u32);
286            let r1 = u32s::splat(diskann_wide::ARCH, (row_start + 1) as u32);
287            t0 = update(update(t0, (s00, r0)), (s10, r1));
288            t1 = update(update(t1, (s01, r0)), (s11, r1));
289        }
290
291        if !centers.nrows().is_multiple_of(2) {
292            // SAFETY: By construction, `i < transpose.num_blocks()` and
293            // `last_pair < centers.nrows()`.
294            let (s00, s01) = process_block_no_unroll(i, last_pair);
295            // SAFETY: by Check 3, this access is in-bounds.
296            let n0 = f32s::splat(diskann_wide::ARCH, unsafe {
297                *center_norms.get_unchecked(last_pair)
298            });
299
300            let s00 = n0 - s00 - s00 + d0;
301            let s01 = n0 - s01 - s01 + d1;
302
303            let r = u32s::splat(diskann_wide::ARCH, last_pair as u32);
304            t0 = update(t0, (s00, r));
305            t1 = update(t1, (s01, r));
306        }
307
308        // Write back.
309        // SAFETY: By Check 4, at least 1 and up to 16 elements are valid and mutable
310        // beginning at the offset `N * i`. This writes the first `min(8, remainder)`.
311        unsafe { t0.1.store_simd_first(nearest.as_mut_ptr().add(N * i), lo) };
312        if hi != 0 {
313            // SAFETY: By Check 4, at least 1 and up to 16 elements are valid and mutable
314            // beginning at the offset `N * i`. If `hi != 0`, then `remainder` is at
315            // least 9. So it's okay to add `8` to `nearest.as_mut_ptr()` and store `hi`
316            // elements.
317            unsafe {
318                t1.1.store_simd_first(nearest.as_mut_ptr().add(N * i + N2), hi)
319            };
320        }
321
322        // Update the residual
323        // Use a masked select to only accumulate lanes that are in-bounds.
324        residual = m32s::keep_first(diskann_wide::ARCH, lo).select(residual + t0.0, residual);
325        residual = m32s::keep_first(diskann_wide::ARCH, hi).select(residual + t1.0, residual);
326    }
327    residual.sum_tree()
328}
329
330#[inline(always)]
331fn update((d0, i0): (f32s, u32s), (d1, i1): (f32s, u32s)) -> (f32s, u32s) {
332    // Generate a mask with lanes set if a computed distance is less that one of theH
333    // current minimum distances.
334    let mask = d1.lt_simd(d0);
335    (
336        mask.select(d1, d0),
337        <u32s as SIMDVector>::Mask::from(mask).select(i1, i0),
338    )
339}
340
341/////////////////
342// Update Step //
343/////////////////
344
345fn update_centroids(mut centers: MutMatrixView<'_, f32>, data: StridedView<'_, f32>, map: &[u32]) {
346    let mut sums = Matrix::<f64>::new(0.0, centers.nrows(), centers.ncols());
347    let mut counts: Vec<u32> = vec![0; centers.nrows()];
348    data.row_iter().zip(map.iter()).for_each(|(row, &center)| {
349        counts[center as usize] += 1;
350        let sum = sums.row_mut(center as usize);
351        std::iter::zip(sum.iter_mut(), row.iter()).for_each(|(s, r)| {
352            *s += <f32 as Into<f64>>::into(*r);
353        });
354    });
355
356    std::iter::zip(counts.iter(), sums.row_iter())
357        .zip(centers.row_iter_mut())
358        .for_each(|((count, sum), center)| {
359            // If the count is zero - we do not want to divide by it because that will
360            // result in `NaN`.
361            let count = (*count).max(1);
362            std::iter::zip(sum.iter(), center.iter_mut()).for_each(|(s, c)| {
363                *c = (*s / (count as f64)) as f32;
364            });
365        });
366}
367
368////////////
369// Lloyds //
370////////////
371
372pub(crate) fn lloyds_inner(
373    data: StridedView<'_, f32>,
374    square_norms: &[f32],
375    transpose: BlockTransposedRef<'_, f32, 16>,
376    mut centers: MutMatrixView<'_, f32>,
377    max_reps: usize,
378) -> (Vec<u32>, f32) {
379    // Check our requirements.
380    let num_data = data.nrows();
381    assert_eq!(
382        num_data,
383        square_norms.len(),
384        "data and norms should have the same length"
385    );
386    assert_eq!(
387        num_data,
388        transpose.nrows(),
389        "data and transpose should have the same length"
390    );
391
392    let dim = data.ncols();
393    assert_eq!(
394        dim,
395        transpose.ncols(),
396        "data and transpose should have the same dimensions"
397    );
398    assert_eq!(
399        dim,
400        centers.ncols(),
401        "data and centers should have the same dimensions"
402    );
403
404    let mut center_square_norms: Vec<f32> = centers.row_iter().map(square_norm).collect();
405    let mut assignments: Vec<u32> = vec![0; num_data];
406    let mut residual = 0.0;
407
408    for i in 0..max_reps {
409        residual = distances_in_place(
410            transpose,
411            square_norms,
412            centers.as_view(),
413            &center_square_norms,
414            &mut assignments,
415        );
416        update_centroids(centers.as_mut_view(), data, &assignments);
417        if i != max_reps - 1 {
418            std::iter::zip(center_square_norms.iter_mut(), centers.row_iter()).for_each(
419                |(c, center)| {
420                    *c = square_norm(center);
421                },
422            );
423        }
424    }
425    (assignments, residual)
426}
427
428/// Run `max_reps` of Lloyd's algorithm over `data` and `centers`, updating the `centers`
429/// argument with the result.
430///
431/// # Returns
432///
433/// Returns a tuple `x = (Vec<u32>, f32)` where
434/// * `x.0` is the position-wise assignments of each data rows nearest center.
435/// * `x.1` is the final squared-l2 residual of the clustered dataset.
436///
437/// # Panics
438///
439/// Panics if `data.ncols() != centers.ncols()`. The data and centers must have the same
440/// dimension.
441pub fn lloyds(
442    data: MatrixView<'_, f32>,
443    centers: MutMatrixView<'_, f32>,
444    max_reps: usize,
445) -> (Vec<u32>, f32) {
446    assert_eq!(
447        data.ncols(),
448        centers.ncols(),
449        "data and centers must have the same dimension",
450    );
451
452    let transpose = BlockTransposed::<f32, 16>::from_matrix_view(data);
453    let square_norms: Vec<f32> = data.row_iter().map(square_norm).collect();
454    lloyds_inner(
455        data.into(),
456        &square_norms,
457        transpose.as_view(),
458        centers,
459        max_reps,
460    )
461}
462
463#[cfg(test)]
464mod tests {
465    #[cfg(not(miri))]
466    use diskann_utils::lazy_format;
467    use diskann_utils::views::Matrix;
468    use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
469    use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
470    #[cfg(not(miri))]
471    use rand::{
472        distr::{Distribution, Uniform},
473        seq::IndexedRandom,
474    };
475
476    use super::*;
477
478    ////////////////////////
479    // Distances in Place //
480    ////////////////////////
481
482    // The strategy here is we need to test a wide range of dimensions, dataset sizes,
483    // and nubmer of centers ... and have the dimensions be small enough that this can run
484    // relatively quickly.
485    //
486    // Outside of rare validations, Miri tests go through a different path for speed purposes.
487    #[cfg(not(miri))]
488    fn test_distances_in_place_impl<R: Rng>(
489        ndata: usize,
490        ncenters: usize,
491        dim: usize,
492        trials: usize,
493        rng: &mut R,
494    ) {
495        let context = lazy_format!("ncenters = {}, ndata = {}, dim = {}", ncenters, ndata, dim,);
496
497        let mut centers = Matrix::new(0.0, ncenters, dim);
498        let mut data = Matrix::new(0.0, ndata, dim);
499
500        // A list of random "nice" offsets that get applied to each center and data point
501        // to ensure proper visitation during computation.
502        let offsets = [-0.125, -0.0625, -0.03125, 0.03125, 0.0625, 0.125];
503
504        // Initialize `centers` uniformly but with random offsets applied to each dimension.
505        for (i, row) in centers.row_iter_mut().enumerate() {
506            for c in row {
507                *c = (i as f32) + *offsets.choose(rng).unwrap();
508            }
509        }
510
511        let center_norms: Vec<f32> = centers.row_iter().map(square_norm).collect();
512
513        // This is the distribution of how we assign data points to centers.
514        let assignment_distribution = Uniform::<usize>::new(0, centers.nrows()).unwrap();
515        let mut nearest: Vec<u32> = vec![0; ndata];
516        for trial in 0..trials {
517            let assignments: Vec<_> = (0..ndata)
518                .map(|_| assignment_distribution.sample(rng))
519                .collect();
520
521            for (assignment, row) in std::iter::zip(assignments.iter(), data.row_iter_mut()) {
522                for c in row.iter_mut() {
523                    *c = (*assignment as f32) + offsets.choose(rng).unwrap()
524                }
525            }
526
527            let data_norms: Vec<f32> = data.row_iter().map(square_norm).collect();
528
529            let residual = distances_in_place(
530                BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
531                &data_norms,
532                centers.as_view(),
533                &center_norms,
534                &mut nearest,
535            );
536
537            // Check that the assignments are correct.
538            for (i, (got, expected)) in
539                std::iter::zip(nearest.iter(), assignments.iter()).enumerate()
540            {
541                assert_eq!(
542                    *got as usize,
543                    *expected,
544                    "failed for data index {} on trial {} -- {}\n\
545                     row = {:?}\n\
546                     expected = {:?}\n\
547                     got = {:?}",
548                    i,
549                    trial,
550                    context,
551                    data.row(i),
552                    centers.row(*expected),
553                    centers.row(*got as usize),
554                );
555            }
556
557            // Check that the residual computation is correct.
558            let mut sum: f32 = 0.0;
559            for (a, row) in std::iter::zip(assignments.iter(), data.row_iter()) {
560                let distance: f32 = SquaredL2::evaluate(row, centers.row(*a));
561                sum += distance;
562            }
563            assert_eq!(sum, residual, "failed on trial {} -- {}", trial, context);
564        }
565    }
566
567    #[cfg(not(miri))]
568    const TRIALS: usize = 100;
569
570    #[test]
571    #[cfg(not(miri))]
572    fn test_distances_in_place() {
573        let mut rng = StdRng::seed_from_u64(0xece88a9c6cd86a8a);
574        for ndata in 1..=31 {
575            for ncenters in 1..=5 {
576                for dim in 1..=4 {
577                    test_distances_in_place_impl(ndata, ncenters, dim, TRIALS, &mut rng);
578                }
579            }
580        }
581    }
582
583    // We do not perform any value-dependent control-flow for memory accesses.
584    // Therefore, the miri tests don't require any setup (this helps everything run faseter).
585    fn test_miri_distances_in_place_impl(ndata: usize, ncenters: usize, dim: usize) {
586        let centers = Matrix::new(0.0, ncenters, dim);
587        let data = Matrix::new(0.0, ndata, dim);
588        let data_norms = vec![0.0; ndata];
589        let center_norms = vec![0.0; ncenters];
590        let mut nearest = vec![0; ndata];
591
592        let _ = distances_in_place(
593            BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
594            &data_norms,
595            centers.as_view(),
596            &center_norms,
597            &mut nearest,
598        );
599    }
600
601    #[test]
602    fn test_miri_distances_in_place() {
603        // We need to hit each dimension between 0 and a low-multiple of the tiling size
604        // of 16.
605        //
606        // Set the upper-bound to 32.
607        //
608        // The implementation is not sensitive to the dimension, so we can keep that low.
609        //
610        // Similarly, we need to ensure we have both an even and odd number of centers,
611        // so bound this up to 5.
612        for ndata in 1..=35 {
613            for ncenters in 1..=5 {
614                for dim in 1..=4 {
615                    test_miri_distances_in_place_impl(ndata, ncenters, dim);
616                }
617            }
618        }
619    }
620
621    // End-to-end test.
622    // The strategy is to initialize a dataset as a shuffled version of the following:
623    // ```test
624    //   0   0   0   0 ...
625    //   1   1   1   1 ...
626    //   2   2   2   2 ...
627    //
628    // 100 100 100 100 ...
629    // 101 101 101 101 ...
630    // 102 102 102 102 ...
631    //
632    // 200 200 200 200 ...
633    // ...
634    // ```
635    // And to initialize centers as
636    // ```
637    //  -1  -1  -1  -1 ...
638    //  99  99  99  99 ...
639    // 199 199 199 199 ...
640    // ```
641    // After one round of Lloyds algorithm, the centers should be updated to be the
642    // center of their respective cluster.
643    #[derive(Debug)]
644    struct EndToEndSetup {
645        ncenters: usize,
646        ndim: usize,
647        data_per_center: usize,
648        step_between_clusters: usize,
649        ntrials: usize,
650    }
651
652    fn end_to_end_test_impl<R: Rng>(setup: &EndToEndSetup, rng: &mut R) {
653        // How far apart each cluster is.
654        let mut values: Vec<usize> = (0..setup.ncenters)
655            .flat_map(|i| {
656                (0..setup.data_per_center).map(move |j| setup.step_between_clusters * i + j)
657            })
658            .collect();
659
660        let mut center_order: Vec<usize> = (0..setup.ncenters).collect();
661        let mut data = Matrix::new(0.0, setup.ncenters * setup.data_per_center, setup.ndim);
662        let mut centers = Matrix::new(0.0, setup.ncenters, setup.ndim);
663
664        for trial in 0..setup.ntrials {
665            values.shuffle(rng);
666            center_order.shuffle(rng);
667
668            // Populate centers
669            assert_eq!(center_order.len(), centers.nrows());
670            for (c, row) in std::iter::zip(center_order.iter(), centers.row_iter_mut()) {
671                row.fill((setup.step_between_clusters * c) as f32 - 1.0);
672            }
673
674            // Populate data.
675            assert_eq!(values.len(), data.nrows());
676            for (d, row) in std::iter::zip(values.iter(), data.row_iter_mut()) {
677                row.fill(*d as f32);
678            }
679
680            // Run 2 iteration of lloyds.
681            // The second iteration ensures that we recompute norms properly.
682            let lloyds_iter = 2;
683            let (assignments, loss) = lloyds(data.as_view(), centers.as_mut_view(), lloyds_iter);
684
685            // Make sure all the assignments are returned correctly.
686            assert_eq!(assignments.len(), values.len());
687            for (i, (&got, v)) in std::iter::zip(assignments.iter(), values.iter()).enumerate() {
688                let expected: usize = v / setup.step_between_clusters;
689                assert_eq!(
690                    center_order[got as usize], expected,
691                    "failed at position {} in trial {} - prevalue: {} -- {:?}",
692                    i, trial, v, setup
693                );
694            }
695
696            // Make sure `centers` were properly set to their mean value.
697            let triangle_sum = setup.data_per_center * (setup.data_per_center - 1) / 2;
698            center_order.iter().enumerate().for_each(|(i, o)| {
699                let expected = (setup.step_between_clusters * setup.data_per_center * o
700                    + triangle_sum) as f32
701                    / setup.data_per_center as f32;
702                assert!(
703                    centers.row(i).iter().all(|v| *v == expected),
704                    "at index {}, expected {}, got {:?} -- {:?}",
705                    i,
706                    expected,
707                    centers.row(i),
708                    setup,
709                );
710            });
711
712            // Verify the loss is correct.
713            let expected_loss: f32 = std::iter::zip(assignments.iter(), data.row_iter())
714                .map(|(a, row)| -> f32 {
715                    let c = centers.row(*a as usize);
716                    SquaredL2::evaluate(row, c)
717                })
718                .sum::<f32>();
719            assert_eq!(loss, expected_loss);
720        }
721    }
722
723    #[test]
724    fn end_to_end_test() {
725        let mut rng = StdRng::seed_from_u64(0xff22c38d0f0531bf);
726        let setup = if cfg!(miri) {
727            EndToEndSetup {
728                ncenters: 3,
729                ndim: 4,
730                data_per_center: 2,
731                step_between_clusters: 20,
732                ntrials: 2,
733            }
734        } else {
735            EndToEndSetup {
736                ncenters: 11,
737                ndim: 4,
738                data_per_center: 8,
739                step_between_clusters: 20,
740                ntrials: 10,
741            }
742        };
743        end_to_end_test_impl(&setup, &mut rng);
744    }
745
746    /////////////////////////////////
747    // Panics - distances_in_place //
748    /////////////////////////////////
749
750    // Verify that our panic safety-checks are in-place.
751    #[test]
752    #[should_panic(expected = "dataset and data norms should have the same length")]
753    fn distances_in_place_panics_data_norms() {
754        let data = Matrix::new(0.0, 5, 8);
755        let data_norms = vec![0.0; data.nrows() + 1]; // Incorrect
756        let centers = Matrix::new(0.0, 2, 8);
757        let center_norms = vec![0.0; centers.nrows()];
758        let mut nearest = vec![0; data.nrows()];
759        distances_in_place(
760            BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
761            &data_norms,
762            centers.as_view(),
763            &center_norms,
764            &mut nearest,
765        );
766    }
767
768    #[test]
769    #[should_panic(expected = "dataset and centers should have the same dimension")]
770    fn distances_in_place_panics_different_dim() {
771        let data = Matrix::new(0.0, 5, 8);
772        let data_norms = vec![0.0; data.nrows()];
773        let centers = Matrix::new(0.0, 2, 9); // Incorrect
774        let center_norms = vec![0.0; centers.nrows()];
775        let mut nearest = vec![0; data.nrows()];
776        distances_in_place(
777            BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
778            &data_norms,
779            centers.as_view(),
780            &center_norms,
781            &mut nearest,
782        );
783    }
784
785    #[test]
786    #[should_panic(expected = "centers and center norms should have the same length")]
787    fn distances_in_place_panics_center_norms() {
788        let data = Matrix::new(0.0, 5, 8);
789        let data_norms = vec![0.0; data.nrows()];
790        let centers = Matrix::new(0.0, 2, 8);
791        let center_norms = vec![0.0; centers.nrows() + 1]; // Incorrect
792        let mut nearest = vec![0; data.nrows()];
793        distances_in_place(
794            BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
795            &data_norms,
796            centers.as_view(),
797            &center_norms,
798            &mut nearest,
799        );
800    }
801
802    #[test]
803    #[should_panic(expected = "dataset and nearest-buffer should have the same length")]
804    fn distances_in_place_panics_nearest() {
805        let data = Matrix::new(0.0, 5, 8);
806        let data_norms = vec![0.0; data.nrows()];
807        let centers = Matrix::new(0.0, 2, 8);
808        let center_norms = vec![0.0; centers.nrows()];
809        let mut nearest = vec![0; data.nrows() + 1]; // Incorrect
810        distances_in_place(
811            BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
812            &data_norms,
813            centers.as_view(),
814            &center_norms,
815            &mut nearest,
816        );
817    }
818
819    ///////////////////////////
820    // Panics - lloyds_inner //
821    ///////////////////////////
822
823    #[test]
824    #[should_panic(expected = "data and norms should have the same length")]
825    fn lloyds_inner_panics_norms_length() {
826        let data = Matrix::new(0.0, 5, 8);
827        let square_norms = vec![0.0; data.nrows() + 1]; // Incorrect
828        let mut centers = Matrix::new(0.0, 2, 8);
829        lloyds_inner(
830            data.as_view().into(),
831            &square_norms,
832            BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
833            centers.as_mut_view(),
834            1,
835        );
836    }
837
838    #[test]
839    #[should_panic(expected = "data and transpose should have the same length")]
840    fn lloyds_inner_panics_transpose_length() {
841        let data = Matrix::new(0.0, 5, 8);
842        let data_incorrect = Matrix::new(0.0, 5 + 1, 8); // Incorrect
843        let square_norms = vec![0.0; data.nrows()];
844        let mut centers = Matrix::new(0.0, 2, 8);
845        lloyds_inner(
846            data.as_view().into(),
847            &square_norms,
848            BlockTransposed::<f32, 16>::from_matrix_view(data_incorrect.as_view()).as_view(),
849            centers.as_mut_view(),
850            1,
851        );
852    }
853
854    #[test]
855    #[should_panic(expected = "data and transpose should have the same dimensions")]
856    fn lloyds_inner_panics_transpose_dim() {
857        let data = Matrix::new(0.0, 5, 8);
858        let data_incorrect = Matrix::new(0.0, 5, 8 + 1); // Incorrect
859        let square_norms = vec![0.0; data.nrows()];
860        let mut centers = Matrix::new(0.0, 2, 8);
861        lloyds_inner(
862            data.as_view().into(),
863            &square_norms,
864            BlockTransposed::<f32, 16>::from_matrix_view(data_incorrect.as_view()).as_view(), // Incorrect
865            centers.as_mut_view(),
866            1,
867        );
868    }
869
870    #[test]
871    #[should_panic(expected = "data and centers should have the same dimensions")]
872    fn lloyds_inner_panics_centers_dim() {
873        let data = Matrix::new(0.0, 5, 8);
874        let square_norms = vec![0.0; data.nrows()];
875        let mut centers = Matrix::new(0.0, 2, 8 + 1); // Incorrect
876        lloyds_inner(
877            data.as_view().into(),
878            &square_norms,
879            BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
880            centers.as_mut_view(),
881            1,
882        );
883    }
884
885    ////////////////////
886    // Panics - lloyds//
887    ////////////////////
888
889    #[test]
890    #[should_panic(expected = "data and centers must have the same dimension")]
891    fn lloyds_panics_dim_mismatch() {
892        let data = Matrix::new(0.0, 5, 8);
893        let mut centers = Matrix::new(0.0, 5, 8 + 1); // Incorrect
894        lloyds(data.as_view(), centers.as_mut_view(), 1);
895    }
896}