Skip to main content

diskann_linalg/
lib.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::fmt;
7
8pub mod common;
9pub use common::Transpose;
10
11mod faer;
12use faer::{random_distance_preserving_matrix_impl, sgemm_impl, svd_into_impl};
13use rand::Rng;
14
15/// Matrix identifier for SGEMM operations.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum MatrixName {
18    A,
19    B,
20    C,
21}
22
23impl fmt::Display for MatrixName {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            MatrixName::A => write!(f, "a (m * k)"),
27            MatrixName::B => write!(f, "b (k * n)"),
28            MatrixName::C => write!(f, "c (m * n)"),
29        }
30    }
31}
32
33/// Error type for SGEMM operations.
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub enum SgemmError {
36    /// Matrix has incorrect dimensions.
37    InvalidMatrixDimensions {
38        matrix_name: MatrixName,
39        expected_rows: usize,
40        expected_cols: usize,
41        actual_len: usize,
42    },
43    /// Dimension overflow when computing matrix size.
44    DimensionOverflow {
45        matrix_name: MatrixName,
46        rows: usize,
47        cols: usize,
48    },
49}
50
51impl fmt::Display for SgemmError {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        match self {
54            SgemmError::InvalidMatrixDimensions {
55                matrix_name,
56                expected_rows,
57                expected_cols,
58                actual_len,
59            } => write!(
60                f,
61                "expected {}x{} matrix {} to have length {}, instead got {}",
62                expected_rows,
63                expected_cols,
64                matrix_name,
65                expected_rows * expected_cols,
66                actual_len
67            ),
68            SgemmError::DimensionOverflow {
69                matrix_name,
70                rows,
71                cols,
72            } => write!(
73                f,
74                "dimension overflow in matrix {}: {} * {} would overflow usize",
75                matrix_name, rows, cols
76            ),
77        }
78    }
79}
80
81impl std::error::Error for SgemmError {}
82
83// Make the reference implementation available for internal testing.
84#[cfg(test)]
85mod reference;
86
87/// Matrix-matrix multiplication for implicit row-major matrices `a` and `b` using the
88/// implicit row-major matrix `c` as the destination.
89///
90/// Performs one of the following operations:
91/// ```ignore
92/// 1. c = [beta * c] + alpha * a * b
93/// 2. c = [beta * c] + alpha * a' * b
94/// 3. c = [beta * c] + alpha * a * b'
95/// 3. c = [beta * c] + alpha * a' * b'
96/// ```
97/// Where `x'` indicates the ordinary transpose of `x`.
98///
99/// If `beta` is `None`, the destination `c` is completely over-written.
100///
101/// * `atranspose`: Whether `a` should be interpreted as an in-place transpose.
102/// * `btranspose`: Whether `b` should be interpreted as an in-place transpose.
103/// * `m`: The number of rows in `c`. Additionally:
104///     - If `!atranspose.is_transpose()`, this is the number of rows in `a`.
105///     - If `atranspose.is_transpose()`, this is the number of rows in `a`.
106/// * `n`: The number of columns `c`. Additionally:
107///     - If `!btranspose.is_transpose()`, this is the number of columns in `b`.
108///     - If `btranspose.is_transpose()`, this is the number of columns in `b`.
109/// * `k`: The number of columns in matrix `a` and the number of rows in matrix `b`.
110/// * `k`: Refer to the following:
111///     - If `!atranspose.is_transpose()`, this is the number of columns in `a`.
112///       Otherwise, this is the number of rows in `a`.
113///     - If `!btranspose.is_transpose()`, this is the number of rows in `b`.
114///       Otherwise, this is the number of columns in `b`.
115/// * `alpha`: Scaling parameter for the operation `a * b`.
116/// * `a`: The matrix `a` with dimension `m x k` (potentially after transposing).
117/// * `b`: The matrix `b` with dimension `k x n` (potentially after transposing).
118/// * `beta`: Optional scaling parameter for the matrix `c`. If `None`, then `c` will be
119///   overwritten entirely.
120/// * `c`: The output matrix with dimension `m x n`.
121///
122/// # Note
123///
124/// This inteface is a simplified version of the full cblas `sgemm` interface, namely that it
125///
126/// 1. Does not support column-major layouts
127/// 2. Does not allow for arbitrary strides in the leading dimension of the matrices.
128///
129/// This is to support the common-case in DiskANN that uses a Row-Major layout and always
130/// uses dense matrices.
131///
132/// If the more esoteric features of the cblas `sgemm` API are needed, we can provide
133/// that as an interface extension.
134///
135/// # Errors
136///
137/// Returns an error if:
138/// * `m * k` would overflow `usize`
139/// * `k * n` would overflow `usize`
140/// * `m * n` would overflow `usize`
141/// * `a.len() != m * k`
142/// * `b.len() != k * n`
143/// * `c.len() != m * n`
144#[allow(clippy::too_many_arguments)]
145pub fn sgemm(
146    atranspose: Transpose,
147    btranspose: Transpose,
148    m: usize,
149    n: usize,
150    k: usize,
151    alpha: f32,
152    a: &[f32],
153    b: &[f32],
154    beta: Option<f32>,
155    c: &mut [f32],
156) -> Result<(), SgemmError> {
157    // Check size requirements with overflow protection.
158    let expected_a_len = m.checked_mul(k).ok_or(SgemmError::DimensionOverflow {
159        matrix_name: MatrixName::A,
160        rows: m,
161        cols: k,
162    })?;
163
164    if a.len() != expected_a_len {
165        return Err(SgemmError::InvalidMatrixDimensions {
166            matrix_name: MatrixName::A,
167            expected_rows: m,
168            expected_cols: k,
169            actual_len: a.len(),
170        });
171    }
172
173    let expected_b_len = k.checked_mul(n).ok_or(SgemmError::DimensionOverflow {
174        matrix_name: MatrixName::B,
175        rows: k,
176        cols: n,
177    })?;
178
179    if b.len() != expected_b_len {
180        return Err(SgemmError::InvalidMatrixDimensions {
181            matrix_name: MatrixName::B,
182            expected_rows: k,
183            expected_cols: n,
184            actual_len: b.len(),
185        });
186    }
187
188    let expected_c_len = m.checked_mul(n).ok_or(SgemmError::DimensionOverflow {
189        matrix_name: MatrixName::C,
190        rows: m,
191        cols: n,
192    })?;
193
194    if c.len() != expected_c_len {
195        return Err(SgemmError::InvalidMatrixDimensions {
196            matrix_name: MatrixName::C,
197            expected_rows: m,
198            expected_cols: n,
199            actual_len: c.len(),
200        });
201    }
202
203    // Invoke the actual implementation.
204    sgemm_impl(atranspose, btranspose, m, n, k, alpha, a, b, beta, c);
205    Ok(())
206}
207
208/// Compute the SVD of the provided matrix implicit row-major matrix `data`.
209///
210/// * `m`: The number of rows in `a`.
211/// * `n`: The number of columns in `a`.
212/// * `a`: The data matrix to decompose with dimensiuon `m x n` stored in Row-Major order [Note 1].
213/// * `singular_values`: Contains the singular values of `a` sorted so that
214///   `singular_values[i] ≥ singular_values[i+1]`.
215/// * `u`: Contains the `m x m` unitary matrix in Row-Major order.
216/// * `vt`: Contains the `n x n` unitary matrix in Column-Major order [Note 2].
217///
218/// # Notes
219///
220/// 1. Due to the contract offered by `lapacke`, callers of this function must assume that
221///    the contents of `a` are left in an undefined state after this function.
222///
223///    See: https://netlib.org/lapack/explore-html//df/d22/group__gesdd_gab9ffdde22b38f0cc442e44cbea23818f.html
224///
225/// 2. Similar to #1, the restriction that `vt` is transposed is a lapack byproduct.
226///
227/// # Panics
228///
229/// Panics if
230///
231/// * `a.len() != m * n`
232/// * `singular_values.len() != min(m, n)`
233/// * `u.len() != m * m`.
234/// * `vt.len() != n * n`.
235///
236/// Additionally, if MKL is used, panics if any either `m` or `n` is not representable
237/// as a signed 32-bit integer due to `cblas` limitations.
238pub fn svd_into(
239    m: usize,
240    n: usize,
241    a: &mut [f32],
242    singular_values: &mut [f32],
243    u: &mut [f32],
244    vt: &mut [f32],
245) -> Result<(), impl std::error::Error + 'static> {
246    // Check size requirements.
247    assert_eq!(a.len(), m * n);
248    assert_eq!(singular_values.len(), m.min(n));
249    assert_eq!(u.len(), m * m);
250    assert_eq!(vt.len(), n * n);
251
252    // Invoke the actual implementation.
253    svd_into_impl(m, n, a, singular_values, u, vt)
254}
255
256/// Construct a random `dim x dim` distance preserving matrix.
257///
258/// Practically speaking, the returned matrix should be orthogonal with a determinant of
259/// either +1 or -1.
260pub fn random_distance_preserving_matrix<T: Rng + ?Sized>(dim: usize, rng: &mut T) -> Vec<f32> {
261    random_distance_preserving_matrix_impl(dim, rng)
262}
263
264#[cfg(test)]
265mod tests {
266    use approx::{assert_abs_diff_eq, assert_relative_eq};
267    use rand::{distr::Distribution, rngs::StdRng, SeedableRng};
268    use rand_distr::StandardNormal;
269    use serde::Deserialize;
270
271    use super::*;
272    use crate::reference;
273
274    ////////////////////////
275    // Simple SGEMM tests //
276    ////////////////////////
277
278    #[test]
279    fn test_reference_implementation() {
280        let problems = reference::test_sgemm_problems();
281        for (i, problem) in problems.iter().enumerate() {
282            let result = problem.check(sgemm);
283            if let Err(err) = result {
284                panic!("{} on iteration {}. Problem: {:?}", err, i, problem);
285            }
286        }
287    }
288
289    #[test]
290    fn test_sgemm_invalid_matrix_a_dimensions() {
291        let mut c = [0.0f32; 6];
292        let err = sgemm(
293            Transpose::None,
294            Transpose::None,
295            2,
296            3,
297            4,
298            1.0,
299            &[0.0; 5], // Should be 2 * 4 = 8
300            &[0.0; 12],
301            None,
302            &mut c,
303        )
304        .unwrap_err();
305
306        assert_eq!(
307            err.to_string(),
308            "expected 2x4 matrix a (m * k) to have length 8, instead got 5"
309        );
310    }
311
312    #[test]
313    fn test_sgemm_invalid_matrix_b_dimensions() {
314        let mut c = [0.0f32; 6];
315        let err = sgemm(
316            Transpose::None,
317            Transpose::None,
318            2,
319            3,
320            4,
321            1.0,
322            &[0.0; 8],
323            &[0.0; 10], // Should be 4 * 3 = 12
324            None,
325            &mut c,
326        )
327        .unwrap_err();
328
329        assert_eq!(
330            err.to_string(),
331            "expected 4x3 matrix b (k * n) to have length 12, instead got 10"
332        );
333    }
334
335    #[test]
336    fn test_sgemm_invalid_matrix_c_dimensions() {
337        let mut c = [0.0f32; 5]; // Should be 2 * 3 = 6
338        let err = sgemm(
339            Transpose::None,
340            Transpose::None,
341            2,
342            3,
343            4,
344            1.0,
345            &[0.0; 8],
346            &[0.0; 12],
347            None,
348            &mut c,
349        )
350        .unwrap_err();
351
352        assert_eq!(
353            err.to_string(),
354            "expected 2x3 matrix c (m * n) to have length 6, instead got 5"
355        );
356    }
357
358    #[test]
359    fn test_sgemm_m_times_k_overflow() {
360        let mut c = [0.0f32];
361        let err = sgemm(
362            Transpose::None,
363            Transpose::None,
364            usize::MAX,
365            1,
366            2,
367            1.0,
368            &[],
369            &[0.0],
370            None,
371            &mut c,
372        )
373        .unwrap_err();
374
375        assert_eq!(
376            err.to_string(),
377            format!(
378                "dimension overflow in matrix a (m * k): {} * 2 would overflow usize",
379                usize::MAX
380            )
381        );
382    }
383
384    #[test]
385    fn test_sgemm_k_times_n_overflow() {
386        let mut c = vec![0.0f32; 10];
387        let err = sgemm(
388            Transpose::None,
389            Transpose::None,
390            1,
391            usize::MAX,
392            10,
393            1.0,
394            &[0.0f32; 10],
395            &[],
396            None,
397            &mut c,
398        )
399        .unwrap_err();
400
401        assert_eq!(
402            err.to_string(),
403            format!(
404                "dimension overflow in matrix b (k * n): 10 * {} would overflow usize",
405                usize::MAX
406            )
407        );
408    }
409
410    #[test]
411    fn test_sgemm_m_times_n_overflow() {
412        let mut c = [];
413        let err = sgemm(
414            Transpose::None,
415            Transpose::None,
416            2,
417            usize::MAX,
418            0,
419            1.0,
420            &[],
421            &[],
422            None,
423            &mut c,
424        )
425        .unwrap_err();
426
427        assert_eq!(
428            err.to_string(),
429            format!(
430                "dimension overflow in matrix c (m * n): 2 * {} would overflow usize",
431                usize::MAX
432            )
433        );
434    }
435
436    /// This test ensures the Result type doesn't grow unexpectedly.
437    /// A large Result size increases stack usage even for the Ok(()) case.
438    #[test]
439    fn test_sgemm_result_size() {
440        let mut c = [0.0f32; 6];
441        let result = sgemm(
442            Transpose::None,
443            Transpose::None,
444            2,
445            3,
446            4,
447            1.0,
448            &[0.0; 5],
449            &[0.0; 12],
450            None,
451            &mut c,
452        );
453
454        let result_size = std::mem::size_of_val(&result);
455        const EXPECTED_RESULT_SIZE: usize = 32;
456        assert_eq!(
457            result_size, EXPECTED_RESULT_SIZE,
458            "Result size is {} bytes, does not match the expected size of {} bytes.",
459            result_size, EXPECTED_RESULT_SIZE
460        );
461    }
462
463    ///////////////
464    // SVD Tests //
465    ///////////////
466
467    fn test_file_path(name: &str) -> String {
468        format!("{}/test_data/{}", env!("CARGO_MANIFEST_DIR"), name)
469    }
470
471    /// The generate set of reference SVD input files.
472    const SVD_INPUT_FILE: &str = "reference_svd_inputs.json";
473
474    #[derive(Deserialize, Debug)]
475    struct SVDTestCase {
476        m: usize,
477        n: usize,
478        matrix: Vec<f32>,
479        singular_values: Vec<f32>,
480    }
481
482    impl SVDTestCase {
483        fn summary(&self) -> String {
484            format!("svd test case with dimension {}x{}", self.m, self.n)
485        }
486    }
487
488    struct SVDTolerance {
489        absolute: f32,
490        relative: f32,
491    }
492
493    impl SVDTolerance {
494        fn check(&self, absolute: f32, relative: f32) -> bool {
495            absolute <= self.absolute || relative <= self.relative
496        }
497    }
498
499    fn materialize_singular_values(singular_values: &[f32], m: usize, n: usize) -> Vec<f32> {
500        assert_eq!(singular_values.len(), m.min(n));
501        let mut output = vec![0.0; m * n];
502
503        for (i, &s) in singular_values.iter().enumerate() {
504            output[n * i + i] = s;
505        }
506        output
507    }
508
509    fn test_svd(
510        case: &SVDTestCase,
511        singular_value_tolerance: &SVDTolerance,
512        reconstructed_tolerance: &SVDTolerance,
513        context: &dyn std::fmt::Display,
514    ) {
515        // Create the output matrices.
516        let mut singular_values = vec![0.0; case.m.min(case.n)];
517        let mut u = vec![0.0; case.m * case.m];
518        let mut vt = vec![0.0; case.n * case.n];
519
520        svd_into(
521            case.m,
522            case.n,
523            &mut case.matrix.clone(),
524            &mut singular_values,
525            &mut u,
526            &mut vt,
527        )
528        .unwrap();
529
530        // Check the resulting singular values.
531        for (i, (&got, &expected)) in
532            std::iter::zip(singular_values.iter(), case.singular_values.iter()).enumerate()
533        {
534            let diff = (got - expected).abs();
535            let relative = diff / expected;
536            assert!(
537                singular_value_tolerance.check(diff, relative),
538                "got {} but expected {} (diff: {}, relative: {}) at position {}: {}",
539                got,
540                expected,
541                diff,
542                relative,
543                i,
544                context
545            );
546        }
547
548        // Test the reconstruction.
549        let full_singular_values = materialize_singular_values(&singular_values, case.m, case.n);
550        let mut temp = vec![0.0; case.m * case.n];
551
552        // Multiply `u * singular_values`.
553        sgemm(
554            Transpose::None,
555            Transpose::None,
556            case.m,
557            case.n,
558            case.m,
559            1.0,
560            &u,
561            &full_singular_values,
562            None,
563            &mut temp,
564        )
565        .unwrap();
566
567        let mut output = vec![0.0; case.m * case.n];
568        sgemm(
569            Transpose::None,
570            Transpose::None,
571            case.m,
572            case.n,
573            case.n,
574            1.0,
575            &temp,
576            &vt,
577            None,
578            &mut output,
579        )
580        .unwrap();
581
582        for row in 0..case.m {
583            for col in 0..case.n {
584                let got = output[case.n * row + col];
585                let expected = case.matrix[case.n * row + col];
586                let diff = (got - expected).abs();
587                let relative = diff / expected;
588                assert!(
589                    reconstructed_tolerance.check(diff, relative),
590                    "mismatch in reconstructed matrix at (row, col) = ({}, {}). \
591                     Got {}, expected {} (diff: {}, relative: {}). {}",
592                    row,
593                    col,
594                    got,
595                    expected,
596                    diff,
597                    relative,
598                    context
599                );
600            }
601        }
602    }
603
604    #[test]
605    fn test_svd_implementation() {
606        let path = test_file_path(SVD_INPUT_FILE);
607        let file = std::fs::File::open(path.clone())
608            .unwrap_or_else(|_| panic!("failed to open file {path}"));
609
610        let reader = std::io::BufReader::new(file);
611        let cases: Vec<SVDTestCase> = serde_json::from_reader(reader).unwrap();
612
613        let singular_values_tolerance = SVDTolerance {
614            absolute: 2.0e-6,
615            relative: 3.0e-6,
616        };
617
618        let reconstructed_tolerance = SVDTolerance {
619            absolute: 5.0e-5,
620            relative: 0.0,
621        };
622
623        for (i, case) in cases.iter().enumerate() {
624            let context = format!(
625                "while processing case {} of {}: {}",
626                i + 1,
627                cases.len(),
628                case.summary()
629            );
630            test_svd(
631                case,
632                &singular_values_tolerance,
633                &reconstructed_tolerance,
634                &context,
635            );
636        }
637    }
638
639    ///////////////////////////
640    // Rotation Matrix Tests //
641    ///////////////////////////
642
643    const EPSILON: f32 = 1e-5;
644
645    fn test_distance_preserving_matrix_impl(dim: usize, rng: &mut StdRng) {
646        // Construct the distance preserving matrix.
647        let q = random_distance_preserving_matrix(dim, rng);
648
649        // Check that `q * q'` is close to the identity matrix.
650        let qm = ::faer::mat::MatRef::from_row_major_slice(&q, dim, dim);
651        let m = qm * qm.transpose();
652
653        for j in 0..dim {
654            for i in 0..dim {
655                if i == j {
656                    assert_abs_diff_eq!(m[(i, j)], 1.0, epsilon = EPSILON);
657                } else {
658                    assert_abs_diff_eq!(m[(i, j)], 0.0, epsilon = EPSILON);
659                }
660            }
661        }
662
663        // Instead of explicitly checking the determinant, we sample using 100 randomly
664        // generated vectors, verifying that the norms are unchanged.
665        const RANDOM_TRIALS: usize = 100;
666        let mut v = vec![0.0f32; dim];
667        for _ in 0..RANDOM_TRIALS {
668            v.iter_mut()
669                .for_each(|i| *i = StandardNormal {}.sample(rng));
670            let vm = ::faer::mat::MatRef::from_row_major_slice(&v, dim, 1);
671            let v_norm = vm.squared_norm_l2();
672            let t = qm * vm;
673            let t_norm = t.squared_norm_l2();
674
675            assert_relative_eq!(v_norm, t_norm, epsilon = EPSILON, max_relative = EPSILON);
676            assert_ne!(vm, t);
677        }
678    }
679
680    #[test]
681    fn test_rotation_matrix() {
682        let mut rng = StdRng::seed_from_u64(0xc0ff33);
683        let num_trials = 5;
684        for dim in [2, 100, 256] {
685            for _ in 0..num_trials {
686                test_distance_preserving_matrix_impl(dim, &mut rng);
687            }
688        }
689    }
690}