Skip to main content

diskann_quantization/algorithms/kmeans/
common.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_utils::{
7    strided::StridedView,
8    views::{MatrixView, MutMatrixView},
9};
10use diskann_wide::{SIMDMulAdd, SIMDSumTree, SIMDVector};
11
12/// Compute the squared L2 norm of the argument.
13pub(crate) fn square_norm(x: &[f32]) -> f32 {
14    let px: *const f32 = x.as_ptr();
15    let len = x.len();
16
17    diskann_wide::alias!(f32s = f32x8);
18
19    let mut i = 0;
20    let mut s = f32s::default(diskann_wide::ARCH);
21
22    // The number of 32-bit blocks over the underlying slice.
23    if i + 32 <= len {
24        let mut s0 = f32s::default(diskann_wide::ARCH);
25        let mut s1 = f32s::default(diskann_wide::ARCH);
26        let mut s2 = f32s::default(diskann_wide::ARCH);
27        let mut s3 = f32s::default(diskann_wide::ARCH);
28        while i + 32 <= len {
29            // SAFETY: The memory range `[i, i + 32)` is valid by the loop bounds.
30            let vx = unsafe { f32s::load_simd(diskann_wide::ARCH, px.add(i)) };
31            s0 = vx.mul_add_simd(vx, s0);
32
33            // SAFETY: The memory range `[i, i + 32)` is valid by the loop bounds.
34            let vx = unsafe { f32s::load_simd(diskann_wide::ARCH, px.add(i + 8)) };
35            s1 = vx.mul_add_simd(vx, s1);
36
37            // SAFETY: The memory range `[i, i + 32)` is valid by the loop bounds.
38            let vx = unsafe { f32s::load_simd(diskann_wide::ARCH, px.add(i + 16)) };
39            s2 = vx.mul_add_simd(vx, s2);
40
41            // SAFETY: The memory range `[i, i + 32)` is valid by the loop bounds.
42            let vx = unsafe { f32s::load_simd(diskann_wide::ARCH, px.add(i + 24)) };
43            s3 = vx.mul_add_simd(vx, s3);
44
45            i += 32;
46        }
47
48        s = (s0 + s1) + (s2 + s3)
49    }
50
51    while i + 8 <= len {
52        // SAFETY: The memory range `[i, i + 8)` is valid by the loop bounds.
53        let vx = unsafe { f32s::load_simd(diskann_wide::ARCH, px.add(i)) };
54        s = vx.mul_add_simd(vx, s);
55        i += 8;
56    }
57
58    let remainder = len - i;
59    if remainder != 0 {
60        // SAFETY: The pointer add is valid because `i < len` (strict inequality), so the
61        // base pointer belongs to the memory owned by `x`.
62        //
63        // Furthermore, the load is valid for the first `remainder` items.
64        let vx = unsafe { f32s::load_simd_first(diskann_wide::ARCH, px.add(i), remainder) };
65        s = vx.mul_add_simd(vx, s);
66    }
67
68    s.sum_tree()
69}
70
71////////////////////
72// BlockTranspose //
73////////////////////
74
75/// A representation of 2D data consisting of blocks of tranposes.
76///
77/// The generic parameter `N` denotes how many rows are in a block.
78///
79/// For example, if the original data is in a row major layout like the following:
80/// ```text
81/// a0 a1 a2 a3 ... aK
82/// b0 b1 b2 b3 ... bK
83/// c0 c1 c2 c3 ... cK
84/// d0 d1 d2 d3 ... dK
85/// e0 e1 e2 e3 ... eK
86/// ```
87/// and the blocking parameter `N = 3`, then the blocked-transpose layout (still row major)
88/// will be as follows:
89/// ```text
90///           Group Size (3)
91///            <---------->
92///
93///            +----------+    ^
94///            | a0 b0 c0 |    |
95///            | a1 b1 c1 |    |
96///            | a2 b2 c2 |    | Block Size (K + 1)
97///  Block 0   | a3 b3 c3 |    |
98///  (Full)    | ...      |    |
99///            | aK bK cK |    |
100///            +----------+    v
101///            +----------+
102///            | d0 e0 XX |
103///            | d1 e1 XX |
104///            | d2 e2 XX |
105///  Block 1   | d3 e3 XX |
106///  (Partial) | ...      |
107///            | dK eK XX |
108///            +----------+
109/// ```
110/// Note the following characteristics:
111///
112/// * The same dimension of different source rows are store contiguously (this helps with
113///   SIMD algorithms).
114///
115/// * Subsequent groups of the following dimensions are also stored contiguously.
116///
117/// * Blocks are stored contiguously so all the entire `BlockTranspose` consists of a single
118///   allocation.
119///
120/// * Allocation is done at a block-level of granularity with the last block only partially
121///   filled if the number of rows does not evenly divide the block-size.
122///
123///   Padding is done as indicated in the diagram. SIMD algorithms are free to load entire
124///   rows provided any bookkeeping tracks the partially full status.
125#[derive(Debug)]
126pub struct BlockTranspose<const N: usize> {
127    data: Box<[f32]>,
128    block_size: usize,
129    /// How many blocks are completely filled with data.
130    full_blocks: usize,
131    /// The total number of data rows stored in this representation.
132    nrows: usize,
133}
134
135impl<const N: usize> BlockTranspose<N> {
136    /// Construct a new `BlockTranspose` sized to contain a matrix of size `nrows x ncols`.
137    ///
138    /// Data will be zero initialized.
139    pub fn new_matrix(nrows: usize, ncols: usize) -> Self {
140        let block_size = ncols;
141        let full_blocks = nrows / N;
142        let remainder = nrows - full_blocks * N;
143
144        let num_blocks = if remainder == 0 {
145            full_blocks
146        } else {
147            full_blocks + 1
148        };
149
150        Self {
151            data: vec![0.0; N * block_size * num_blocks].into(),
152            block_size,
153            full_blocks,
154            nrows,
155        }
156    }
157
158    /// Return the number of rows of data stored in `self`.
159    pub fn nrows(&self) -> usize {
160        self.nrows
161    }
162
163    /// Return the number of columns in each rows of the data in `self`.
164    pub fn ncols(&self) -> usize {
165        self.block_size
166    }
167
168    /// Return the number of **physical** rows in each data block.
169    ///
170    /// Conceptually, this is the same as the number of columns.
171    pub fn block_size(&self) -> usize {
172        self.block_size
173    }
174
175    /// Return the length of each row in a block.
176    ///
177    /// Conceptually, this is the number of source-data rows stored in a block.
178    pub fn group_size(&self) -> usize {
179        N
180    }
181
182    /// Return the length of each row in a block.
183    pub const fn const_group_size() -> usize {
184        N
185    }
186
187    /// Return the number of completely full data blocks.
188    ///
189    /// This will always be equal to `self.nrows() / self.group_size()`.
190    pub fn full_blocks(&self) -> usize {
191        self.full_blocks
192    }
193
194    /// Return the total number of data blocks including any partially full terminal block.
195    ///
196    /// This will always be equal to:
197    /// `crate::utils::div_round_up(self.nrows(), self.group_size())`
198    pub fn num_blocks(&self) -> usize {
199        if self.remainder() == 0 {
200            self.full_blocks()
201        } else {
202            self.full_blocks() + 1
203        }
204    }
205
206    /// Return the number of elements in the last partially full block.
207    ///
208    /// A return value of 0 indicates that all blocks are full.
209    pub fn remainder(&self) -> usize {
210        self.nrows % N
211    }
212
213    /// Return a pointer to the beginning of `block`.
214    ///
215    /// The caller may assume that for the returned pointer `ptr`,
216    /// `[ptr, ptr + self.block_stride())` points to valid memory, even for the remainder
217    /// block.
218    ///
219    /// # Safety
220    ///
221    /// Block must be in-bounds (i.e., `block < self.num_blocks()`).
222    pub unsafe fn block_ptr_unchecked(&self, block: usize) -> *const f32 {
223        debug_assert!(block < self.num_blocks());
224        // SAFETY: If we assume `block < self.num_blocks()`, (which the caller attests) then
225        //
226        // 1. Our base pointer was allocated in the first place, so this computed offset
227        //    must fit within an `isize`.
228        // 2. This pointer (and an offset `self.block_stride()`) higher all live within
229        //    a single allocated object.
230        unsafe { self.data.as_ptr().add(self.block_offset(block)) }
231    }
232
233    /// Return a pointer to the start of data segment.
234    pub fn as_ptr(&self) -> *const f32 {
235        self.data.as_ptr()
236    }
237
238    /// The linear offset of the beginning of `block`.
239    fn block_offset(&self, block: usize) -> usize {
240        self.block_stride() * block
241    }
242
243    /// The number of elements of type `f32` in each block (i.e., the spacing between the
244    /// starts of blocks).
245    fn block_stride(&self) -> usize {
246        N * self.block_size
247    }
248
249    /// Return a view over a full `block`.
250    ///
251    /// # Panics
252    ///
253    /// Panics if `block >= self.full_blocks()`.
254    #[allow(clippy::expect_used)]
255    pub fn block(&self, block: usize) -> MatrixView<'_, f32> {
256        assert!(block < self.full_blocks());
257        let offset = self.block_offset(block);
258        let stride = self.block_stride();
259        let block_size = self.block_size();
260        MatrixView::try_from(&self.data[offset..offset + stride], block_size, N)
261            .expect("base data should have been sized correctly")
262    }
263
264    /// Return a view over the remainder block, or `None` if there is no remainder block.
265    #[allow(clippy::expect_used)]
266    pub fn remainder_block(&self) -> Option<MatrixView<'_, f32>> {
267        if self.remainder() == 0 {
268            None
269        } else {
270            let offset = self.block_offset(self.full_blocks());
271            let stride = self.block_stride();
272            Some(
273                MatrixView::try_from(&self.data[offset..offset + stride], self.block_size, N)
274                    .expect("base data should have been sized correctly"),
275            )
276        }
277    }
278
279    /// Return a mutable view over a full `block`.
280    ///
281    /// # Panics
282    ///
283    /// Panics if `block >= self.full_blocks()`.
284    #[allow(clippy::expect_used)]
285    pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, f32> {
286        assert!(block < self.full_blocks());
287        let offset = self.block_offset(block);
288        let stride = self.block_stride();
289        let block_size = self.block_size();
290        MutMatrixView::try_from(&mut self.data[offset..offset + stride], block_size, N)
291            .expect("base data should have been sized correctly")
292    }
293
294    /// Return a mutable view over the remainder block, or `None` if there is no remainder block.
295    #[allow(clippy::expect_used)]
296    pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, f32>> {
297        if self.remainder() == 0 {
298            None
299        } else {
300            let offset = self.block_offset(self.full_blocks());
301            let stride = self.block_stride();
302            let block_size = self.block_size();
303            Some(
304                MutMatrixView::try_from(&mut self.data[offset..offset + stride], block_size, N)
305                    .expect("base data should have been sized correctly"),
306            )
307        }
308    }
309
310    //////////////////
311    // Constructors //
312    //////////////////
313
314    /// Construct a copy of `v` inside a BlockTranspose.
315    pub fn from_strided(v: StridedView<'_, f32>) -> Self {
316        let mut data = BlockTranspose::<N>::new_matrix(v.nrows(), v.ncols());
317
318        // Pack full blocks
319        let full_blocks = data.full_blocks();
320        for block_index in 0..full_blocks {
321            let mut block = data.block_mut(block_index);
322            for col in 0..v.ncols() {
323                for row in 0..N {
324                    block[(col, row)] = v[(N * block_index + row, col)]
325                }
326            }
327        }
328
329        // Pack remainder
330        let remaining_rows = data.remainder();
331        if let Some(mut block) = data.remainder_block_mut() {
332            for col in 0..v.ncols() {
333                for row in 0..remaining_rows {
334                    block[(col, row)] = v[(N * full_blocks + row, col)]
335                }
336            }
337        }
338
339        data
340    }
341
342    /// Construct a copy of `v` inside a BlockTranspose.
343    pub fn from_matrix_view(v: MatrixView<'_, f32>) -> Self {
344        Self::from_strided(v.into())
345    }
346}
347
348impl<const N: usize> std::ops::Index<(usize, usize)> for BlockTranspose<N> {
349    type Output = f32;
350
351    /// Return a reference the the element at the logical `(row, col)`.
352    ///
353    /// # Panics
354    ///
355    /// Panics if `row >= self.nrows()` or `col >= self.ncols()`.
356    fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
357        assert!(row < self.nrows());
358        assert!(col < self.ncols());
359
360        let block = row / N;
361        let offset = row % N;
362        &self.data[self.block_offset(block) + col * N + offset]
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use diskann_utils::{lazy_format, views::Matrix};
369    use rand::{
370        Rng, SeedableRng,
371        distr::{Distribution, Uniform},
372        rngs::StdRng,
373    };
374
375    use super::*;
376    use crate::utils::div_round_up;
377
378    /////////////////
379    // Square Norm //
380    /////////////////
381
382    fn square_norm_reference(x: &[f32]) -> f32 {
383        x.iter().map(|&i| i * i).sum()
384    }
385
386    fn test_square_norm_impl<R: Rng>(
387        dim: usize,
388        ntrials: usize,
389        relative_error: f32,
390        absolute_error: f32,
391        rng: &mut R,
392    ) {
393        let distribution = Uniform::<f32>::new(-1.0, 1.0).unwrap();
394        let mut x: Vec<f32> = vec![0.0; dim];
395        for trial in 0..ntrials {
396            x.iter_mut().for_each(|i| *i = distribution.sample(rng));
397            let expected = square_norm_reference(&x);
398            let got = square_norm(&x);
399
400            let this_absolute_error = (expected - got).abs();
401            let this_relative_error = this_absolute_error / expected.abs();
402
403            let absolute_ok = this_absolute_error <= absolute_error;
404            let relative_ok = this_relative_error <= relative_error;
405
406            if !absolute_ok && !relative_ok {
407                panic!(
408                    "recieved abolute/relative errors of {}/{} when the bounds were {}/{}\n\
409                     dim = {}, trial = {} of {}",
410                    this_absolute_error,
411                    this_relative_error,
412                    absolute_error,
413                    relative_error,
414                    dim,
415                    trial,
416                    ntrials,
417                )
418            }
419        }
420    }
421
422    cfg_if::cfg_if! {
423        if #[cfg(miri)] {
424            const NTRIALS: usize = 1;
425            const MAX_DIM: usize = 80;
426        } else {
427            const NTRIALS: usize = 100;
428            const MAX_DIM: usize = 128;
429        }
430    }
431
432    #[test]
433    fn test_square_norm() {
434        let mut rng = StdRng::seed_from_u64(0x71d00ad8c7105273);
435        for dim in 0..MAX_DIM {
436            let relative_error = 8.0e-7;
437            let absolute_error = 1.0e-5;
438
439            test_square_norm_impl(dim, NTRIALS, relative_error, absolute_error, &mut rng);
440        }
441    }
442
443    /////////////////////
444    // Block Transpose //
445    /////////////////////
446
447    fn test_block_transpose<const N: usize>(nrows: usize, ncols: usize) {
448        let context = lazy_format!("N = {}, nrows = {}, ncols = {}", N, nrows, ncols);
449
450        // Create initial data with the following layout:
451        //       0         1         2 ...   ncols-1
452        //   ncols   ncols+1   ncols+2     2*ncols-1
453        // 2*ncols 2*ncols+1 2*ncols+2     3*ncols-1
454        // ...
455        let mut data = Matrix::new(0.0, nrows, ncols);
456        data.as_mut_slice()
457            .iter_mut()
458            .enumerate()
459            .for_each(|(i, d)| *d = i as f32);
460
461        let mut transpose = BlockTranspose::<N>::from_matrix_view(data.as_view());
462
463        // Make sure the public methods return their advertised methods.
464        assert_eq!(transpose.nrows(), nrows, "{}", context);
465        assert_eq!(transpose.ncols(), ncols, "{}", context);
466        assert_eq!(transpose.block_size(), ncols, "{}", context);
467        assert_eq!(transpose.group_size(), N, "{}", context);
468        assert_eq!(transpose.full_blocks(), nrows / N, "{}", context);
469        assert_eq!(
470            transpose.num_blocks(),
471            div_round_up(nrows, N),
472            "{}",
473            context
474        );
475        assert_eq!(transpose.remainder(), nrows % N, "{}", context);
476
477        // Check regular row-column based indexing.
478        for row in 0..nrows {
479            for col in 0..ncols {
480                assert_eq!(
481                    data[(row, col)],
482                    transpose[(row, col)],
483                    "failed for (row, col) = ({}, {})",
484                    row,
485                    col
486                );
487            }
488        }
489
490        // Check indexing on the block level.
491        for b in 0..transpose.full_blocks() {
492            let block = transpose.block(b);
493            assert_eq!(block.nrows(), ncols);
494            assert_eq!(block.ncols(), N);
495
496            // Are the contents correct?
497            for i in 0..block.nrows() {
498                for j in 0..block.ncols() {
499                    assert_eq!(
500                        block[(i, j)],
501                        data[(N * b + j, i)],
502                        "failed in block {}, row {}, col {} -- {}",
503                        b,
504                        i,
505                        j,
506                        context
507                    );
508                }
509            }
510
511            // Make sure the pointer API is correct.
512            // SAFETY: The loop bounds above ensure `b < transpose.num_blocks()`.
513            let ptr = unsafe { transpose.block_ptr_unchecked(b) };
514            assert_eq!(ptr, block.as_slice().as_ptr());
515
516            // Construct a mutable version and zero it.
517            let mut block_mut = transpose.block_mut(b);
518            assert_eq!(ptr, block_mut.as_slice().as_ptr());
519            assert_eq!(block_mut.nrows(), ncols);
520            assert_eq!(block_mut.ncols(), N);
521            block_mut.as_mut_slice().fill(0.0);
522        }
523
524        let expected_remainder = nrows % N;
525        if expected_remainder != 0 {
526            let b = transpose.full_blocks();
527            let block = transpose.remainder_block().unwrap();
528            assert_eq!(block.nrows(), ncols);
529            assert_eq!(block.ncols(), N);
530
531            // Are the contents correct?
532            for i in 0..block.nrows() {
533                for j in 0..expected_remainder {
534                    assert_eq!(
535                        block[(i, j)],
536                        data[(N * b + j, i)],
537                        "failed in block {}, row {}, col {} -- {}",
538                        b,
539                        i,
540                        j,
541                        context
542                    );
543                }
544            }
545
546            // Make sure the pointer API is correct.
547            // SAFETY: The loop bounds above ensure `b < transpose.num_blocks()`.
548            let ptr = unsafe { transpose.block_ptr_unchecked(b) };
549            assert_eq!(ptr, block.as_slice().as_ptr());
550
551            // Construct a mutable version and zero it.
552            let mut block_mut = transpose.remainder_block_mut().unwrap();
553            assert_eq!(ptr, block_mut.as_slice().as_ptr());
554            assert_eq!(block_mut.nrows(), ncols);
555            assert_eq!(block_mut.ncols(), N);
556            block_mut.as_mut_slice().fill(0.0);
557        } else {
558            assert!(transpose.remainder_block().is_none());
559            assert!(transpose.remainder_block_mut().is_none());
560        }
561
562        // Check that the inner state is now zeroed.
563        assert!(transpose.data.iter().all(|i| *i == 0.0));
564    }
565
566    #[test]
567    fn test_block_transpose_16() {
568        for nrows in 0..128 {
569            for ncols in 0..5 {
570                test_block_transpose::<16>(nrows, ncols);
571            }
572        }
573    }
574
575    #[test]
576    fn test_block_transpose_8() {
577        for nrows in 0..128 {
578            for ncols in 0..5 {
579                test_block_transpose::<8>(nrows, ncols);
580            }
581        }
582    }
583}