1pub mod scalar;
11
12#[cfg(target_arch = "aarch64")]
13pub mod neon;
14
15#[cfg(target_arch = "x86_64")]
16pub mod avx2;
17
18#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
19pub mod avx512;
20
21pub mod checksum;
22pub mod convert;
23
24#[cfg(target_arch = "aarch64")]
33pub const CACHE_LINE_SIZE: usize = 128;
34
35#[cfg(target_arch = "x86_64")]
36pub const CACHE_LINE_SIZE: usize = 64;
37
38#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
39pub const CACHE_LINE_SIZE: usize = 64;
40
41#[inline]
43pub fn align_to_cache_line(size: usize) -> usize {
44 (size + CACHE_LINE_SIZE - 1) & !(CACHE_LINE_SIZE - 1)
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum Backend {
50 Neon,
52 Avx2,
54 Avx512,
56 Sse4,
58 WasmSimd128,
60 Scalar,
62}
63
64pub fn detect_backend() -> Backend {
66 #[cfg(target_arch = "aarch64")]
67 {
68 return Backend::Neon; }
70
71 #[cfg(target_arch = "x86_64")]
72 {
73 #[cfg(feature = "avx512")]
74 {
75 if is_x86_feature_detected!("avx512f") {
76 return Backend::Avx512;
77 }
78 }
79 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
80 return Backend::Avx2;
81 }
82 if is_x86_feature_detected!("sse4.1") {
83 return Backend::Sse4;
84 }
85 }
86
87 #[cfg(target_arch = "wasm32")]
88 {
89 return Backend::WasmSimd128;
90 }
91
92 #[allow(unreachable_code)]
93 Backend::Scalar
94}
95
96pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
102 match detect_backend() {
103 #[cfg(target_arch = "aarch64")]
104 Backend::Neon => unsafe { neon::dot_product(a, b) },
105
106 #[cfg(all(target_arch = "x86_64", feature = "avx512"))]
107 Backend::Avx512 => unsafe { avx512::dot_product(a, b) },
108
109 #[cfg(target_arch = "x86_64")]
110 Backend::Avx2 => unsafe { avx2::dot_product(a, b) },
111
112 _ => scalar::dot_product(a, b),
113 }
114}
115
116pub fn vector_norm(v: &[f32]) -> f32 {
118 dot_product(v, v).sqrt()
119}
120
121pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
123 match detect_backend() {
124 #[cfg(target_arch = "aarch64")]
125 Backend::Neon => unsafe { neon::cosine_similarity(a, b) },
126
127 #[cfg(all(target_arch = "x86_64", feature = "avx512"))]
128 Backend::Avx512 => unsafe { avx512::cosine_similarity(a, b) },
129
130 #[cfg(target_arch = "x86_64")]
131 Backend::Avx2 => unsafe { avx2::cosine_similarity(a, b) },
132
133 _ => scalar::cosine_similarity(a, b),
134 }
135}
136
137pub fn batch_cosine(query: &[f32], vectors: &[&[f32]], results: &mut [(usize, f32)]) {
141 assert!(results.len() >= vectors.len());
142 for (i, v) in vectors.iter().enumerate() {
143 results[i] = (i, cosine_similarity(query, v));
144 }
145}
146
147pub fn batch_cosine_prenorm(
152 query_normed: &[f32],
153 vectors: &[&[f32]],
154 norms: &[f32],
155 results: &mut [(usize, f32)],
156) {
157 assert!(results.len() >= vectors.len());
158 assert!(norms.len() >= vectors.len());
159 for (i, v) in vectors.iter().enumerate() {
160 let dot = dot_product(query_normed, v);
161 let sim = if norms[i] == 0.0 { 0.0 } else { dot / norms[i] };
162 results[i] = (i, sim);
163 }
164}
165
166pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
168 match detect_backend() {
169 #[cfg(target_arch = "aarch64")]
170 Backend::Neon => unsafe { neon::l2_distance(a, b) },
171
172 #[cfg(all(target_arch = "x86_64", feature = "avx512"))]
173 Backend::Avx512 => unsafe { avx512::l2_distance(a, b) },
174
175 #[cfg(target_arch = "x86_64")]
176 Backend::Avx2 => unsafe { avx2::l2_distance(a, b) },
177
178 _ => scalar::l2_distance(a, b),
179 }
180}
181
182pub fn batch_norms(vectors: &[&[f32]], norms: &mut [f32]) {
184 assert!(norms.len() >= vectors.len());
185 for (i, v) in vectors.iter().enumerate() {
186 norms[i] = vector_norm(v);
187 }
188}
189
190pub fn f16_to_f32_batch(input: &[u16], output: &mut [f32]) {
192 convert::f16_to_f32_batch(input, output);
193}
194
195pub fn checksum_fletcher32(data: &[u8]) -> u32 {
197 checksum::checksum_fletcher32(data)
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 const EPSILON: f32 = 1e-5;
205
206 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
207 (a - b).abs() < eps
208 }
209
210 #[test]
215 fn test_detect_backend_returns_valid() {
216 let backend = detect_backend();
217 match backend {
218 Backend::Neon
219 | Backend::Avx2
220 | Backend::Avx512
221 | Backend::Sse4
222 | Backend::WasmSimd128
223 | Backend::Scalar => {}
224 }
225 }
226
227 #[test]
228 fn test_detect_backend_consistent() {
229 let b1 = detect_backend();
230 let b2 = detect_backend();
231 assert_eq!(b1, b2);
232 }
233
234 #[test]
239 fn test_dot_product_known_values() {
240 let a = [1.0, 2.0, 3.0, 4.0];
241 let b = [5.0, 6.0, 7.0, 8.0];
242 let result = dot_product(&a, &b);
244 assert!(approx_eq(result, 70.0, EPSILON), "got {result}");
245 }
246
247 #[test]
248 fn test_dot_product_zero_vectors() {
249 let a = [0.0f32; 16];
250 let b = [1.0f32; 16];
251 assert!(approx_eq(dot_product(&a, &b), 0.0, EPSILON));
252 }
253
254 #[test]
255 fn test_dot_product_unit_vectors() {
256 let mut a = [0.0f32; 3];
257 let mut b = [0.0f32; 3];
258 a[0] = 1.0;
259 b[0] = 1.0;
260 assert!(approx_eq(dot_product(&a, &b), 1.0, EPSILON));
261 }
262
263 #[test]
264 fn test_dot_product_large_random() {
265 let n = 1024;
266 let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.01).collect();
267 let b: Vec<f32> = (0..n).map(|i| ((n - i) as f32) * 0.01).collect();
268 let scalar_result = scalar::dot_product(&a, &b);
269 let simd_result = dot_product(&a, &b);
270 assert!(
271 approx_eq(scalar_result, simd_result, 0.1),
272 "scalar={scalar_result} simd={simd_result}"
273 );
274 }
275
276 #[test]
277 fn test_dot_product_negative_values() {
278 let a = [-1.0, -2.0, -3.0];
279 let b = [1.0, 2.0, 3.0];
280 assert!(approx_eq(dot_product(&a, &b), -14.0, EPSILON));
281 }
282
283 #[test]
284 fn test_dot_product_single_element() {
285 assert!(approx_eq(dot_product(&[3.0], &[4.0]), 12.0, EPSILON));
286 }
287
288 #[test]
289 fn test_dot_product_empty() {
290 assert!(approx_eq(dot_product(&[], &[]), 0.0, EPSILON));
291 }
292
293 #[test]
294 fn test_dot_product_scalar_vs_dispatch() {
295 let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
296 let b: Vec<f32> = (0..384).map(|i| (i as f32).cos()).collect();
297 let s = scalar::dot_product(&a, &b);
298 let d = dot_product(&a, &b);
299 assert!(
300 approx_eq(s, d, 0.01),
301 "scalar={s} dispatched={d}"
302 );
303 }
304
305 #[test]
310 fn test_vector_norm_unit() {
311 let v = [1.0, 0.0, 0.0];
312 assert!(approx_eq(vector_norm(&v), 1.0, EPSILON));
313 }
314
315 #[test]
316 fn test_vector_norm_345() {
317 let v = [3.0, 4.0];
318 assert!(approx_eq(vector_norm(&v), 5.0, EPSILON));
319 }
320
321 #[test]
322 fn test_vector_norm_zero() {
323 let v = [0.0f32; 10];
324 assert!(approx_eq(vector_norm(&v), 0.0, EPSILON));
325 }
326
327 #[test]
332 fn test_cosine_identical_is_one() {
333 let v = [1.0, 2.0, 3.0, 4.0, 5.0];
334 assert!(approx_eq(cosine_similarity(&v, &v), 1.0, EPSILON));
335 }
336
337 #[test]
338 fn test_cosine_opposite_is_neg_one() {
339 let a = [1.0, 2.0, 3.0];
340 let b = [-1.0, -2.0, -3.0];
341 assert!(approx_eq(cosine_similarity(&a, &b), -1.0, EPSILON));
342 }
343
344 #[test]
345 fn test_cosine_orthogonal_is_zero() {
346 let a = [1.0, 0.0, 0.0, 0.0];
347 let b = [0.0, 1.0, 0.0, 0.0];
348 assert!(approx_eq(cosine_similarity(&a, &b), 0.0, EPSILON));
349 }
350
351 #[test]
352 fn test_cosine_zero_vector() {
353 let a = [0.0f32; 4];
354 let b = [1.0, 2.0, 3.0, 4.0];
355 assert!(approx_eq(cosine_similarity(&a, &b), 0.0, EPSILON));
356 }
357
358 #[test]
359 fn test_cosine_scalar_vs_dispatch() {
360 let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
361 let b: Vec<f32> = (0..384).map(|i| (i as f32 * 0.7).cos()).collect();
362 let s = scalar::cosine_similarity(&a, &b);
363 let d = cosine_similarity(&a, &b);
364 assert!(
365 approx_eq(s, d, 1e-4),
366 "scalar={s} dispatched={d}"
367 );
368 }
369
370 #[test]
375 fn test_batch_cosine_ranking_order() {
376 let query = [1.0, 0.0, 0.0];
377 let v0: Vec<f32> = vec![0.0, 1.0, 0.0]; let v1: Vec<f32> = vec![1.0, 0.0, 0.0]; let v2: Vec<f32> = vec![0.5, 0.5, 0.0]; let vectors: Vec<&[f32]> = vec![&v0, &v1, &v2];
381 let mut results = vec![(0usize, 0.0f32); 3];
382 batch_cosine(&query, &vectors, &mut results);
383
384 assert!(results[1].1 > results[2].1);
386 assert!(results[2].1 > results[0].1);
387 }
388
389 #[test]
390 fn test_batch_cosine_scalar_vs_dispatch() {
391 let query: Vec<f32> = (0..32).map(|i| (i as f32).sin()).collect();
392 let v0: Vec<f32> = (0..32).map(|i| (i as f32).cos()).collect();
393 let v1: Vec<f32> = (0..32).map(|i| (i as f32 * 2.0).sin()).collect();
394 let vectors: Vec<&[f32]> = vec![&v0, &v1];
395
396 let mut scalar_results = vec![(0usize, 0.0f32); 2];
397 scalar::batch_cosine(&query, &vectors, &mut scalar_results);
398
399 let mut simd_results = vec![(0usize, 0.0f32); 2];
400 batch_cosine(&query, &vectors, &mut simd_results);
401
402 for i in 0..2 {
403 assert!(
404 approx_eq(scalar_results[i].1, simd_results[i].1, 1e-4),
405 "mismatch at {i}: scalar={} simd={}",
406 scalar_results[i].1,
407 simd_results[i].1
408 );
409 }
410 }
411
412 #[test]
417 fn test_batch_cosine_prenorm() {
418 let query = [1.0, 0.0, 0.0]; let v0: Vec<f32> = vec![3.0, 4.0, 0.0];
420 let v1: Vec<f32> = vec![0.0, 0.0, 5.0];
421 let vectors: Vec<&[f32]> = vec![&v0, &v1];
422 let norms = [5.0, 5.0];
423 let mut results = vec![(0usize, 0.0f32); 2];
424 batch_cosine_prenorm(&query, &vectors, &norms, &mut results);
425 assert!(approx_eq(results[0].1, 0.6, EPSILON));
427 assert!(approx_eq(results[1].1, 0.0, EPSILON));
429 }
430
431 #[test]
436 fn test_l2_distance_same_is_zero() {
437 let v = [1.0, 2.0, 3.0, 4.0];
438 assert!(approx_eq(l2_distance(&v, &v), 0.0, EPSILON));
439 }
440
441 #[test]
442 fn test_l2_distance_known_triangle() {
443 let a = [0.0, 0.0];
444 let b = [3.0, 4.0];
445 assert!(approx_eq(l2_distance(&a, &b), 5.0, EPSILON));
446 }
447
448 #[test]
449 fn test_l2_distance_unit_axes() {
450 let a = [1.0, 0.0, 0.0];
451 let b = [0.0, 1.0, 0.0];
452 assert!(approx_eq(l2_distance(&a, &b), 2.0f32.sqrt(), EPSILON));
453 }
454
455 #[test]
456 fn test_l2_distance_scalar_vs_dispatch() {
457 let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
458 let b: Vec<f32> = (0..384).map(|i| (i as f32).cos()).collect();
459 let s = scalar::l2_distance(&a, &b);
460 let d = l2_distance(&a, &b);
461 assert!(
462 approx_eq(s, d, 0.01),
463 "scalar={s} dispatched={d}"
464 );
465 }
466
467 #[test]
472 fn test_batch_norms() {
473 let v0: Vec<f32> = vec![3.0, 4.0];
474 let v1: Vec<f32> = vec![0.0, 0.0];
475 let v2: Vec<f32> = vec![1.0, 0.0, 0.0];
476 let vectors: Vec<&[f32]> = vec![&v0, &v1, &v2];
477 let mut norms = vec![0.0f32; 3];
478 batch_norms(&vectors, &mut norms);
479 assert!(approx_eq(norms[0], 5.0, EPSILON));
480 assert!(approx_eq(norms[1], 0.0, EPSILON));
481 assert!(approx_eq(norms[2], 1.0, EPSILON));
482 }
483
484 #[test]
489 fn test_f16_to_f32_known_values() {
490 let input = [0x3C00u16, 0x4000, 0x0000]; let mut output = [0.0f32; 3];
493 f16_to_f32_batch(&input, &mut output);
494 assert!(approx_eq(output[0], 1.0, EPSILON), "got {}", output[0]);
495 assert!(approx_eq(output[1], 2.0, EPSILON), "got {}", output[1]);
496 assert!(approx_eq(output[2], 0.0, EPSILON), "got {}", output[2]);
497 }
498
499 #[test]
500 fn test_f16_to_f32_negative() {
501 let input = [0xBC00u16];
503 let mut output = [0.0f32; 1];
504 f16_to_f32_batch(&input, &mut output);
505 assert!(approx_eq(output[0], -1.0, EPSILON), "got {}", output[0]);
506 }
507
508 #[test]
509 fn test_f16_to_f32_batch_larger() {
510 let input: Vec<u16> = (0..32).map(|_| 0x3C00u16).collect(); let mut output = vec![0.0f32; 32];
513 f16_to_f32_batch(&input, &mut output);
514 for (i, &v) in output.iter().enumerate() {
515 assert!(approx_eq(v, 1.0, EPSILON), "mismatch at {i}: {v}");
516 }
517 }
518
519 #[test]
520 fn test_f16_to_f32_round_trip_accuracy() {
521 let cases: Vec<(u16, f32)> = vec![
523 (0x3C00, 1.0),
524 (0x4000, 2.0),
525 (0x3800, 0.5),
526 (0x4200, 3.0),
527 (0x4400, 4.0),
528 (0x0000, 0.0),
529 (0x8000, -0.0),
530 ];
531 let input: Vec<u16> = cases.iter().map(|(bits, _)| *bits).collect();
532 let mut output = vec![0.0f32; cases.len()];
533 f16_to_f32_batch(&input, &mut output);
534 for (i, (_, expected)) in cases.iter().enumerate() {
535 assert!(
536 approx_eq(output[i], *expected, EPSILON),
537 "f16 0x{:04X}: expected {expected}, got {}",
538 input[i],
539 output[i]
540 );
541 }
542 }
543
544 #[test]
549 fn test_fletcher32_empty() {
550 let result = checksum_fletcher32(&[]);
551 assert_eq!(result, 0xFFFF_FFFF);
553 }
554
555 #[test]
556 fn test_fletcher32_known() {
557 let data = [0x00u8, 0x01, 0x00, 0x02];
558 let result = checksum_fletcher32(&data);
559 let scalar = scalar::checksum_fletcher32(&data);
560 assert_eq!(result, scalar);
561 }
562
563 #[test]
564 fn test_fletcher32_scalar_vs_dispatch() {
565 let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
566 let s = scalar::checksum_fletcher32(&data);
567 let d = checksum_fletcher32(&data);
568 assert_eq!(s, d);
569 }
570
571 #[test]
576 fn test_dot_product_384_dim_perf() {
577 use std::time::Instant;
578 let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
579 let b: Vec<f32> = (0..384).map(|i| (i as f32).cos()).collect();
580
581 for _ in 0..100 {
583 let _ = dot_product(&a, &b);
584 }
585
586 let start = Instant::now();
587 let iterations = 10_000;
588 let mut sum = 0.0f32;
589 for _ in 0..iterations {
590 sum += dot_product(&a, &b);
591 }
592 let elapsed = start.elapsed();
593 let per_call = elapsed / iterations;
594 assert!(sum.abs() >= 0.0);
596
597 let limit_ns = if cfg!(debug_assertions) { 20_000 } else { 1_000 };
600 assert!(
601 per_call.as_nanos() < limit_ns,
602 "dot product too slow: {per_call:?} per call (limit {limit_ns}ns)"
603 );
604 }
605
606 #[test]
611 fn test_cache_line_size_is_power_of_two() {
612 assert!(CACHE_LINE_SIZE.is_power_of_two());
613 }
614
615 #[test]
616 fn test_cache_line_size_platform() {
617 #[cfg(target_arch = "aarch64")]
618 assert_eq!(CACHE_LINE_SIZE, 128);
619 #[cfg(target_arch = "x86_64")]
620 assert_eq!(CACHE_LINE_SIZE, 64);
621 }
622
623 #[test]
624 fn test_align_to_cache_line() {
625 assert_eq!(align_to_cache_line(0), 0);
626 assert_eq!(align_to_cache_line(1), CACHE_LINE_SIZE);
627 assert_eq!(align_to_cache_line(CACHE_LINE_SIZE), CACHE_LINE_SIZE);
628 assert_eq!(align_to_cache_line(CACHE_LINE_SIZE + 1), CACHE_LINE_SIZE * 2);
629 assert_eq!(align_to_cache_line(CACHE_LINE_SIZE * 3), CACHE_LINE_SIZE * 3);
630 }
631
632 #[test]
633 fn test_align_to_cache_line_64_and_128() {
634 let val = align_to_cache_line(100);
636 assert_eq!(val % CACHE_LINE_SIZE, 0);
637 assert!(val >= 100);
638 assert!(val < 100 + CACHE_LINE_SIZE);
639 }
640
641 #[test]
646 fn test_dot_product_non_aligned_length() {
647 for len in [1, 3, 5, 7, 9, 13, 17, 31, 33] {
649 let a: Vec<f32> = (0..len).map(|i| i as f32).collect();
650 let b: Vec<f32> = (0..len).map(|i| (i as f32) * 0.5).collect();
651 let s = scalar::dot_product(&a, &b);
652 let d = dot_product(&a, &b);
653 assert!(
654 approx_eq(s, d, 0.01),
655 "len={len}: scalar={s} dispatched={d}"
656 );
657 }
658 }
659
660 #[test]
661 fn test_cosine_non_aligned_length() {
662 for len in [1, 3, 5, 7, 9, 13, 17, 31, 33] {
663 let a: Vec<f32> = (0..len).map(|i| i as f32 + 1.0).collect();
664 let b: Vec<f32> = (0..len).map(|i| (i as f32 + 1.0) * 2.0).collect();
665 let s = scalar::cosine_similarity(&a, &b);
666 let d = cosine_similarity(&a, &b);
667 assert!(
668 approx_eq(s, d, 1e-4),
669 "len={len}: scalar={s} dispatched={d}"
670 );
671 }
672 }
673
674 #[test]
675 fn test_l2_distance_non_aligned_length() {
676 for len in [1, 3, 5, 7, 9, 13, 17, 31, 33] {
677 let a: Vec<f32> = (0..len).map(|i| i as f32).collect();
678 let b: Vec<f32> = (0..len).map(|i| (i as f32) + 1.0).collect();
679 let s = scalar::l2_distance(&a, &b);
680 let d = l2_distance(&a, &b);
681 assert!(
682 approx_eq(s, d, 0.01),
683 "len={len}: scalar={s} dispatched={d}"
684 );
685 }
686 }
687}