1use crate::vector::Metric;
10
11pub trait DistanceKernel: Send + Sync {
13 fn cosine(&self, query: &[f32], vector: &[f32]) -> f32;
15 fn l2(&self, query: &[f32], vector: &[f32]) -> f32;
17 fn inner_product(&self, query: &[f32], vector: &[f32]) -> f32;
19 fn batch_score(
21 &self,
22 metric: Metric,
23 query: &[f32],
24 vectors: &[f32],
25 dimension: usize,
26 scores: &mut [f32],
27 );
28}
29
30#[derive(Debug, Default)]
32pub struct ScalarKernel;
33
34impl ScalarKernel {
35 #[inline]
36 fn dot(query: &[f32], vector: &[f32]) -> f32 {
37 query
38 .iter()
39 .zip(vector.iter())
40 .map(|(a, b)| a * b)
41 .sum::<f32>()
42 }
43
44 #[inline]
45 fn norm(v: &[f32]) -> f32 {
46 v.iter().map(|x| x * x).sum::<f32>().sqrt()
47 }
48}
49
50impl DistanceKernel for ScalarKernel {
51 fn cosine(&self, query: &[f32], vector: &[f32]) -> f32 {
52 let dot = Self::dot(query, vector);
53 let q_norm = Self::norm(query);
54 let v_norm = Self::norm(vector);
55 if q_norm == 0.0 || v_norm == 0.0 {
56 0.0
57 } else {
58 dot / (q_norm * v_norm)
59 }
60 }
61
62 fn l2(&self, query: &[f32], vector: &[f32]) -> f32 {
63 let dist = query
64 .iter()
65 .zip(vector.iter())
66 .map(|(a, b)| {
67 let d = a - b;
68 d * d
69 })
70 .sum::<f32>()
71 .sqrt();
72 -dist
73 }
74
75 fn inner_product(&self, query: &[f32], vector: &[f32]) -> f32 {
76 Self::dot(query, vector)
77 }
78
79 fn batch_score(
80 &self,
81 metric: Metric,
82 query: &[f32],
83 vectors: &[f32],
84 dimension: usize,
85 scores: &mut [f32],
86 ) {
87 for (i, chunk) in vectors.chunks(dimension).enumerate() {
88 if i >= scores.len() {
89 break;
90 }
91 scores[i] = match metric {
92 Metric::Cosine => self.cosine(query, chunk),
93 Metric::L2 => self.l2(query, chunk),
94 Metric::InnerProduct => self.inner_product(query, chunk),
95 };
96 }
97 }
98}
99
100#[cfg(target_arch = "x86_64")]
105mod avx2 {
106 use super::{DistanceKernel, Metric};
107 use std::arch::x86_64::*;
108
109 #[derive(Debug, Default)]
110 pub struct Avx2Kernel;
111
112 #[inline]
113 fn horizontal_sum_ps(v: __m256) -> f32 {
114 unsafe {
115 let lo = _mm256_castps256_ps128(v);
117 let hi = _mm256_extractf128_ps(v, 1);
118 let sum128 = _mm_add_ps(lo, hi); let sum64 = _mm_hadd_ps(sum128, sum128); let sum32 = _mm_hadd_ps(sum64, sum64); _mm_cvtss_f32(sum32)
122 }
123 }
124
125 impl Avx2Kernel {
126 #[inline]
127 unsafe fn dot(query: &[f32], vector: &[f32]) -> f32 {
128 let mut acc = _mm256_setzero_ps();
129 let mut i = 0;
130 while i + 8 <= query.len() {
131 let q = _mm256_loadu_ps(query.as_ptr().add(i));
132 let v = _mm256_loadu_ps(vector.as_ptr().add(i));
133 acc = _mm256_fmadd_ps(q, v, acc);
135 i += 8;
136 }
137 let mut sum = horizontal_sum_ps(acc);
138 for j in i..query.len() {
139 sum += *query.get_unchecked(j) * *vector.get_unchecked(j);
140 }
141 sum
142 }
143
144 #[inline]
145 unsafe fn norm(v: &[f32]) -> f32 {
146 let mut acc = _mm256_setzero_ps();
147 let mut i = 0;
148 while i + 8 <= v.len() {
149 let x = _mm256_loadu_ps(v.as_ptr().add(i));
150 acc = _mm256_fmadd_ps(x, x, acc);
151 i += 8;
152 }
153 let mut sum = horizontal_sum_ps(acc);
154 for j in i..v.len() {
155 let x = *v.get_unchecked(j);
156 sum += x * x;
157 }
158 sum.sqrt()
159 }
160
161 #[inline]
162 unsafe fn cosine_impl(&self, query: &[f32], vector: &[f32]) -> f32 {
163 let dot = Self::dot(query, vector);
164 let q_norm = Self::norm(query);
165 let v_norm = Self::norm(vector);
166 if q_norm == 0.0 || v_norm == 0.0 {
167 0.0
168 } else {
169 dot / (q_norm * v_norm)
170 }
171 }
172
173 #[inline]
174 unsafe fn l2_impl(&self, query: &[f32], vector: &[f32]) -> f32 {
175 let mut acc = _mm256_setzero_ps();
176 let mut i = 0;
177 while i + 8 <= query.len() {
178 let q = _mm256_loadu_ps(query.as_ptr().add(i));
179 let v = _mm256_loadu_ps(vector.as_ptr().add(i));
180 let diff = _mm256_sub_ps(q, v);
181 acc = _mm256_fmadd_ps(diff, diff, acc);
182 i += 8;
183 }
184 let mut sum = horizontal_sum_ps(acc);
185 for j in i..query.len() {
186 let d = *query.get_unchecked(j) - *vector.get_unchecked(j);
187 sum += d * d;
188 }
189 -sum.sqrt()
190 }
191 }
192
193 impl DistanceKernel for Avx2Kernel {
194 fn cosine(&self, query: &[f32], vector: &[f32]) -> f32 {
195 unsafe { self.cosine_impl(query, vector) }
196 }
197
198 fn l2(&self, query: &[f32], vector: &[f32]) -> f32 {
199 unsafe { self.l2_impl(query, vector) }
200 }
201
202 fn inner_product(&self, query: &[f32], vector: &[f32]) -> f32 {
203 unsafe { Self::dot(query, vector) }
204 }
205
206 fn batch_score(
207 &self,
208 metric: Metric,
209 query: &[f32],
210 vectors: &[f32],
211 dimension: usize,
212 scores: &mut [f32],
213 ) {
214 for (i, chunk) in vectors.chunks(dimension).enumerate() {
215 if i >= scores.len() {
216 break;
217 }
218 scores[i] = match metric {
219 Metric::Cosine => unsafe { self.cosine_impl(query, chunk) },
220 Metric::L2 => unsafe { self.l2_impl(query, chunk) },
221 Metric::InnerProduct => unsafe { Self::dot(query, chunk) },
222 };
223 }
224 }
225 }
226
227 pub fn create() -> Box<dyn DistanceKernel> {
228 Box::new(Avx2Kernel)
229 }
230
231 #[cfg(all(test, not(target_arch = "wasm32")))]
232 mod tests {
233 use super::*;
234
235 #[test]
236 fn horizontal_sum_correct_for_ones() {
237 if !std::is_x86_feature_detected!("avx2") {
238 return;
239 }
240 unsafe {
241 let v = _mm256_set1_ps(1.0);
242 let total = horizontal_sum_ps(v);
243 assert!((total - 8.0).abs() < 1e-6);
244 }
245 }
246 }
247}
248
249#[cfg(target_arch = "aarch64")]
254mod neon {
255 use super::{DistanceKernel, Metric};
256 use core::arch::aarch64::*;
257
258 #[derive(Debug, Default)]
259 pub struct NeonKernel;
260
261 #[inline]
262 unsafe fn horizontal_sum(v: float32x4_t) -> f32 {
263 let pair_sum = vadd_f32(vget_low_f32(v), vget_high_f32(v));
264 let sum = vpadd_f32(pair_sum, pair_sum);
265 vget_lane_f32(sum, 0)
266 }
267
268 #[inline]
269 unsafe fn dot(query: &[f32], vector: &[f32]) -> f32 {
270 let mut acc = vdupq_n_f32(0.0);
271 let mut i = 0;
272 while i + 4 <= query.len() {
273 let q = vld1q_f32(query.as_ptr().add(i));
274 let v = vld1q_f32(vector.as_ptr().add(i));
275 acc = vfmaq_f32(acc, q, v);
276 i += 4;
277 }
278 let mut sum = horizontal_sum(acc);
279 for j in i..query.len() {
280 sum += *query.get_unchecked(j) * *vector.get_unchecked(j);
281 }
282 sum
283 }
284
285 #[inline]
286 unsafe fn norm(v: &[f32]) -> f32 {
287 let mut acc = vdupq_n_f32(0.0);
288 let mut i = 0;
289 while i + 4 <= v.len() {
290 let x = vld1q_f32(v.as_ptr().add(i));
291 acc = vfmaq_f32(acc, x, x);
292 i += 4;
293 }
294 let mut sum = horizontal_sum(acc);
295 for j in i..v.len() {
296 let x = *v.get_unchecked(j);
297 sum += x * x;
298 }
299 sum.sqrt()
300 }
301
302 impl DistanceKernel for NeonKernel {
303 fn cosine(&self, query: &[f32], vector: &[f32]) -> f32 {
304 unsafe {
305 let dot = dot(query, vector);
306 let q_norm = norm(query);
307 let v_norm = norm(vector);
308 if q_norm == 0.0 || v_norm == 0.0 {
309 0.0
310 } else {
311 dot / (q_norm * v_norm)
312 }
313 }
314 }
315
316 fn l2(&self, query: &[f32], vector: &[f32]) -> f32 {
317 unsafe {
318 let mut acc = vdupq_n_f32(0.0);
319 let mut i = 0;
320 while i + 4 <= query.len() {
321 let q = vld1q_f32(query.as_ptr().add(i));
322 let v = vld1q_f32(vector.as_ptr().add(i));
323 let diff = vsubq_f32(q, v);
324 acc = vfmaq_f32(acc, diff, diff);
325 i += 4;
326 }
327 let mut sum = horizontal_sum(acc);
328 for j in i..query.len() {
329 let d = *query.get_unchecked(j) - *vector.get_unchecked(j);
330 sum += d * d;
331 }
332 -sum.sqrt()
333 }
334 }
335
336 fn inner_product(&self, query: &[f32], vector: &[f32]) -> f32 {
337 unsafe { dot(query, vector) }
338 }
339
340 fn batch_score(
341 &self,
342 metric: Metric,
343 query: &[f32],
344 vectors: &[f32],
345 dimension: usize,
346 scores: &mut [f32],
347 ) {
348 for (i, chunk) in vectors.chunks(dimension).enumerate() {
349 if i >= scores.len() {
350 break;
351 }
352 scores[i] = match metric {
353 Metric::Cosine => self.cosine(query, chunk),
354 Metric::L2 => self.l2(query, chunk),
355 Metric::InnerProduct => self.inner_product(query, chunk),
356 };
357 }
358 }
359 }
360
361 pub fn create() -> Box<dyn DistanceKernel> {
362 Box::new(NeonKernel)
363 }
364}
365
366pub fn select_kernel() -> Box<dyn DistanceKernel> {
368 #[cfg(target_arch = "x86_64")]
369 {
370 if std::is_x86_feature_detected!("avx2") {
371 return avx2::create();
372 }
373 }
374
375 #[cfg(target_arch = "aarch64")]
376 {
377 if std::arch::is_aarch64_feature_detected!("neon") {
378 return neon::create();
379 }
380 }
381
382 Box::new(ScalarKernel)
383}
384
385#[cfg(all(test, not(target_arch = "wasm32")))]
386mod tests {
387 use super::*;
388 use crate::vector::score;
389
390 #[test]
391 fn scalar_matches_reference() {
392 let k = ScalarKernel;
393 let q = [1.0, 2.0, 3.0, 4.0];
394 let v = [4.0, 3.0, 2.0, 1.0];
395 let metrics = [Metric::Cosine, Metric::L2, Metric::InnerProduct];
396 for &m in &metrics {
397 let ref_score = score(m, &q, &v).unwrap();
398 let k_score = match m {
399 Metric::Cosine => k.cosine(&q, &v),
400 Metric::L2 => k.l2(&q, &v),
401 Metric::InnerProduct => k.inner_product(&q, &v),
402 };
403 assert!((ref_score - k_score).abs() < 1e-6);
404 }
405 }
406
407 #[test]
408 fn scalar_cosine_zero_norm_returns_zero() {
409 let k = ScalarKernel;
410 let q = [0.0, 0.0, 0.0];
411 let v = [1.0, 2.0, 3.0];
412 assert_eq!(k.cosine(&q, &v), 0.0);
413 }
414
415 #[test]
416 fn batch_score_populates_all() {
417 let k = ScalarKernel;
418 let q = [1.0, 0.0];
419 let vectors = [1.0, 0.0, 0.0, 1.0];
420 let mut scores = [0.0f32; 2];
421 k.batch_score(Metric::InnerProduct, &q, &vectors, 2, &mut scores);
422 assert_eq!(scores[0], 1.0);
423 assert_eq!(scores[1], 0.0);
424 }
425
426 #[test]
427 fn select_kernel_returns_any() {
428 let k = select_kernel();
429 let q = [1.0, 2.0];
430 let v = [2.0, 1.0];
431 let s = k.inner_product(&q, &v);
432 assert!((s - 4.0).abs() < 1e-6);
433 }
434
435 #[test]
436 fn select_kernel_matches_scalar_for_all_metrics() {
437 let kernel = select_kernel();
438 let scalar = ScalarKernel;
439 let q = vec![1.0f32, 2.0, 3.0, 4.0];
440 let v1 = vec![4.0f32, 3.0, 2.0, 1.0];
441 let v2 = vec![1.0f32, 1.0, 1.0, 1.0];
442
443 let metrics = [Metric::Cosine, Metric::L2, Metric::InnerProduct];
444 for &m in &metrics {
445 let s1 = match m {
446 Metric::Cosine => scalar.cosine(&q, &v1),
447 Metric::L2 => scalar.l2(&q, &v1),
448 Metric::InnerProduct => scalar.inner_product(&q, &v1),
449 };
450 let k1 = match m {
451 Metric::Cosine => kernel.cosine(&q, &v1),
452 Metric::L2 => kernel.l2(&q, &v1),
453 Metric::InnerProduct => kernel.inner_product(&q, &v1),
454 };
455 assert!((s1 - k1).abs() < 1e-6);
456
457 let s2 = match m {
458 Metric::Cosine => scalar.cosine(&q, &v2),
459 Metric::L2 => scalar.l2(&q, &v2),
460 Metric::InnerProduct => scalar.inner_product(&q, &v2),
461 };
462 let k2 = match m {
463 Metric::Cosine => kernel.cosine(&q, &v2),
464 Metric::L2 => kernel.l2(&q, &v2),
465 Metric::InnerProduct => kernel.inner_product(&q, &v2),
466 };
467 assert!((s2 - k2).abs() < 1e-6);
468 }
469 }
470
471 fn assert_same_f32(a: f32, b: f32) {
472 if a.is_nan() && b.is_nan() {
473 return;
474 }
475 if a.is_infinite() && b.is_infinite() {
476 assert_eq!(a.is_sign_positive(), b.is_sign_positive());
477 return;
478 }
479 assert!((a - b).abs() < 1e-5, "a={a}, b={b}");
480 }
481
482 #[test]
483 fn kernel_handles_nan_and_inf_like_scalar() {
484 let kernel = select_kernel();
485 let scalar = ScalarKernel;
486 let cases = vec![
487 (
488 Metric::Cosine,
489 vec![f32::NAN, 1.0, 2.0],
490 vec![1.0, 2.0, 3.0],
491 ),
492 (
493 Metric::InnerProduct,
494 vec![f32::INFINITY, 1.0],
495 vec![1.0, 2.0],
496 ),
497 (Metric::L2, vec![f32::INFINITY, 0.0], vec![1.0, 0.0]),
498 ];
499
500 for (metric, q, v) in cases {
501 let s = match metric {
502 Metric::Cosine => scalar.cosine(&q, &v),
503 Metric::L2 => scalar.l2(&q, &v),
504 Metric::InnerProduct => scalar.inner_product(&q, &v),
505 };
506 let k = match metric {
507 Metric::Cosine => kernel.cosine(&q, &v),
508 Metric::L2 => kernel.l2(&q, &v),
509 Metric::InnerProduct => kernel.inner_product(&q, &v),
510 };
511 assert_same_f32(s, k);
512 }
513 }
514
515 #[test]
516 fn cosine_with_nan_matches_scalar() {
517 let kernel = select_kernel();
518 let scalar = ScalarKernel;
519 let q = [f32::NAN, 1.0, 2.0, 3.0];
520 let v = [1.0, 2.0, 3.0, 4.0];
521 let s = scalar.cosine(&q, &v);
522 let k = kernel.cosine(&q, &v);
523 assert_same_f32(s, k);
524 }
525
526 #[test]
527 fn l2_with_inf_matches_scalar() {
528 let kernel = select_kernel();
529 let scalar = ScalarKernel;
530 let q = [f32::INFINITY, 0.0, 1.0];
531 let v = [1.0, 0.0, 1.0];
532 let s = scalar.l2(&q, &v);
533 let k = kernel.l2(&q, &v);
534 assert_same_f32(s, k);
535 }
536
537 #[test]
538 fn inner_product_with_nan_matches_scalar() {
539 let kernel = select_kernel();
540 let scalar = ScalarKernel;
541 let q = [1.0, f32::NAN];
542 let v = [2.0, 3.0];
543 let s = scalar.inner_product(&q, &v);
544 let k = kernel.inner_product(&q, &v);
545 assert_same_f32(s, k);
546 }
547
548 #[test]
549 fn batch_score_propagates_nan_inf_like_scalar() {
550 let kernel = select_kernel();
551 let scalar = ScalarKernel;
552 let q = [1.0, f32::NAN];
553 let vectors = [2.0, 3.0, f32::INFINITY, 0.0];
554 let mut scores_kernel = [0.0f32; 2];
555 let mut scores_scalar = [0.0f32; 2];
556 kernel.batch_score(Metric::InnerProduct, &q, &vectors, 2, &mut scores_kernel);
557 scalar.batch_score(Metric::InnerProduct, &q, &vectors, 2, &mut scores_scalar);
558 for (a, b) in scores_scalar.iter().zip(scores_kernel.iter()) {
559 assert_same_f32(*a, *b);
560 }
561 }
562}