1use crate::l2_float_distance::{
6 distance_cosine_vector_f32, distance_l2_vector_f16, distance_l2_vector_f32,
7};
8use crate::{Half, Metric};
9
10pub trait FullPrecisionDistance<T, const N: usize> {
12 fn distance_compare(a: &[T; N], b: &[T; N], vec_type: Metric) -> f32;
14}
15
16#[allow(clippy::panic)]
18impl<const N: usize> FullPrecisionDistance<f32, N> for [f32; N] {
19 #[inline(always)]
21 fn distance_compare(a: &[f32; N], b: &[f32; N], metric: Metric) -> f32 {
22 match metric {
23 Metric::L2 => distance_l2_vector_f32::<N>(a, b),
24 Metric::Cosine => distance_cosine_vector_f32::<N>(a, b),
25 }
27 }
28}
29
30#[allow(clippy::panic)]
32impl<const N: usize> FullPrecisionDistance<Half, N> for [Half; N] {
33 fn distance_compare(a: &[Half; N], b: &[Half; N], metric: Metric) -> f32 {
34 match metric {
35 Metric::L2 => distance_l2_vector_f16::<N>(a, b),
36 _ => panic!("Not supported Metric type {:?}", metric),
37 }
38 }
39}
40
41#[allow(clippy::panic)]
43impl<const N: usize> FullPrecisionDistance<i8, N> for [i8; N] {
44 fn distance_compare(_a: &[i8; N], _b: &[i8; N], _metric: Metric) -> f32 {
45 panic!("Not supported VectorType i8")
46 }
47}
48
49#[allow(clippy::panic)]
51impl<const N: usize> FullPrecisionDistance<u8, N> for [u8; N] {
52 fn distance_compare(_a: &[u8; N], _b: &[u8; N], _metric: Metric) -> f32 {
53 panic!("Not supported VectorType u8")
54 }
55}
56
57#[cfg(test)]
58mod distance_test {
59 use super::*;
60
61 #[repr(C, align(32))]
62 pub struct F32Slice112([f32; 112]);
63
64 #[repr(C, align(32))]
65 pub struct F16Slice112([Half; 112]);
66
67 fn get_turing_test_data() -> (F32Slice112, F32Slice112) {
68 let a_slice: [f32; 112] = [
69 0.13961786,
70 -0.031577103,
71 -0.09567415,
72 0.06695563,
73 -0.1588727,
74 0.089852564,
75 -0.019837005,
76 0.07497972,
77 0.010418192,
78 -0.054594643,
79 0.08613386,
80 -0.05103466,
81 0.16568437,
82 -0.02703799,
83 0.00728657,
84 -0.15313251,
85 0.16462992,
86 -0.030570814,
87 0.11635703,
88 0.23938893,
89 0.018022912,
90 -0.12646551,
91 0.018048918,
92 -0.035986554,
93 0.031986624,
94 -0.015286017,
95 0.010117953,
96 -0.032691937,
97 0.12163067,
98 -0.04746277,
99 0.010213069,
100 -0.043672588,
101 -0.099362016,
102 0.06599016,
103 -0.19397286,
104 -0.13285528,
105 -0.22040887,
106 0.017690737,
107 -0.104262285,
108 -0.0044555613,
109 -0.07383778,
110 -0.108652934,
111 0.13399786,
112 0.054912474,
113 0.20181285,
114 0.1795591,
115 -0.05425621,
116 -0.10765217,
117 0.1405377,
118 -0.14101997,
119 -0.12017701,
120 0.011565498,
121 0.06952187,
122 0.060136646,
123 0.0023214167,
124 0.04204699,
125 0.048470616,
126 0.17398086,
127 0.024218207,
128 -0.15626553,
129 -0.11291045,
130 -0.09688122,
131 0.14393932,
132 -0.14713104,
133 -0.108876854,
134 0.035279203,
135 -0.05440188,
136 0.017205412,
137 0.011413814,
138 0.04009471,
139 0.11070237,
140 -0.058998976,
141 0.07260045,
142 -0.057893746,
143 -0.0036240944,
144 -0.0064988653,
145 -0.13842176,
146 -0.023219328,
147 0.0035885905,
148 -0.0719257,
149 -0.21335067,
150 0.11415403,
151 -0.0059823603,
152 0.12091869,
153 0.08136634,
154 -0.10769281,
155 0.024518685,
156 0.0009200326,
157 -0.11628049,
158 0.07448965,
159 0.13736208,
160 -0.04144517,
161 -0.16426727,
162 -0.06380103,
163 -0.21386267,
164 0.022373492,
165 -0.05874115,
166 0.017314062,
167 -0.040344074,
168 0.01059176,
169 0.0,
170 0.0,
171 0.0,
172 0.0,
173 0.0,
174 0.0,
175 0.0,
176 0.0,
177 0.0,
178 0.0,
179 0.0,
180 0.0,
181 ];
182 let b_slice: [f32; 112] = [
183 -0.07209058,
184 -0.17755842,
185 -0.030627966,
186 0.163028,
187 -0.2233766,
188 0.057412963,
189 0.0076995124,
190 -0.017121306,
191 -0.015759075,
192 -0.026947778,
193 -0.010282468,
194 -0.23968373,
195 -0.021486737,
196 -0.09903155,
197 0.09361805,
198 0.0042711576,
199 -0.08695552,
200 -0.042165346,
201 0.064218745,
202 -0.06707651,
203 0.07846054,
204 0.12235762,
205 -0.060716823,
206 0.18496591,
207 -0.13023394,
208 0.022469055,
209 0.056764495,
210 0.07168404,
211 -0.08856144,
212 -0.15343173,
213 0.099879816,
214 -0.033529017,
215 0.0795304,
216 -0.009242254,
217 -0.10254546,
218 0.13086525,
219 -0.101518914,
220 -0.1031299,
221 -0.056826904,
222 0.033196196,
223 0.044143833,
224 -0.049787212,
225 -0.018148342,
226 -0.11172959,
227 -0.06776237,
228 -0.09185828,
229 -0.24171598,
230 0.05080982,
231 -0.0727684,
232 0.045031235,
233 -0.11363879,
234 -0.063389264,
235 0.105850354,
236 -0.19847773,
237 0.08828623,
238 -0.087071925,
239 0.033512704,
240 0.16118294,
241 0.14111553,
242 0.020884402,
243 -0.088860825,
244 0.018745849,
245 0.047522716,
246 -0.03665169,
247 0.15726231,
248 -0.09930561,
249 0.057844743,
250 -0.10532736,
251 -0.091297254,
252 0.067029804,
253 0.04153976,
254 0.06393326,
255 0.054578528,
256 0.0038539872,
257 0.1023088,
258 -0.10653885,
259 -0.108500294,
260 -0.046606563,
261 0.020439683,
262 -0.120957725,
263 -0.13334097,
264 -0.13425854,
265 -0.20481694,
266 0.07009538,
267 0.08660361,
268 -0.0096641015,
269 0.095316306,
270 -0.002898167,
271 -0.19680002,
272 0.08466311,
273 0.04812689,
274 -0.028978813,
275 0.04780206,
276 -0.2001506,
277 -0.036866356,
278 -0.023720587,
279 0.10731964,
280 0.05517358,
281 -0.09580819,
282 0.14595725,
283 0.0,
284 0.0,
285 0.0,
286 0.0,
287 0.0,
288 0.0,
289 0.0,
290 0.0,
291 0.0,
292 0.0,
293 0.0,
294 0.0,
295 ];
296
297 (F32Slice112(a_slice), F32Slice112(b_slice))
298 }
299
300 fn get_turing_test_data_f16() -> (F16Slice112, F16Slice112) {
301 let (a_slice, b_slice) = get_turing_test_data();
302 let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x));
303 let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x));
304
305 (
306 F16Slice112(a_data.collect::<Vec<Half>>().try_into().unwrap()),
307 F16Slice112(b_data.collect::<Vec<Half>>().try_into().unwrap()),
308 )
309 }
310
311 use crate::test_util::*;
312 use approx::assert_abs_diff_eq;
313
314 #[test]
315 fn test_dist_l2_float_turing() {
316 let (a_slice, b_slice) = get_turing_test_data();
318 let distance = <[f32; 112] as FullPrecisionDistance<f32, 112>>::distance_compare(
319 &a_slice.0,
320 &b_slice.0,
321 Metric::L2,
322 );
323
324 assert_abs_diff_eq!(
325 distance,
326 no_vector_compare_f32(&a_slice.0, &b_slice.0),
327 epsilon = 1e-6
328 );
329 }
330
331 #[test]
332 fn test_dist_l2_f16_turing() {
333 let (a_slice, b_slice) = get_turing_test_data_f16();
335 let distance = <[Half; 112] as FullPrecisionDistance<Half, 112>>::distance_compare(
336 &a_slice.0,
337 &b_slice.0,
338 Metric::L2,
339 );
340
341 assert_eq!(distance, no_vector_compare_f16(&a_slice.0, &b_slice.0));
343 }
344
345 #[test]
346 fn distance_test() {
347 #[repr(C, align(32))]
348 struct Vector32ByteAligned {
349 v: [f32; 512],
350 }
351
352 let two_vec = Box::new(Vector32ByteAligned {
354 v: [
355 69.02492, 78.84786, 63.125072, 90.90581, 79.2592, 70.81731, 3.0829668, 33.33287,
356 20.777142, 30.147898, 23.681915, 42.553043, 12.602162, 7.3808074, 19.157589,
357 65.6791, 76.44677, 76.89124, 86.40756, 84.70118, 87.86142, 16.126896, 5.1277637,
358 95.11038, 83.946945, 22.735607, 11.548555, 59.51482, 24.84603, 15.573776, 78.27185,
359 71.13179, 38.574017, 80.0228, 13.175261, 62.887978, 15.205181, 18.89392, 96.13162,
360 87.55455, 34.179806, 62.920044, 4.9305916, 54.349373, 21.731495, 14.982187,
361 40.262867, 20.15214, 36.61963, 72.450806, 55.565, 95.5375, 93.73356, 95.36308,
362 66.30762, 58.0397, 18.951357, 67.11702, 43.043316, 30.65622, 99.85361, 2.5889993,
363 27.844774, 39.72441, 46.463238, 71.303764, 90.45308, 36.390602, 63.344395,
364 26.427078, 35.99528, 82.35505, 32.529175, 23.165905, 74.73179, 9.856939, 59.38126,
365 35.714924, 79.81213, 46.704124, 24.47884, 36.01743, 0.46678782, 29.528152,
366 1.8980742, 24.68853, 75.58984, 98.72279, 68.62601, 11.890173, 49.49361, 55.45572,
367 72.71067, 34.107483, 51.357758, 76.400635, 81.32725, 66.45081, 17.848074,
368 62.398876, 94.20444, 2.10886, 17.416393, 64.88253, 29.000723, 62.434315, 53.907238,
369 70.51412, 78.70744, 55.181683, 64.45116, 23.419212, 53.68544, 43.506958, 46.89598,
370 35.905994, 64.51397, 91.95555, 20.322979, 74.80128, 97.548744, 58.312725, 78.81985,
371 31.911612, 14.445949, 49.85094, 70.87396, 40.06766, 7.129991, 78.48008, 75.21636,
372 93.623604, 95.95479, 29.571129, 22.721554, 26.73875, 52.075504, 56.783104,
373 94.65493, 61.778534, 85.72401, 85.369514, 29.922367, 41.410553, 94.12884,
374 80.276855, 55.604828, 54.70947, 74.07216, 44.61955, 31.38113, 68.48596, 34.56782,
375 14.424729, 48.204506, 9.675444, 32.01946, 92.32695, 36.292683, 78.31955, 98.05327,
376 14.343918, 46.017002, 95.90888, 82.63626, 16.873539, 3.698051, 7.8042626,
377 64.194405, 96.71023, 67.93692, 21.618402, 51.92182, 22.834194, 61.56986, 19.749891,
378 55.31206, 38.29552, 67.57593, 67.145836, 38.92673, 94.95708, 72.38746, 90.70901,
379 69.43995, 9.394085, 31.646872, 88.20112, 9.134722, 99.98214, 5.423498, 41.51995,
380 76.94409, 77.373276, 3.2966614, 9.611201, 57.231106, 30.747868, 76.10228, 91.98308,
381 70.893585, 0.9067178, 43.96515, 16.321218, 27.734184, 83.271835, 88.23312,
382 87.16445, 5.556643, 15.627432, 58.547127, 93.6459, 40.539192, 49.124157, 91.13276,
383 57.485855, 8.827019, 4.9690843, 46.511234, 53.91469, 97.71925, 20.135271,
384 23.353004, 70.92099, 93.38748, 87.520134, 51.684677, 29.89813, 9.110392, 65.809204,
385 34.16554, 93.398605, 84.58669, 96.409645, 9.876037, 94.767784, 99.21523, 1.9330144,
386 94.92429, 75.12728, 17.218828, 97.89164, 35.476578, 77.629456, 69.573746,
387 40.200542, 42.117836, 5.861628, 75.45282, 82.73633, 0.98086596, 77.24894,
388 11.248695, 61.070026, 52.692616, 80.5449, 80.76036, 29.270136, 67.60252, 48.782394,
389 95.18851, 83.47162, 52.068756, 46.66002, 90.12216, 15.515327, 33.694042, 96.963036,
390 73.49627, 62.805485, 44.715607, 59.98627, 3.8921833, 37.565327, 29.69184,
391 39.429665, 83.46899, 44.286453, 21.54851, 56.096413, 18.169249, 5.214751,
392 14.691341, 99.779335, 26.32643, 67.69903, 36.41243, 67.27333, 12.157213, 96.18984,
393 2.438283, 78.14289, 0.14715195, 98.769, 53.649532, 21.615898, 39.657497, 95.45616,
394 18.578386, 71.47976, 22.348118, 17.85519, 6.3717127, 62.176777, 22.033644,
395 23.178005, 79.44858, 89.70233, 37.21273, 71.86182, 21.284317, 52.908623, 30.095518,
396 63.64478, 77.55823, 80.04871, 15.133011, 30.439043, 70.16561, 4.4014096, 89.28944,
397 26.29093, 46.827854, 11.764729, 61.887516, 47.774887, 57.19503, 59.444664,
398 28.592825, 98.70386, 1.2497544, 82.28431, 46.76423, 83.746124, 53.032673, 86.53457,
399 99.42168, 90.184, 92.27852, 9.059965, 71.75723, 70.45299, 10.924053, 68.329704,
400 77.27232, 6.677854, 75.63629, 57.370533, 17.09031, 10.554659, 99.56178, 37.53221,
401 72.311104, 75.7565, 65.2042, 36.096478, 64.69502, 38.88497, 64.33723, 84.87812,
402 66.84958, 8.508932, 79.134, 83.431015, 66.72124, 61.801838, 64.30524, 37.194263,
403 77.94725, 89.705185, 23.643505, 19.505919, 48.40264, 43.01083, 21.171177,
404 18.717121, 10.805857, 69.66983, 77.85261, 57.323063, 3.28964, 38.758026, 5.349946,
405 7.46572, 57.485138, 30.822384, 33.9411, 95.53746, 65.57723, 42.1077, 28.591347,
406 11.917269, 5.031073, 31.835615, 19.34116, 85.71027, 87.4516, 1.3798475, 70.70583,
407 51.988052, 45.217144, 14.308596, 54.557167, 86.18323, 79.13666, 76.866745,
408 46.010685, 79.739235, 44.667603, 39.36416, 72.605896, 73.83187, 13.137412,
409 6.7911267, 63.952374, 10.082436, 86.00318, 99.760376, 92.84948, 63.786434,
410 3.4429908, 18.244314, 75.65299, 14.964747, 70.126366, 80.89449, 91.266655,
411 96.58798, 46.439327, 38.253975, 87.31036, 21.093178, 37.19671, 58.28973, 9.75231,
412 12.350321, 25.75115, 87.65073, 53.610504, 36.850048, 18.66356, 94.48941, 83.71898,
413 44.49315, 44.186737, 19.360733, 84.365974, 46.76272, 44.924366, 50.279808,
414 54.868866, 91.33004, 18.683397, 75.13282, 15.070831, 47.04839, 53.780903,
415 26.911152, 74.65651, 57.659935, 25.604189, 37.235474, 65.39667, 53.952206,
416 40.37131, 59.173275, 96.00756, 54.591274, 10.787476, 69.51549, 31.970142,
417 25.408005, 55.972492, 85.01888, 97.48981, 91.006134, 28.98619, 97.151276,
418 34.388496, 47.498177, 11.985874, 64.73775, 33.877014, 13.370312, 34.79146,
419 86.19321, 15.019405, 94.07832, 93.50433, 60.168625, 50.95409, 38.27827, 47.458614,
420 32.83715, 69.54998, 69.0361, 84.1418, 34.270298, 74.23852, 70.707466, 78.59845,
421 9.651399, 24.186779, 58.255756, 53.72362, 92.46477, 97.75528, 20.257462, 30.122698,
422 50.41517, 28.156603, 42.644154,
423 ],
424 });
425
426 let distance = compare::<f32, 256>(256, Metric::L2, &two_vec.v);
427
428 assert_eq!(distance, 429141.2);
429 }
430
431 fn compare<T, const N: usize>(dim: usize, metric: Metric, v: &[f32]) -> f32
432 where
433 for<'a> [T; N]: FullPrecisionDistance<T, N>,
434 {
435 let a_ptr = v.as_ptr();
436 let b_ptr = unsafe { a_ptr.add(dim) };
437
438 let a_ref =
439 <&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(a_ptr, dim) }).unwrap();
440 let b_ref =
441 <&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(b_ptr, dim) }).unwrap();
442
443 <[f32; N]>::distance_compare(a_ref, b_ref, metric)
444 }
445}