diskann_quantization/algorithms/kmeans/
common.rs1use diskann_utils::{
7 strided::StridedView,
8 views::{MatrixView, MutMatrixView},
9};
10use diskann_wide::{SIMDMulAdd, SIMDSumTree, SIMDVector};
11
12pub(crate) fn square_norm(x: &[f32]) -> f32 {
14 let px: *const f32 = x.as_ptr();
15 let len = x.len();
16
17 diskann_wide::alias!(f32s = f32x8);
18
19 let mut i = 0;
20 let mut s = f32s::default(diskann_wide::ARCH);
21
22 if i + 32 <= len {
24 let mut s0 = f32s::default(diskann_wide::ARCH);
25 let mut s1 = f32s::default(diskann_wide::ARCH);
26 let mut s2 = f32s::default(diskann_wide::ARCH);
27 let mut s3 = f32s::default(diskann_wide::ARCH);
28 while i + 32 <= len {
29 let vx = unsafe { f32s::load_simd(diskann_wide::ARCH, px.add(i)) };
31 s0 = vx.mul_add_simd(vx, s0);
32
33 let vx = unsafe { f32s::load_simd(diskann_wide::ARCH, px.add(i + 8)) };
35 s1 = vx.mul_add_simd(vx, s1);
36
37 let vx = unsafe { f32s::load_simd(diskann_wide::ARCH, px.add(i + 16)) };
39 s2 = vx.mul_add_simd(vx, s2);
40
41 let vx = unsafe { f32s::load_simd(diskann_wide::ARCH, px.add(i + 24)) };
43 s3 = vx.mul_add_simd(vx, s3);
44
45 i += 32;
46 }
47
48 s = (s0 + s1) + (s2 + s3)
49 }
50
51 while i + 8 <= len {
52 let vx = unsafe { f32s::load_simd(diskann_wide::ARCH, px.add(i)) };
54 s = vx.mul_add_simd(vx, s);
55 i += 8;
56 }
57
58 let remainder = len - i;
59 if remainder != 0 {
60 let vx = unsafe { f32s::load_simd_first(diskann_wide::ARCH, px.add(i), remainder) };
65 s = vx.mul_add_simd(vx, s);
66 }
67
68 s.sum_tree()
69}
70
71#[derive(Debug)]
126pub struct BlockTranspose<const N: usize> {
127 data: Box<[f32]>,
128 block_size: usize,
129 full_blocks: usize,
131 nrows: usize,
133}
134
135impl<const N: usize> BlockTranspose<N> {
136 pub fn new_matrix(nrows: usize, ncols: usize) -> Self {
140 let block_size = ncols;
141 let full_blocks = nrows / N;
142 let remainder = nrows - full_blocks * N;
143
144 let num_blocks = if remainder == 0 {
145 full_blocks
146 } else {
147 full_blocks + 1
148 };
149
150 Self {
151 data: vec![0.0; N * block_size * num_blocks].into(),
152 block_size,
153 full_blocks,
154 nrows,
155 }
156 }
157
158 pub fn nrows(&self) -> usize {
160 self.nrows
161 }
162
163 pub fn ncols(&self) -> usize {
165 self.block_size
166 }
167
168 pub fn block_size(&self) -> usize {
172 self.block_size
173 }
174
175 pub fn group_size(&self) -> usize {
179 N
180 }
181
182 pub const fn const_group_size() -> usize {
184 N
185 }
186
187 pub fn full_blocks(&self) -> usize {
191 self.full_blocks
192 }
193
194 pub fn num_blocks(&self) -> usize {
199 if self.remainder() == 0 {
200 self.full_blocks()
201 } else {
202 self.full_blocks() + 1
203 }
204 }
205
206 pub fn remainder(&self) -> usize {
210 self.nrows % N
211 }
212
213 pub unsafe fn block_ptr_unchecked(&self, block: usize) -> *const f32 {
223 debug_assert!(block < self.num_blocks());
224 unsafe { self.data.as_ptr().add(self.block_offset(block)) }
231 }
232
233 pub fn as_ptr(&self) -> *const f32 {
235 self.data.as_ptr()
236 }
237
238 fn block_offset(&self, block: usize) -> usize {
240 self.block_stride() * block
241 }
242
243 fn block_stride(&self) -> usize {
246 N * self.block_size
247 }
248
249 #[allow(clippy::expect_used)]
255 pub fn block(&self, block: usize) -> MatrixView<'_, f32> {
256 assert!(block < self.full_blocks());
257 let offset = self.block_offset(block);
258 let stride = self.block_stride();
259 let block_size = self.block_size();
260 MatrixView::try_from(&self.data[offset..offset + stride], block_size, N)
261 .expect("base data should have been sized correctly")
262 }
263
264 #[allow(clippy::expect_used)]
266 pub fn remainder_block(&self) -> Option<MatrixView<'_, f32>> {
267 if self.remainder() == 0 {
268 None
269 } else {
270 let offset = self.block_offset(self.full_blocks());
271 let stride = self.block_stride();
272 Some(
273 MatrixView::try_from(&self.data[offset..offset + stride], self.block_size, N)
274 .expect("base data should have been sized correctly"),
275 )
276 }
277 }
278
279 #[allow(clippy::expect_used)]
285 pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, f32> {
286 assert!(block < self.full_blocks());
287 let offset = self.block_offset(block);
288 let stride = self.block_stride();
289 let block_size = self.block_size();
290 MutMatrixView::try_from(&mut self.data[offset..offset + stride], block_size, N)
291 .expect("base data should have been sized correctly")
292 }
293
294 #[allow(clippy::expect_used)]
296 pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, f32>> {
297 if self.remainder() == 0 {
298 None
299 } else {
300 let offset = self.block_offset(self.full_blocks());
301 let stride = self.block_stride();
302 let block_size = self.block_size();
303 Some(
304 MutMatrixView::try_from(&mut self.data[offset..offset + stride], block_size, N)
305 .expect("base data should have been sized correctly"),
306 )
307 }
308 }
309
310 pub fn from_strided(v: StridedView<'_, f32>) -> Self {
316 let mut data = BlockTranspose::<N>::new_matrix(v.nrows(), v.ncols());
317
318 let full_blocks = data.full_blocks();
320 for block_index in 0..full_blocks {
321 let mut block = data.block_mut(block_index);
322 for col in 0..v.ncols() {
323 for row in 0..N {
324 block[(col, row)] = v[(N * block_index + row, col)]
325 }
326 }
327 }
328
329 let remaining_rows = data.remainder();
331 if let Some(mut block) = data.remainder_block_mut() {
332 for col in 0..v.ncols() {
333 for row in 0..remaining_rows {
334 block[(col, row)] = v[(N * full_blocks + row, col)]
335 }
336 }
337 }
338
339 data
340 }
341
342 pub fn from_matrix_view(v: MatrixView<'_, f32>) -> Self {
344 Self::from_strided(v.into())
345 }
346}
347
348impl<const N: usize> std::ops::Index<(usize, usize)> for BlockTranspose<N> {
349 type Output = f32;
350
351 fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
357 assert!(row < self.nrows());
358 assert!(col < self.ncols());
359
360 let block = row / N;
361 let offset = row % N;
362 &self.data[self.block_offset(block) + col * N + offset]
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use diskann_utils::{lazy_format, views::Matrix};
369 use rand::{
370 Rng, SeedableRng,
371 distr::{Distribution, Uniform},
372 rngs::StdRng,
373 };
374
375 use super::*;
376 use crate::utils::div_round_up;
377
378 fn square_norm_reference(x: &[f32]) -> f32 {
383 x.iter().map(|&i| i * i).sum()
384 }
385
386 fn test_square_norm_impl<R: Rng>(
387 dim: usize,
388 ntrials: usize,
389 relative_error: f32,
390 absolute_error: f32,
391 rng: &mut R,
392 ) {
393 let distribution = Uniform::<f32>::new(-1.0, 1.0).unwrap();
394 let mut x: Vec<f32> = vec![0.0; dim];
395 for trial in 0..ntrials {
396 x.iter_mut().for_each(|i| *i = distribution.sample(rng));
397 let expected = square_norm_reference(&x);
398 let got = square_norm(&x);
399
400 let this_absolute_error = (expected - got).abs();
401 let this_relative_error = this_absolute_error / expected.abs();
402
403 let absolute_ok = this_absolute_error <= absolute_error;
404 let relative_ok = this_relative_error <= relative_error;
405
406 if !absolute_ok && !relative_ok {
407 panic!(
408 "recieved abolute/relative errors of {}/{} when the bounds were {}/{}\n\
409 dim = {}, trial = {} of {}",
410 this_absolute_error,
411 this_relative_error,
412 absolute_error,
413 relative_error,
414 dim,
415 trial,
416 ntrials,
417 )
418 }
419 }
420 }
421
422 cfg_if::cfg_if! {
423 if #[cfg(miri)] {
424 const NTRIALS: usize = 1;
425 const MAX_DIM: usize = 80;
426 } else {
427 const NTRIALS: usize = 100;
428 const MAX_DIM: usize = 128;
429 }
430 }
431
432 #[test]
433 fn test_square_norm() {
434 let mut rng = StdRng::seed_from_u64(0x71d00ad8c7105273);
435 for dim in 0..MAX_DIM {
436 let relative_error = 8.0e-7;
437 let absolute_error = 1.0e-5;
438
439 test_square_norm_impl(dim, NTRIALS, relative_error, absolute_error, &mut rng);
440 }
441 }
442
443 fn test_block_transpose<const N: usize>(nrows: usize, ncols: usize) {
448 let context = lazy_format!("N = {}, nrows = {}, ncols = {}", N, nrows, ncols);
449
450 let mut data = Matrix::new(0.0, nrows, ncols);
456 data.as_mut_slice()
457 .iter_mut()
458 .enumerate()
459 .for_each(|(i, d)| *d = i as f32);
460
461 let mut transpose = BlockTranspose::<N>::from_matrix_view(data.as_view());
462
463 assert_eq!(transpose.nrows(), nrows, "{}", context);
465 assert_eq!(transpose.ncols(), ncols, "{}", context);
466 assert_eq!(transpose.block_size(), ncols, "{}", context);
467 assert_eq!(transpose.group_size(), N, "{}", context);
468 assert_eq!(transpose.full_blocks(), nrows / N, "{}", context);
469 assert_eq!(
470 transpose.num_blocks(),
471 div_round_up(nrows, N),
472 "{}",
473 context
474 );
475 assert_eq!(transpose.remainder(), nrows % N, "{}", context);
476
477 for row in 0..nrows {
479 for col in 0..ncols {
480 assert_eq!(
481 data[(row, col)],
482 transpose[(row, col)],
483 "failed for (row, col) = ({}, {})",
484 row,
485 col
486 );
487 }
488 }
489
490 for b in 0..transpose.full_blocks() {
492 let block = transpose.block(b);
493 assert_eq!(block.nrows(), ncols);
494 assert_eq!(block.ncols(), N);
495
496 for i in 0..block.nrows() {
498 for j in 0..block.ncols() {
499 assert_eq!(
500 block[(i, j)],
501 data[(N * b + j, i)],
502 "failed in block {}, row {}, col {} -- {}",
503 b,
504 i,
505 j,
506 context
507 );
508 }
509 }
510
511 let ptr = unsafe { transpose.block_ptr_unchecked(b) };
514 assert_eq!(ptr, block.as_slice().as_ptr());
515
516 let mut block_mut = transpose.block_mut(b);
518 assert_eq!(ptr, block_mut.as_slice().as_ptr());
519 assert_eq!(block_mut.nrows(), ncols);
520 assert_eq!(block_mut.ncols(), N);
521 block_mut.as_mut_slice().fill(0.0);
522 }
523
524 let expected_remainder = nrows % N;
525 if expected_remainder != 0 {
526 let b = transpose.full_blocks();
527 let block = transpose.remainder_block().unwrap();
528 assert_eq!(block.nrows(), ncols);
529 assert_eq!(block.ncols(), N);
530
531 for i in 0..block.nrows() {
533 for j in 0..expected_remainder {
534 assert_eq!(
535 block[(i, j)],
536 data[(N * b + j, i)],
537 "failed in block {}, row {}, col {} -- {}",
538 b,
539 i,
540 j,
541 context
542 );
543 }
544 }
545
546 let ptr = unsafe { transpose.block_ptr_unchecked(b) };
549 assert_eq!(ptr, block.as_slice().as_ptr());
550
551 let mut block_mut = transpose.remainder_block_mut().unwrap();
553 assert_eq!(ptr, block_mut.as_slice().as_ptr());
554 assert_eq!(block_mut.nrows(), ncols);
555 assert_eq!(block_mut.ncols(), N);
556 block_mut.as_mut_slice().fill(0.0);
557 } else {
558 assert!(transpose.remainder_block().is_none());
559 assert!(transpose.remainder_block_mut().is_none());
560 }
561
562 assert!(transpose.data.iter().all(|i| *i == 0.0));
564 }
565
566 #[test]
567 fn test_block_transpose_16() {
568 for nrows in 0..128 {
569 for ncols in 0..5 {
570 test_block_transpose::<16>(nrows, ncols);
571 }
572 }
573 }
574
575 #[test]
576 fn test_block_transpose_8() {
577 for nrows in 0..128 {
578 for ncols in 0..5 {
579 test_block_transpose::<8>(nrows, ncols);
580 }
581 }
582 }
583}