diskann_vector/
distance.rs

1/*
2 * Copyright (c) Microsoft Corporation. All rights reserved.
3 * Licensed under the MIT license.
4 */
5use crate::l2_float_distance::{
6    distance_cosine_vector_f32, distance_l2_vector_f16, distance_l2_vector_f32,
7};
8use crate::{Half, Metric};
9
10/// Distance contract for full-precision vertex
11pub trait FullPrecisionDistance<T, const N: usize> {
12    /// Get the distance between vertex a and vertex b
13    fn distance_compare(a: &[T; N], b: &[T; N], vec_type: Metric) -> f32;
14}
15
16// reason = "Not supported Metric type Metric::Cosine"
17#[allow(clippy::panic)]
18impl<const N: usize> FullPrecisionDistance<f32, N> for [f32; N] {
19    /// Calculate distance between two f32 Vertex
20    #[inline(always)]
21    fn distance_compare(a: &[f32; N], b: &[f32; N], metric: Metric) -> f32 {
22        match metric {
23            Metric::L2 => distance_l2_vector_f32::<N>(a, b),
24            Metric::Cosine => distance_cosine_vector_f32::<N>(a, b),
25            //_ => panic!("Not supported Metric type {:?}", metric),
26        }
27    }
28}
29
30// reason = "Not supported Metric type Metric::Cosine"
31#[allow(clippy::panic)]
32impl<const N: usize> FullPrecisionDistance<Half, N> for [Half; N] {
33    fn distance_compare(a: &[Half; N], b: &[Half; N], metric: Metric) -> f32 {
34        match metric {
35            Metric::L2 => distance_l2_vector_f16::<N>(a, b),
36            _ => panic!("Not supported Metric type {:?}", metric),
37        }
38    }
39}
40
41// reason = "Not yet supported Vector i8"
42#[allow(clippy::panic)]
43impl<const N: usize> FullPrecisionDistance<i8, N> for [i8; N] {
44    fn distance_compare(_a: &[i8; N], _b: &[i8; N], _metric: Metric) -> f32 {
45        panic!("Not supported VectorType i8")
46    }
47}
48
49// reason = "Not yet supported Vector u8"
50#[allow(clippy::panic)]
51impl<const N: usize> FullPrecisionDistance<u8, N> for [u8; N] {
52    fn distance_compare(_a: &[u8; N], _b: &[u8; N], _metric: Metric) -> f32 {
53        panic!("Not supported VectorType u8")
54    }
55}
56
57#[cfg(test)]
58mod distance_test {
59    use super::*;
60
61    #[repr(C, align(32))]
62    pub struct F32Slice112([f32; 112]);
63
64    #[repr(C, align(32))]
65    pub struct F16Slice112([Half; 112]);
66
67    fn get_turing_test_data() -> (F32Slice112, F32Slice112) {
68        let a_slice: [f32; 112] = [
69            0.13961786,
70            -0.031577103,
71            -0.09567415,
72            0.06695563,
73            -0.1588727,
74            0.089852564,
75            -0.019837005,
76            0.07497972,
77            0.010418192,
78            -0.054594643,
79            0.08613386,
80            -0.05103466,
81            0.16568437,
82            -0.02703799,
83            0.00728657,
84            -0.15313251,
85            0.16462992,
86            -0.030570814,
87            0.11635703,
88            0.23938893,
89            0.018022912,
90            -0.12646551,
91            0.018048918,
92            -0.035986554,
93            0.031986624,
94            -0.015286017,
95            0.010117953,
96            -0.032691937,
97            0.12163067,
98            -0.04746277,
99            0.010213069,
100            -0.043672588,
101            -0.099362016,
102            0.06599016,
103            -0.19397286,
104            -0.13285528,
105            -0.22040887,
106            0.017690737,
107            -0.104262285,
108            -0.0044555613,
109            -0.07383778,
110            -0.108652934,
111            0.13399786,
112            0.054912474,
113            0.20181285,
114            0.1795591,
115            -0.05425621,
116            -0.10765217,
117            0.1405377,
118            -0.14101997,
119            -0.12017701,
120            0.011565498,
121            0.06952187,
122            0.060136646,
123            0.0023214167,
124            0.04204699,
125            0.048470616,
126            0.17398086,
127            0.024218207,
128            -0.15626553,
129            -0.11291045,
130            -0.09688122,
131            0.14393932,
132            -0.14713104,
133            -0.108876854,
134            0.035279203,
135            -0.05440188,
136            0.017205412,
137            0.011413814,
138            0.04009471,
139            0.11070237,
140            -0.058998976,
141            0.07260045,
142            -0.057893746,
143            -0.0036240944,
144            -0.0064988653,
145            -0.13842176,
146            -0.023219328,
147            0.0035885905,
148            -0.0719257,
149            -0.21335067,
150            0.11415403,
151            -0.0059823603,
152            0.12091869,
153            0.08136634,
154            -0.10769281,
155            0.024518685,
156            0.0009200326,
157            -0.11628049,
158            0.07448965,
159            0.13736208,
160            -0.04144517,
161            -0.16426727,
162            -0.06380103,
163            -0.21386267,
164            0.022373492,
165            -0.05874115,
166            0.017314062,
167            -0.040344074,
168            0.01059176,
169            0.0,
170            0.0,
171            0.0,
172            0.0,
173            0.0,
174            0.0,
175            0.0,
176            0.0,
177            0.0,
178            0.0,
179            0.0,
180            0.0,
181        ];
182        let b_slice: [f32; 112] = [
183            -0.07209058,
184            -0.17755842,
185            -0.030627966,
186            0.163028,
187            -0.2233766,
188            0.057412963,
189            0.0076995124,
190            -0.017121306,
191            -0.015759075,
192            -0.026947778,
193            -0.010282468,
194            -0.23968373,
195            -0.021486737,
196            -0.09903155,
197            0.09361805,
198            0.0042711576,
199            -0.08695552,
200            -0.042165346,
201            0.064218745,
202            -0.06707651,
203            0.07846054,
204            0.12235762,
205            -0.060716823,
206            0.18496591,
207            -0.13023394,
208            0.022469055,
209            0.056764495,
210            0.07168404,
211            -0.08856144,
212            -0.15343173,
213            0.099879816,
214            -0.033529017,
215            0.0795304,
216            -0.009242254,
217            -0.10254546,
218            0.13086525,
219            -0.101518914,
220            -0.1031299,
221            -0.056826904,
222            0.033196196,
223            0.044143833,
224            -0.049787212,
225            -0.018148342,
226            -0.11172959,
227            -0.06776237,
228            -0.09185828,
229            -0.24171598,
230            0.05080982,
231            -0.0727684,
232            0.045031235,
233            -0.11363879,
234            -0.063389264,
235            0.105850354,
236            -0.19847773,
237            0.08828623,
238            -0.087071925,
239            0.033512704,
240            0.16118294,
241            0.14111553,
242            0.020884402,
243            -0.088860825,
244            0.018745849,
245            0.047522716,
246            -0.03665169,
247            0.15726231,
248            -0.09930561,
249            0.057844743,
250            -0.10532736,
251            -0.091297254,
252            0.067029804,
253            0.04153976,
254            0.06393326,
255            0.054578528,
256            0.0038539872,
257            0.1023088,
258            -0.10653885,
259            -0.108500294,
260            -0.046606563,
261            0.020439683,
262            -0.120957725,
263            -0.13334097,
264            -0.13425854,
265            -0.20481694,
266            0.07009538,
267            0.08660361,
268            -0.0096641015,
269            0.095316306,
270            -0.002898167,
271            -0.19680002,
272            0.08466311,
273            0.04812689,
274            -0.028978813,
275            0.04780206,
276            -0.2001506,
277            -0.036866356,
278            -0.023720587,
279            0.10731964,
280            0.05517358,
281            -0.09580819,
282            0.14595725,
283            0.0,
284            0.0,
285            0.0,
286            0.0,
287            0.0,
288            0.0,
289            0.0,
290            0.0,
291            0.0,
292            0.0,
293            0.0,
294            0.0,
295        ];
296
297        (F32Slice112(a_slice), F32Slice112(b_slice))
298    }
299
300    fn get_turing_test_data_f16() -> (F16Slice112, F16Slice112) {
301        let (a_slice, b_slice) = get_turing_test_data();
302        let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x));
303        let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x));
304
305        (
306            F16Slice112(a_data.collect::<Vec<Half>>().try_into().unwrap()),
307            F16Slice112(b_data.collect::<Vec<Half>>().try_into().unwrap()),
308        )
309    }
310
311    use crate::test_util::*;
312    use approx::assert_abs_diff_eq;
313
314    #[test]
315    fn test_dist_l2_float_turing() {
316        // two vectors are allocated in the contiguous heap memory
317        let (a_slice, b_slice) = get_turing_test_data();
318        let distance = <[f32; 112] as FullPrecisionDistance<f32, 112>>::distance_compare(
319            &a_slice.0,
320            &b_slice.0,
321            Metric::L2,
322        );
323
324        assert_abs_diff_eq!(
325            distance,
326            no_vector_compare_f32(&a_slice.0, &b_slice.0),
327            epsilon = 1e-6
328        );
329    }
330
331    #[test]
332    fn test_dist_l2_f16_turing() {
333        // two vectors are allocated in the contiguous heap memory
334        let (a_slice, b_slice) = get_turing_test_data_f16();
335        let distance = <[Half; 112] as FullPrecisionDistance<Half, 112>>::distance_compare(
336            &a_slice.0,
337            &b_slice.0,
338            Metric::L2,
339        );
340
341        // Note the variance between the full 32 bit precision and the 16 bit precision
342        assert_eq!(distance, no_vector_compare_f16(&a_slice.0, &b_slice.0));
343    }
344
345    #[test]
346    fn distance_test() {
347        #[repr(C, align(32))]
348        struct Vector32ByteAligned {
349            v: [f32; 512],
350        }
351
352        // two vectors are allocated in the contiguous heap memory
353        let two_vec = Box::new(Vector32ByteAligned {
354            v: [
355                69.02492, 78.84786, 63.125072, 90.90581, 79.2592, 70.81731, 3.0829668, 33.33287,
356                20.777142, 30.147898, 23.681915, 42.553043, 12.602162, 7.3808074, 19.157589,
357                65.6791, 76.44677, 76.89124, 86.40756, 84.70118, 87.86142, 16.126896, 5.1277637,
358                95.11038, 83.946945, 22.735607, 11.548555, 59.51482, 24.84603, 15.573776, 78.27185,
359                71.13179, 38.574017, 80.0228, 13.175261, 62.887978, 15.205181, 18.89392, 96.13162,
360                87.55455, 34.179806, 62.920044, 4.9305916, 54.349373, 21.731495, 14.982187,
361                40.262867, 20.15214, 36.61963, 72.450806, 55.565, 95.5375, 93.73356, 95.36308,
362                66.30762, 58.0397, 18.951357, 67.11702, 43.043316, 30.65622, 99.85361, 2.5889993,
363                27.844774, 39.72441, 46.463238, 71.303764, 90.45308, 36.390602, 63.344395,
364                26.427078, 35.99528, 82.35505, 32.529175, 23.165905, 74.73179, 9.856939, 59.38126,
365                35.714924, 79.81213, 46.704124, 24.47884, 36.01743, 0.46678782, 29.528152,
366                1.8980742, 24.68853, 75.58984, 98.72279, 68.62601, 11.890173, 49.49361, 55.45572,
367                72.71067, 34.107483, 51.357758, 76.400635, 81.32725, 66.45081, 17.848074,
368                62.398876, 94.20444, 2.10886, 17.416393, 64.88253, 29.000723, 62.434315, 53.907238,
369                70.51412, 78.70744, 55.181683, 64.45116, 23.419212, 53.68544, 43.506958, 46.89598,
370                35.905994, 64.51397, 91.95555, 20.322979, 74.80128, 97.548744, 58.312725, 78.81985,
371                31.911612, 14.445949, 49.85094, 70.87396, 40.06766, 7.129991, 78.48008, 75.21636,
372                93.623604, 95.95479, 29.571129, 22.721554, 26.73875, 52.075504, 56.783104,
373                94.65493, 61.778534, 85.72401, 85.369514, 29.922367, 41.410553, 94.12884,
374                80.276855, 55.604828, 54.70947, 74.07216, 44.61955, 31.38113, 68.48596, 34.56782,
375                14.424729, 48.204506, 9.675444, 32.01946, 92.32695, 36.292683, 78.31955, 98.05327,
376                14.343918, 46.017002, 95.90888, 82.63626, 16.873539, 3.698051, 7.8042626,
377                64.194405, 96.71023, 67.93692, 21.618402, 51.92182, 22.834194, 61.56986, 19.749891,
378                55.31206, 38.29552, 67.57593, 67.145836, 38.92673, 94.95708, 72.38746, 90.70901,
379                69.43995, 9.394085, 31.646872, 88.20112, 9.134722, 99.98214, 5.423498, 41.51995,
380                76.94409, 77.373276, 3.2966614, 9.611201, 57.231106, 30.747868, 76.10228, 91.98308,
381                70.893585, 0.9067178, 43.96515, 16.321218, 27.734184, 83.271835, 88.23312,
382                87.16445, 5.556643, 15.627432, 58.547127, 93.6459, 40.539192, 49.124157, 91.13276,
383                57.485855, 8.827019, 4.9690843, 46.511234, 53.91469, 97.71925, 20.135271,
384                23.353004, 70.92099, 93.38748, 87.520134, 51.684677, 29.89813, 9.110392, 65.809204,
385                34.16554, 93.398605, 84.58669, 96.409645, 9.876037, 94.767784, 99.21523, 1.9330144,
386                94.92429, 75.12728, 17.218828, 97.89164, 35.476578, 77.629456, 69.573746,
387                40.200542, 42.117836, 5.861628, 75.45282, 82.73633, 0.98086596, 77.24894,
388                11.248695, 61.070026, 52.692616, 80.5449, 80.76036, 29.270136, 67.60252, 48.782394,
389                95.18851, 83.47162, 52.068756, 46.66002, 90.12216, 15.515327, 33.694042, 96.963036,
390                73.49627, 62.805485, 44.715607, 59.98627, 3.8921833, 37.565327, 29.69184,
391                39.429665, 83.46899, 44.286453, 21.54851, 56.096413, 18.169249, 5.214751,
392                14.691341, 99.779335, 26.32643, 67.69903, 36.41243, 67.27333, 12.157213, 96.18984,
393                2.438283, 78.14289, 0.14715195, 98.769, 53.649532, 21.615898, 39.657497, 95.45616,
394                18.578386, 71.47976, 22.348118, 17.85519, 6.3717127, 62.176777, 22.033644,
395                23.178005, 79.44858, 89.70233, 37.21273, 71.86182, 21.284317, 52.908623, 30.095518,
396                63.64478, 77.55823, 80.04871, 15.133011, 30.439043, 70.16561, 4.4014096, 89.28944,
397                26.29093, 46.827854, 11.764729, 61.887516, 47.774887, 57.19503, 59.444664,
398                28.592825, 98.70386, 1.2497544, 82.28431, 46.76423, 83.746124, 53.032673, 86.53457,
399                99.42168, 90.184, 92.27852, 9.059965, 71.75723, 70.45299, 10.924053, 68.329704,
400                77.27232, 6.677854, 75.63629, 57.370533, 17.09031, 10.554659, 99.56178, 37.53221,
401                72.311104, 75.7565, 65.2042, 36.096478, 64.69502, 38.88497, 64.33723, 84.87812,
402                66.84958, 8.508932, 79.134, 83.431015, 66.72124, 61.801838, 64.30524, 37.194263,
403                77.94725, 89.705185, 23.643505, 19.505919, 48.40264, 43.01083, 21.171177,
404                18.717121, 10.805857, 69.66983, 77.85261, 57.323063, 3.28964, 38.758026, 5.349946,
405                7.46572, 57.485138, 30.822384, 33.9411, 95.53746, 65.57723, 42.1077, 28.591347,
406                11.917269, 5.031073, 31.835615, 19.34116, 85.71027, 87.4516, 1.3798475, 70.70583,
407                51.988052, 45.217144, 14.308596, 54.557167, 86.18323, 79.13666, 76.866745,
408                46.010685, 79.739235, 44.667603, 39.36416, 72.605896, 73.83187, 13.137412,
409                6.7911267, 63.952374, 10.082436, 86.00318, 99.760376, 92.84948, 63.786434,
410                3.4429908, 18.244314, 75.65299, 14.964747, 70.126366, 80.89449, 91.266655,
411                96.58798, 46.439327, 38.253975, 87.31036, 21.093178, 37.19671, 58.28973, 9.75231,
412                12.350321, 25.75115, 87.65073, 53.610504, 36.850048, 18.66356, 94.48941, 83.71898,
413                44.49315, 44.186737, 19.360733, 84.365974, 46.76272, 44.924366, 50.279808,
414                54.868866, 91.33004, 18.683397, 75.13282, 15.070831, 47.04839, 53.780903,
415                26.911152, 74.65651, 57.659935, 25.604189, 37.235474, 65.39667, 53.952206,
416                40.37131, 59.173275, 96.00756, 54.591274, 10.787476, 69.51549, 31.970142,
417                25.408005, 55.972492, 85.01888, 97.48981, 91.006134, 28.98619, 97.151276,
418                34.388496, 47.498177, 11.985874, 64.73775, 33.877014, 13.370312, 34.79146,
419                86.19321, 15.019405, 94.07832, 93.50433, 60.168625, 50.95409, 38.27827, 47.458614,
420                32.83715, 69.54998, 69.0361, 84.1418, 34.270298, 74.23852, 70.707466, 78.59845,
421                9.651399, 24.186779, 58.255756, 53.72362, 92.46477, 97.75528, 20.257462, 30.122698,
422                50.41517, 28.156603, 42.644154,
423            ],
424        });
425
426        let distance = compare::<f32, 256>(256, Metric::L2, &two_vec.v);
427
428        assert_eq!(distance, 429141.2);
429    }
430
431    fn compare<T, const N: usize>(dim: usize, metric: Metric, v: &[f32]) -> f32
432    where
433        for<'a> [T; N]: FullPrecisionDistance<T, N>,
434    {
435        let a_ptr = v.as_ptr();
436        let b_ptr = unsafe { a_ptr.add(dim) };
437
438        let a_ref =
439            <&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(a_ptr, dim) }).unwrap();
440        let b_ref =
441            <&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(b_ptr, dim) }).unwrap();
442
443        <[f32; N]>::distance_compare(a_ref, b_ref, metric)
444    }
445}