1use diskann_wide::{SIMDMask, SIMDMulAdd, SIMDPartialOrd, SIMDSelect, SIMDSumTree, SIMDVector};
7
8use super::common::square_norm;
9use crate::multi_vector::{BlockTransposed, BlockTransposedRef};
10use diskann_utils::{
11 strided::StridedView,
12 views::{Matrix, MatrixView, MutMatrixView},
13};
14
15diskann_wide::alias!(f32s = f32x8);
20diskann_wide::alias!(u32s = u32x8);
21
22pub fn distances_in_place(
29 dataset: BlockTransposedRef<'_, f32, 16>,
30 data_norms: &[f32],
31 centers: MatrixView<'_, f32>,
32 center_norms: &[f32],
33 nearest: &mut [u32],
34) -> f32 {
35 assert_eq!(
40 dataset.nrows(),
41 data_norms.len(),
42 "dataset and data norms should have the same length"
43 );
44 assert_eq!(
46 centers.ncols(),
47 dataset.ncols(),
48 "dataset and centers should have the same dimensions"
49 );
50 assert_eq!(
52 centers.nrows(),
53 center_norms.len(),
54 "centers and center norms should have the same length"
55 );
56 assert_eq!(
58 nearest.len(),
59 dataset.nrows(),
60 "dataset and nearest-buffer should have the same length"
61 );
62
63 const N: usize = 16;
64 const N2: usize = N / 2;
65
66 diskann_wide::alias!(m32s = mask_f32x8);
67
68 let mut residual = f32s::default(diskann_wide::ARCH);
69
70 let process_block_unroll_2 = |block: usize, center_row_start: usize| {
78 debug_assert!(block < dataset.num_blocks());
79 debug_assert!(center_row_start + 1 < centers.nrows());
80
81 let mut s00 = f32s::default(diskann_wide::ARCH);
82 let mut s01 = f32s::default(diskann_wide::ARCH);
83 let mut s10 = f32s::default(diskann_wide::ARCH);
84 let mut s11 = f32s::default(diskann_wide::ARCH);
85
86 let block_ptr = unsafe { dataset.block_ptr_unchecked(block) };
88 for dim in 0..dataset.ncols() {
89 let d0 = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(N * dim)) };
92 let d1 = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(N * dim + N2)) };
94
95 let c0 = f32s::splat(diskann_wide::ARCH, unsafe {
97 *centers.get_unchecked(center_row_start, dim)
98 });
99 let c1 = f32s::splat(diskann_wide::ARCH, unsafe {
101 *centers.get_unchecked(center_row_start + 1, dim)
102 });
103
104 s00 = c0.mul_add_simd(d0, s00);
105 s01 = c0.mul_add_simd(d1, s01);
106 s10 = c1.mul_add_simd(d0, s10);
107 s11 = c1.mul_add_simd(d1, s11);
108 }
109 (s00, s01, s10, s11)
110 };
111
112 let process_block_no_unroll = |block: usize, center_row_start: usize| {
120 debug_assert!(block < dataset.num_blocks());
121 debug_assert!(center_row_start + 1 == centers.nrows());
122
123 let mut s00 = f32s::default(diskann_wide::ARCH);
124 let mut s01 = f32s::default(diskann_wide::ARCH);
125
126 let block_ptr = unsafe { dataset.block_ptr_unchecked(block) };
128 for dim in 0..dataset.ncols() {
129 let d0 = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(N * dim)) };
132 let d1 = unsafe { f32s::load_simd(diskann_wide::ARCH, block_ptr.add(N * dim + N2)) };
134
135 let c0 = f32s::splat(diskann_wide::ARCH, unsafe {
137 *centers.get_unchecked(center_row_start, dim)
138 });
139
140 s00 = c0.mul_add_simd(d0, s00);
141 s01 = c0.mul_add_simd(d1, s01);
142 }
143 (s00, s01)
144 };
145
146 let last_pair = if centers.nrows().is_multiple_of(2) {
151 centers.nrows()
152 } else {
153 centers.nrows() - 1
154 };
155
156 for i in 0..dataset.full_blocks() {
157 let mut t0 = (
158 f32s::splat(diskann_wide::ARCH, f32::INFINITY),
159 u32s::splat(diskann_wide::ARCH, u32::MAX),
160 );
161 let mut t1 = (
162 f32s::splat(diskann_wide::ARCH, f32::INFINITY),
163 u32s::splat(diskann_wide::ARCH, u32::MAX),
164 );
165
166 let data_norm_ptr = unsafe { data_norms.as_ptr().add(N * i) };
168
169 let d0 = unsafe { f32s::load_simd(diskann_wide::ARCH, data_norm_ptr) };
172
173 let d1 = unsafe { f32s::load_simd(diskann_wide::ARCH, data_norm_ptr.add(N2)) };
176 for row_start in (0..last_pair).step_by(2) {
177 let (s00, s01, s10, s11) = process_block_unroll_2(i, row_start);
180
181 let n0 = f32s::splat(diskann_wide::ARCH, *unsafe {
184 center_norms.get_unchecked(row_start)
185 });
186 let n1 = f32s::splat(diskann_wide::ARCH, *unsafe {
188 center_norms.get_unchecked(row_start + 1)
189 });
190
191 let s00 = n0 - s00 - s00 + d0;
192 let s01 = n0 - s01 - s01 + d1;
193 let s10 = n1 - s10 - s10 + d0;
194 let s11 = n1 - s11 - s11 + d1;
195
196 let r0 = u32s::splat(diskann_wide::ARCH, row_start as u32);
197 let r1 = u32s::splat(diskann_wide::ARCH, (row_start + 1) as u32);
198 t0 = update(update(t0, (s00, r0)), (s10, r1));
199 t1 = update(update(t1, (s01, r0)), (s11, r1));
200 }
201
202 if !centers.nrows().is_multiple_of(2) {
204 let (s00, s01) = process_block_no_unroll(i, last_pair);
207 let n0 = f32s::splat(diskann_wide::ARCH, unsafe {
209 *center_norms.get_unchecked(last_pair)
210 });
211
212 let s00 = n0 - s00 - s00 + d0;
213 let s01 = n0 - s01 - s01 + d1;
214
215 let r = u32s::splat(diskann_wide::ARCH, last_pair as u32);
216 t0 = update(t0, (s00, r));
217 t1 = update(t1, (s01, r));
218 }
219
220 unsafe { t0.1.store_simd(nearest.as_mut_ptr().add(N * i)) }
224 unsafe { t1.1.store_simd(nearest.as_mut_ptr().add(N * i + N2)) }
227
228 residual = residual + t0.0 + t1.0;
230 }
231
232 let remainder = dataset.remainder();
236 if remainder != 0 {
237 let i = dataset.full_blocks();
238 let lo = remainder.min(N2);
239 let hi = remainder - lo;
240
241 let mut t0 = (
242 f32s::splat(diskann_wide::ARCH, f32::INFINITY),
243 u32s::splat(diskann_wide::ARCH, u32::MAX),
244 );
245 let mut t1 = (
246 f32s::splat(diskann_wide::ARCH, f32::INFINITY),
247 u32s::splat(diskann_wide::ARCH, u32::MAX),
248 );
249
250 let data_norm_ptr = unsafe { data_norms.as_ptr().add(N * i) };
252
253 let d0 = unsafe { f32s::load_simd_first(diskann_wide::ARCH, data_norm_ptr, lo) };
256 let d1 = if hi == 0 {
257 f32s::default(diskann_wide::ARCH)
258 } else {
259 unsafe { f32s::load_simd_first(diskann_wide::ARCH, data_norm_ptr.add(N2), hi) }
263 };
264
265 for row_start in (0..last_pair).step_by(2) {
266 let (s00, s01, s10, s11) = process_block_unroll_2(i, row_start);
269
270 let n0 = f32s::splat(diskann_wide::ARCH, *unsafe {
273 center_norms.get_unchecked(row_start)
274 });
275 let n1 = f32s::splat(diskann_wide::ARCH, *unsafe {
277 center_norms.get_unchecked(row_start + 1)
278 });
279
280 let s00 = n0 - s00 - s00 + d0;
281 let s01 = n0 - s01 - s01 + d1;
282 let s10 = n1 - s10 - s10 + d0;
283 let s11 = n1 - s11 - s11 + d1;
284
285 let r0 = u32s::splat(diskann_wide::ARCH, row_start as u32);
286 let r1 = u32s::splat(diskann_wide::ARCH, (row_start + 1) as u32);
287 t0 = update(update(t0, (s00, r0)), (s10, r1));
288 t1 = update(update(t1, (s01, r0)), (s11, r1));
289 }
290
291 if !centers.nrows().is_multiple_of(2) {
292 let (s00, s01) = process_block_no_unroll(i, last_pair);
295 let n0 = f32s::splat(diskann_wide::ARCH, unsafe {
297 *center_norms.get_unchecked(last_pair)
298 });
299
300 let s00 = n0 - s00 - s00 + d0;
301 let s01 = n0 - s01 - s01 + d1;
302
303 let r = u32s::splat(diskann_wide::ARCH, last_pair as u32);
304 t0 = update(t0, (s00, r));
305 t1 = update(t1, (s01, r));
306 }
307
308 unsafe { t0.1.store_simd_first(nearest.as_mut_ptr().add(N * i), lo) };
312 if hi != 0 {
313 unsafe {
318 t1.1.store_simd_first(nearest.as_mut_ptr().add(N * i + N2), hi)
319 };
320 }
321
322 residual = m32s::keep_first(diskann_wide::ARCH, lo).select(residual + t0.0, residual);
325 residual = m32s::keep_first(diskann_wide::ARCH, hi).select(residual + t1.0, residual);
326 }
327 residual.sum_tree()
328}
329
330#[inline(always)]
331fn update((d0, i0): (f32s, u32s), (d1, i1): (f32s, u32s)) -> (f32s, u32s) {
332 let mask = d1.lt_simd(d0);
335 (
336 mask.select(d1, d0),
337 <u32s as SIMDVector>::Mask::from(mask).select(i1, i0),
338 )
339}
340
341fn update_centroids(mut centers: MutMatrixView<'_, f32>, data: StridedView<'_, f32>, map: &[u32]) {
346 let mut sums = Matrix::<f64>::new(0.0, centers.nrows(), centers.ncols());
347 let mut counts: Vec<u32> = vec![0; centers.nrows()];
348 data.row_iter().zip(map.iter()).for_each(|(row, ¢er)| {
349 counts[center as usize] += 1;
350 let sum = sums.row_mut(center as usize);
351 std::iter::zip(sum.iter_mut(), row.iter()).for_each(|(s, r)| {
352 *s += <f32 as Into<f64>>::into(*r);
353 });
354 });
355
356 std::iter::zip(counts.iter(), sums.row_iter())
357 .zip(centers.row_iter_mut())
358 .for_each(|((count, sum), center)| {
359 let count = (*count).max(1);
362 std::iter::zip(sum.iter(), center.iter_mut()).for_each(|(s, c)| {
363 *c = (*s / (count as f64)) as f32;
364 });
365 });
366}
367
368pub(crate) fn lloyds_inner(
373 data: StridedView<'_, f32>,
374 square_norms: &[f32],
375 transpose: BlockTransposedRef<'_, f32, 16>,
376 mut centers: MutMatrixView<'_, f32>,
377 max_reps: usize,
378) -> (Vec<u32>, f32) {
379 let num_data = data.nrows();
381 assert_eq!(
382 num_data,
383 square_norms.len(),
384 "data and norms should have the same length"
385 );
386 assert_eq!(
387 num_data,
388 transpose.nrows(),
389 "data and transpose should have the same length"
390 );
391
392 let dim = data.ncols();
393 assert_eq!(
394 dim,
395 transpose.ncols(),
396 "data and transpose should have the same dimensions"
397 );
398 assert_eq!(
399 dim,
400 centers.ncols(),
401 "data and centers should have the same dimensions"
402 );
403
404 let mut center_square_norms: Vec<f32> = centers.row_iter().map(square_norm).collect();
405 let mut assignments: Vec<u32> = vec![0; num_data];
406 let mut residual = 0.0;
407
408 for i in 0..max_reps {
409 residual = distances_in_place(
410 transpose,
411 square_norms,
412 centers.as_view(),
413 ¢er_square_norms,
414 &mut assignments,
415 );
416 update_centroids(centers.as_mut_view(), data, &assignments);
417 if i != max_reps - 1 {
418 std::iter::zip(center_square_norms.iter_mut(), centers.row_iter()).for_each(
419 |(c, center)| {
420 *c = square_norm(center);
421 },
422 );
423 }
424 }
425 (assignments, residual)
426}
427
428pub fn lloyds(
442 data: MatrixView<'_, f32>,
443 centers: MutMatrixView<'_, f32>,
444 max_reps: usize,
445) -> (Vec<u32>, f32) {
446 assert_eq!(
447 data.ncols(),
448 centers.ncols(),
449 "data and centers must have the same dimension",
450 );
451
452 let transpose = BlockTransposed::<f32, 16>::from_matrix_view(data);
453 let square_norms: Vec<f32> = data.row_iter().map(square_norm).collect();
454 lloyds_inner(
455 data.into(),
456 &square_norms,
457 transpose.as_view(),
458 centers,
459 max_reps,
460 )
461}
462
463#[cfg(test)]
464mod tests {
465 #[cfg(not(miri))]
466 use diskann_utils::lazy_format;
467 use diskann_utils::views::Matrix;
468 use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
469 use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
470 #[cfg(not(miri))]
471 use rand::{
472 distr::{Distribution, Uniform},
473 seq::IndexedRandom,
474 };
475
476 use super::*;
477
478 #[cfg(not(miri))]
488 fn test_distances_in_place_impl<R: Rng>(
489 ndata: usize,
490 ncenters: usize,
491 dim: usize,
492 trials: usize,
493 rng: &mut R,
494 ) {
495 let context = lazy_format!("ncenters = {}, ndata = {}, dim = {}", ncenters, ndata, dim,);
496
497 let mut centers = Matrix::new(0.0, ncenters, dim);
498 let mut data = Matrix::new(0.0, ndata, dim);
499
500 let offsets = [-0.125, -0.0625, -0.03125, 0.03125, 0.0625, 0.125];
503
504 for (i, row) in centers.row_iter_mut().enumerate() {
506 for c in row {
507 *c = (i as f32) + *offsets.choose(rng).unwrap();
508 }
509 }
510
511 let center_norms: Vec<f32> = centers.row_iter().map(square_norm).collect();
512
513 let assignment_distribution = Uniform::<usize>::new(0, centers.nrows()).unwrap();
515 let mut nearest: Vec<u32> = vec![0; ndata];
516 for trial in 0..trials {
517 let assignments: Vec<_> = (0..ndata)
518 .map(|_| assignment_distribution.sample(rng))
519 .collect();
520
521 for (assignment, row) in std::iter::zip(assignments.iter(), data.row_iter_mut()) {
522 for c in row.iter_mut() {
523 *c = (*assignment as f32) + offsets.choose(rng).unwrap()
524 }
525 }
526
527 let data_norms: Vec<f32> = data.row_iter().map(square_norm).collect();
528
529 let residual = distances_in_place(
530 BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
531 &data_norms,
532 centers.as_view(),
533 ¢er_norms,
534 &mut nearest,
535 );
536
537 for (i, (got, expected)) in
539 std::iter::zip(nearest.iter(), assignments.iter()).enumerate()
540 {
541 assert_eq!(
542 *got as usize,
543 *expected,
544 "failed for data index {} on trial {} -- {}\n\
545 row = {:?}\n\
546 expected = {:?}\n\
547 got = {:?}",
548 i,
549 trial,
550 context,
551 data.row(i),
552 centers.row(*expected),
553 centers.row(*got as usize),
554 );
555 }
556
557 let mut sum: f32 = 0.0;
559 for (a, row) in std::iter::zip(assignments.iter(), data.row_iter()) {
560 let distance: f32 = SquaredL2::evaluate(row, centers.row(*a));
561 sum += distance;
562 }
563 assert_eq!(sum, residual, "failed on trial {} -- {}", trial, context);
564 }
565 }
566
567 #[cfg(not(miri))]
568 const TRIALS: usize = 100;
569
570 #[test]
571 #[cfg(not(miri))]
572 fn test_distances_in_place() {
573 let mut rng = StdRng::seed_from_u64(0xece88a9c6cd86a8a);
574 for ndata in 1..=31 {
575 for ncenters in 1..=5 {
576 for dim in 1..=4 {
577 test_distances_in_place_impl(ndata, ncenters, dim, TRIALS, &mut rng);
578 }
579 }
580 }
581 }
582
583 fn test_miri_distances_in_place_impl(ndata: usize, ncenters: usize, dim: usize) {
586 let centers = Matrix::new(0.0, ncenters, dim);
587 let data = Matrix::new(0.0, ndata, dim);
588 let data_norms = vec![0.0; ndata];
589 let center_norms = vec![0.0; ncenters];
590 let mut nearest = vec![0; ndata];
591
592 let _ = distances_in_place(
593 BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
594 &data_norms,
595 centers.as_view(),
596 ¢er_norms,
597 &mut nearest,
598 );
599 }
600
601 #[test]
602 fn test_miri_distances_in_place() {
603 for ndata in 1..=35 {
613 for ncenters in 1..=5 {
614 for dim in 1..=4 {
615 test_miri_distances_in_place_impl(ndata, ncenters, dim);
616 }
617 }
618 }
619 }
620
621 #[derive(Debug)]
644 struct EndToEndSetup {
645 ncenters: usize,
646 ndim: usize,
647 data_per_center: usize,
648 step_between_clusters: usize,
649 ntrials: usize,
650 }
651
652 fn end_to_end_test_impl<R: Rng>(setup: &EndToEndSetup, rng: &mut R) {
653 let mut values: Vec<usize> = (0..setup.ncenters)
655 .flat_map(|i| {
656 (0..setup.data_per_center).map(move |j| setup.step_between_clusters * i + j)
657 })
658 .collect();
659
660 let mut center_order: Vec<usize> = (0..setup.ncenters).collect();
661 let mut data = Matrix::new(0.0, setup.ncenters * setup.data_per_center, setup.ndim);
662 let mut centers = Matrix::new(0.0, setup.ncenters, setup.ndim);
663
664 for trial in 0..setup.ntrials {
665 values.shuffle(rng);
666 center_order.shuffle(rng);
667
668 assert_eq!(center_order.len(), centers.nrows());
670 for (c, row) in std::iter::zip(center_order.iter(), centers.row_iter_mut()) {
671 row.fill((setup.step_between_clusters * c) as f32 - 1.0);
672 }
673
674 assert_eq!(values.len(), data.nrows());
676 for (d, row) in std::iter::zip(values.iter(), data.row_iter_mut()) {
677 row.fill(*d as f32);
678 }
679
680 let lloyds_iter = 2;
683 let (assignments, loss) = lloyds(data.as_view(), centers.as_mut_view(), lloyds_iter);
684
685 assert_eq!(assignments.len(), values.len());
687 for (i, (&got, v)) in std::iter::zip(assignments.iter(), values.iter()).enumerate() {
688 let expected: usize = v / setup.step_between_clusters;
689 assert_eq!(
690 center_order[got as usize], expected,
691 "failed at position {} in trial {} - prevalue: {} -- {:?}",
692 i, trial, v, setup
693 );
694 }
695
696 let triangle_sum = setup.data_per_center * (setup.data_per_center - 1) / 2;
698 center_order.iter().enumerate().for_each(|(i, o)| {
699 let expected = (setup.step_between_clusters * setup.data_per_center * o
700 + triangle_sum) as f32
701 / setup.data_per_center as f32;
702 assert!(
703 centers.row(i).iter().all(|v| *v == expected),
704 "at index {}, expected {}, got {:?} -- {:?}",
705 i,
706 expected,
707 centers.row(i),
708 setup,
709 );
710 });
711
712 let expected_loss: f32 = std::iter::zip(assignments.iter(), data.row_iter())
714 .map(|(a, row)| -> f32 {
715 let c = centers.row(*a as usize);
716 SquaredL2::evaluate(row, c)
717 })
718 .sum::<f32>();
719 assert_eq!(loss, expected_loss);
720 }
721 }
722
723 #[test]
724 fn end_to_end_test() {
725 let mut rng = StdRng::seed_from_u64(0xff22c38d0f0531bf);
726 let setup = if cfg!(miri) {
727 EndToEndSetup {
728 ncenters: 3,
729 ndim: 4,
730 data_per_center: 2,
731 step_between_clusters: 20,
732 ntrials: 2,
733 }
734 } else {
735 EndToEndSetup {
736 ncenters: 11,
737 ndim: 4,
738 data_per_center: 8,
739 step_between_clusters: 20,
740 ntrials: 10,
741 }
742 };
743 end_to_end_test_impl(&setup, &mut rng);
744 }
745
746 #[test]
752 #[should_panic(expected = "dataset and data norms should have the same length")]
753 fn distances_in_place_panics_data_norms() {
754 let data = Matrix::new(0.0, 5, 8);
755 let data_norms = vec![0.0; data.nrows() + 1]; let centers = Matrix::new(0.0, 2, 8);
757 let center_norms = vec![0.0; centers.nrows()];
758 let mut nearest = vec![0; data.nrows()];
759 distances_in_place(
760 BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
761 &data_norms,
762 centers.as_view(),
763 ¢er_norms,
764 &mut nearest,
765 );
766 }
767
768 #[test]
769 #[should_panic(expected = "dataset and centers should have the same dimension")]
770 fn distances_in_place_panics_different_dim() {
771 let data = Matrix::new(0.0, 5, 8);
772 let data_norms = vec![0.0; data.nrows()];
773 let centers = Matrix::new(0.0, 2, 9); let center_norms = vec![0.0; centers.nrows()];
775 let mut nearest = vec![0; data.nrows()];
776 distances_in_place(
777 BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
778 &data_norms,
779 centers.as_view(),
780 ¢er_norms,
781 &mut nearest,
782 );
783 }
784
785 #[test]
786 #[should_panic(expected = "centers and center norms should have the same length")]
787 fn distances_in_place_panics_center_norms() {
788 let data = Matrix::new(0.0, 5, 8);
789 let data_norms = vec![0.0; data.nrows()];
790 let centers = Matrix::new(0.0, 2, 8);
791 let center_norms = vec![0.0; centers.nrows() + 1]; let mut nearest = vec![0; data.nrows()];
793 distances_in_place(
794 BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
795 &data_norms,
796 centers.as_view(),
797 ¢er_norms,
798 &mut nearest,
799 );
800 }
801
802 #[test]
803 #[should_panic(expected = "dataset and nearest-buffer should have the same length")]
804 fn distances_in_place_panics_nearest() {
805 let data = Matrix::new(0.0, 5, 8);
806 let data_norms = vec![0.0; data.nrows()];
807 let centers = Matrix::new(0.0, 2, 8);
808 let center_norms = vec![0.0; centers.nrows()];
809 let mut nearest = vec![0; data.nrows() + 1]; distances_in_place(
811 BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
812 &data_norms,
813 centers.as_view(),
814 ¢er_norms,
815 &mut nearest,
816 );
817 }
818
819 #[test]
824 #[should_panic(expected = "data and norms should have the same length")]
825 fn lloyds_inner_panics_norms_length() {
826 let data = Matrix::new(0.0, 5, 8);
827 let square_norms = vec![0.0; data.nrows() + 1]; let mut centers = Matrix::new(0.0, 2, 8);
829 lloyds_inner(
830 data.as_view().into(),
831 &square_norms,
832 BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
833 centers.as_mut_view(),
834 1,
835 );
836 }
837
838 #[test]
839 #[should_panic(expected = "data and transpose should have the same length")]
840 fn lloyds_inner_panics_transpose_length() {
841 let data = Matrix::new(0.0, 5, 8);
842 let data_incorrect = Matrix::new(0.0, 5 + 1, 8); let square_norms = vec![0.0; data.nrows()];
844 let mut centers = Matrix::new(0.0, 2, 8);
845 lloyds_inner(
846 data.as_view().into(),
847 &square_norms,
848 BlockTransposed::<f32, 16>::from_matrix_view(data_incorrect.as_view()).as_view(),
849 centers.as_mut_view(),
850 1,
851 );
852 }
853
854 #[test]
855 #[should_panic(expected = "data and transpose should have the same dimensions")]
856 fn lloyds_inner_panics_transpose_dim() {
857 let data = Matrix::new(0.0, 5, 8);
858 let data_incorrect = Matrix::new(0.0, 5, 8 + 1); let square_norms = vec![0.0; data.nrows()];
860 let mut centers = Matrix::new(0.0, 2, 8);
861 lloyds_inner(
862 data.as_view().into(),
863 &square_norms,
864 BlockTransposed::<f32, 16>::from_matrix_view(data_incorrect.as_view()).as_view(), centers.as_mut_view(),
866 1,
867 );
868 }
869
870 #[test]
871 #[should_panic(expected = "data and centers should have the same dimensions")]
872 fn lloyds_inner_panics_centers_dim() {
873 let data = Matrix::new(0.0, 5, 8);
874 let square_norms = vec![0.0; data.nrows()];
875 let mut centers = Matrix::new(0.0, 2, 8 + 1); lloyds_inner(
877 data.as_view().into(),
878 &square_norms,
879 BlockTransposed::<f32, 16>::from_matrix_view(data.as_view()).as_view(),
880 centers.as_mut_view(),
881 1,
882 );
883 }
884
885 #[test]
890 #[should_panic(expected = "data and centers must have the same dimension")]
891 fn lloyds_panics_dim_mismatch() {
892 let data = Matrix::new(0.0, 5, 8);
893 let mut centers = Matrix::new(0.0, 5, 8 + 1); lloyds(data.as_view(), centers.as_mut_view(), 1);
895 }
896}