diskann_quantization/
distances.rs1use diskann_vector::{DistanceFunction, MathematicalValue, PureDistanceFunction};
20
21pub(crate) type MV<T> = MathematicalValue<T>;
22
23#[derive(Debug, Default, Clone)]
29pub struct UnequalLengths;
30
31impl UnequalLengths {
32 #[allow(clippy::panic)]
34 #[inline(never)]
35 pub fn panic(self, xlen: usize, ylen: usize) -> ! {
36 panic!(
37 "vector lengths must be equal, instead got xlen = {}, ylen = {}",
38 xlen, ylen
39 );
40 }
41}
42
43impl std::fmt::Display for UnequalLengths {
44 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(fmt, "vector lengths must be equal")
46 }
47}
48
49impl std::error::Error for UnequalLengths {}
50
51pub type Result<T> = std::result::Result<T, UnequalLengths>;
54
55pub type MathematicalResult<T> = std::result::Result<MathematicalValue<T>, UnequalLengths>;
58
59#[macro_export]
64macro_rules! check_lengths {
65 ($x:ident, $y:ident) => {{
66 let len = $x.len();
67 if len != $y.len() {
68 Err($crate::distances::UnequalLengths)
69 } else {
70 Ok(len)
71 }
72 }};
73}
74
75pub use check_lengths;
76
77#[derive(Debug, Clone, Copy)]
79pub struct SquaredL2;
80
81#[derive(Debug, Clone, Copy)]
83pub struct InnerProduct;
84
85#[derive(Debug, Clone, Copy)]
87pub struct Hamming;
88
89macro_rules! auto_distance_function {
90 ($T:ty) => {
91 impl<A, B, To> DistanceFunction<A, B, To> for $T
92 where
93 $T: PureDistanceFunction<A, B, To>,
94 {
95 fn evaluate_similarity(&self, a: A, b: B) -> To {
96 <$T>::evaluate(a, b)
97 }
98 }
99 };
100}
101
102auto_distance_function!(SquaredL2);
103auto_distance_function!(InnerProduct);
104auto_distance_function!(Hamming);
105
106#[cfg(test)]
111mod test {
112 use super::*;
113
114 fn test_error_impl<T>(x: T)
115 where
116 T: std::error::Error,
117 {
118 assert_eq!(x.to_string(), "vector lengths must be equal");
119 assert!(x.source().is_none());
120 }
121
122 #[test]
123 fn test_error() {
124 test_error_impl(UnequalLengths);
125 }
126
127 fn test_check_length_impl(x: &[f32], y: &[f32]) -> Result<usize> {
128 check_lengths!(x, y)
129 }
130
131 #[test]
132 fn test_check_length() {
133 let x = [0.0f32; 10];
134 let y = [0.0f32; 10];
135
136 for i in 0..10 {
137 for j in 0..10 {
138 match test_check_length_impl(&x[..i], &y[..j]) {
139 Ok(len) => {
140 assert_eq!(i, j, "Ok should only be returned when i == j");
141 assert_eq!(i, len, "`check_lengths` should return the final length");
142 }
143 Err(UnequalLengths) => {
144 assert_ne!(i, j, "An error should be returned for unequal lengths");
145 }
146 }
147 }
148 }
149 }
150
151 #[test]
152 #[should_panic(expected = "vector lengths must be equal, instead got xlen = 10, ylen = 20")]
153 fn unequal_lenghts_panic() {
154 (UnequalLengths).panic(10, 20)
155 }
156}