Skip to main content

diskann_quantization/
distances.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Zero-sized types representing distance functions.
7//!
8//! The types defined here are zero-sized types that represent logical distance function
9//! operations.
10//!
11//! Some types like bit-vectors and [`crate::bits::BitSlice`]s can be efficiently
12//! implemented as pure functions and therefore can use [`PureDistanceFunction`] semantics
13//! for these types directly.
14//!
15//! Other quantization schemes such as scalar quantization (e.g.
16//! [`crate::scalar::CompensatedSquaredL2`]) need auxiliary state and therefore need
17//! customized, stateful distance function types.
18
19use diskann_vector::{DistanceFunction, MathematicalValue, PureDistanceFunction};
20
21pub(crate) type MV<T> = MathematicalValue<T>;
22
23/// A marker type that indicates a distance computation failed because the arguments had
24/// unequal lengths.
25///
26/// This struct intentionally is a zero-sized type to allow return paths to be as efficient
27/// as possible.
28#[derive(Debug, Default, Clone)]
29pub struct UnequalLengths;
30
31impl UnequalLengths {
32    /// Escalate the unequal length error to a full-blown panic.
33    #[allow(clippy::panic)]
34    #[inline(never)]
35    pub fn panic(self, xlen: usize, ylen: usize) -> ! {
36        panic!(
37            "vector lengths must be equal, instead got xlen = {}, ylen = {}",
38            xlen, ylen
39        );
40    }
41}
42
43impl std::fmt::Display for UnequalLengths {
44    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        write!(fmt, "vector lengths must be equal")
46    }
47}
48
49impl std::error::Error for UnequalLengths {}
50
51/// Upcoming return type for distance functions to allow graceful failure instead of
52/// panicking when incorrect dimensions are provided.
53pub type Result<T> = std::result::Result<T, UnequalLengths>;
54
55/// Upcoming return type for distance functions to allow graceful failure instead of
56/// panicking when incorrect dimensions are provided.
57pub type MathematicalResult<T> = std::result::Result<MathematicalValue<T>, UnequalLengths>;
58
59/// Check that `x.len() == y.len()`, returning `Err(diskann_vector::distances::UnequalLengths)` if
60/// the results are different.
61///
62/// If the results are the same, return the length.
63#[macro_export]
64macro_rules! check_lengths {
65    ($x:ident, $y:ident) => {{
66        let len = $x.len();
67        if len != $y.len() {
68            Err($crate::distances::UnequalLengths)
69        } else {
70            Ok(len)
71        }
72    }};
73}
74
75pub use check_lengths;
76
77/// Compute the squared Euclidean distance between vector-like types.
78#[derive(Debug, Clone, Copy)]
79pub struct SquaredL2;
80
81/// Compute the inner-product between vector-like types.
82#[derive(Debug, Clone, Copy)]
83pub struct InnerProduct;
84
85/// Compute the hamming distance between bit-vectors.
86#[derive(Debug, Clone, Copy)]
87pub struct Hamming;
88
89macro_rules! auto_distance_function {
90    ($T:ty) => {
91        impl<A, B, To> DistanceFunction<A, B, To> for $T
92        where
93            $T: PureDistanceFunction<A, B, To>,
94        {
95            fn evaluate_similarity(&self, a: A, b: B) -> To {
96                <$T>::evaluate(a, b)
97            }
98        }
99    };
100}
101
102auto_distance_function!(SquaredL2);
103auto_distance_function!(InnerProduct);
104auto_distance_function!(Hamming);
105
106///////////
107// Tests //
108///////////
109
110#[cfg(test)]
111mod test {
112    use super::*;
113
114    fn test_error_impl<T>(x: T)
115    where
116        T: std::error::Error,
117    {
118        assert_eq!(x.to_string(), "vector lengths must be equal");
119        assert!(x.source().is_none());
120    }
121
122    #[test]
123    fn test_error() {
124        test_error_impl(UnequalLengths);
125    }
126
127    fn test_check_length_impl(x: &[f32], y: &[f32]) -> Result<usize> {
128        check_lengths!(x, y)
129    }
130
131    #[test]
132    fn test_check_length() {
133        let x = [0.0f32; 10];
134        let y = [0.0f32; 10];
135
136        for i in 0..10 {
137            for j in 0..10 {
138                match test_check_length_impl(&x[..i], &y[..j]) {
139                    Ok(len) => {
140                        assert_eq!(i, j, "Ok should only be returned when i == j");
141                        assert_eq!(i, len, "`check_lengths` should return the final length");
142                    }
143                    Err(UnequalLengths) => {
144                        assert_ne!(i, j, "An error should be returned for unequal lengths");
145                    }
146                }
147            }
148        }
149    }
150
151    #[test]
152    #[should_panic(expected = "vector lengths must be equal, instead got xlen = 10, ylen = 20")]
153    fn unequal_lenghts_panic() {
154        (UnequalLengths).panic(10, 20)
155    }
156}