Skip to main content

diskann_linalg/
common.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6/// Indicate whether a matrix should be implicitly transposed for an operation.
7#[derive(Debug, Clone, Copy)]
8pub enum Transpose {
9    /// Use a provided matrix directly.
10    None,
11    /// Use the transpose of a matrix.
12    Ordinary,
13}
14
15impl Transpose {
16    /// Return whether or not the enum is `Transpose::Ordinary`.
17    pub fn is_transpose(&self) -> bool {
18        match self {
19            Self::None => false,
20            Self::Ordinary => true,
21        }
22    }
23
24    /// Forward one of the arguments, depending on the value of `self`.
25    pub fn forward<T>(&self, if_none: T, if_transpose: T) -> T {
26        match self {
27            Self::None => if_none,
28            Self::Ordinary => if_transpose,
29        }
30    }
31
32    /// Call exactly one of the arguments depending on the value of `self` and return the
33    /// result.
34    pub fn call<F, G, T>(&self, if_none: F, if_transpose: G) -> T
35    where
36        F: Fn() -> T,
37        G: Fn() -> T,
38    {
39        match self {
40            Self::None => if_none(),
41            Self::Ordinary => if_transpose(),
42        }
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use std::sync::atomic::{AtomicBool, Ordering};
49
50    use super::*;
51
52    #[test]
53    fn test_is_transpose() {
54        assert!(!(Transpose::None).is_transpose());
55        assert!((Transpose::Ordinary).is_transpose());
56    }
57
58    #[test]
59    fn test_forward() {
60        assert_eq!((Transpose::None).forward(1, 2), 1);
61        assert_eq!((Transpose::Ordinary).forward(1, 2), 2);
62    }
63
64    #[test]
65    fn test_call() {
66        // None
67        let a_called = AtomicBool::new(false);
68        let b_called = AtomicBool::new(false);
69
70        let a = || {
71            a_called.store(true, Ordering::Relaxed);
72            1
73        };
74
75        let b = || {
76            b_called.store(true, Ordering::Relaxed);
77            2
78        };
79
80        assert_eq!((Transpose::None).call(a, b), 1);
81
82        // Make sure *only* `a` was called
83        assert!(a_called.load(Ordering::Relaxed));
84        assert!(!b_called.load(Ordering::Relaxed));
85
86        // Ordinary
87        let a_called = AtomicBool::new(false);
88        let b_called = AtomicBool::new(false);
89
90        let a = || {
91            a_called.store(true, Ordering::Relaxed);
92            1
93        };
94
95        let b = || {
96            b_called.store(true, Ordering::Relaxed);
97            2
98        };
99
100        assert_eq!((Transpose::Ordinary).call(a, b), 2);
101
102        // Make sure *only* `a` was called
103        assert!(!a_called.load(Ordering::Relaxed));
104        assert!(b_called.load(Ordering::Relaxed));
105    }
106}