oxify_vector/
simd.rs

1//! SIMD-accelerated distance calculations
2//!
3//! This module provides SIMD-optimized implementations of distance metrics
4//! for improved performance on supported CPUs.
5//!
6//! ## Features
7//!
8//! - **Auto-vectorization hints**: Helps the compiler generate SIMD code
9//! - **x86_64 optimizations**: Automatically uses AVX-512, FMA+AVX2, or AVX2 when available
10//! - **aarch64 optimizations**: Automatically uses NEON (always available on ARM64)
11//! - **Cache-friendly**: Optimized memory access patterns
12//! - **Fallback**: Automatically falls back to auto-vectorization on unsupported platforms
13//!
14//! ## Performance Hierarchy
15//!
16//! - **x86_64**: AVX-512 (16-wide) → FMA+AVX2 (8-wide) → AVX2 (8-wide) → auto-vectorization
17//! - **aarch64**: NEON (4-wide)
18//! - **other**: auto-vectorization
19//!
20//! ## Usage
21//!
22//! ```rust
23//! use oxify_vector::simd::cosine_similarity_simd;
24//!
25//! let v1 = vec![1.0, 2.0, 3.0, 4.0];
26//! let v2 = vec![2.0, 3.0, 4.0, 5.0];
27//! let similarity = cosine_similarity_simd(&v1, &v2);
28//! ```
29
30// Allow unreachable code for architecture-specific optimizations
31// On aarch64, NEON is always available so fallback code is never reached
32// On x86_64 with AVX2, fallback code may not be reached
33#![allow(unreachable_code)]
34
35use crate::types::DistanceMetric;
36
37// AVX2 intrinsics (x86_64 only)
38#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41// NEON intrinsics (aarch64 only)
42#[cfg(target_arch = "aarch64")]
43use std::arch::aarch64::*;
44
45/// Check if AVX2 is available at runtime
46#[cfg(target_arch = "x86_64")]
47#[inline]
48pub fn is_avx2_available() -> bool {
49    is_x86_feature_detected!("avx2")
50}
51
52/// Check if AVX2 is available at runtime (non-x86_64 always returns false)
53#[cfg(not(target_arch = "x86_64"))]
54#[inline]
55pub fn is_avx2_available() -> bool {
56    false
57}
58
59/// Check if FMA is available at runtime
60#[cfg(target_arch = "x86_64")]
61#[inline]
62pub fn is_fma_available() -> bool {
63    is_x86_feature_detected!("fma")
64}
65
66/// Check if FMA is available at runtime (non-x86_64 always returns false)
67#[cfg(not(target_arch = "x86_64"))]
68#[inline]
69pub fn is_fma_available() -> bool {
70    false
71}
72
73/// Check if NEON is available at runtime (aarch64 always has NEON)
74#[cfg(target_arch = "aarch64")]
75#[inline]
76pub fn is_neon_available() -> bool {
77    // NEON is a mandatory feature on aarch64, always available
78    true
79}
80
81/// Check if NEON is available at runtime (non-aarch64 always returns false)
82#[cfg(not(target_arch = "aarch64"))]
83#[inline]
84pub fn is_neon_available() -> bool {
85    false
86}
87
88/// Check if AVX-512 is available at runtime (x86_64 only)
89#[cfg(target_arch = "x86_64")]
90#[inline]
91pub fn is_avx512_available() -> bool {
92    is_x86_feature_detected!("avx512f")
93}
94
95/// Check if AVX-512 is available at runtime (non-x86_64 always returns false)
96#[cfg(not(target_arch = "x86_64"))]
97#[inline]
98pub fn is_avx512_available() -> bool {
99    false
100}
101
102// ============================================================================
103// AVX-512 Explicit Intrinsics (x86_64 only)
104// ============================================================================
105
106/// Horizontal sum of 16 f32 values in a 512-bit AVX-512 register
107#[cfg(target_arch = "x86_64")]
108#[inline]
109unsafe fn horizontal_sum_avx512(v: __m512) -> f32 {
110    // Extract high and low 256-bit lanes
111    let low = _mm512_castps512_ps256(v); // Lower 8 elements
112    let high = _mm512_extractf32x8_ps(v, 1); // Upper 8 elements
113
114    // Add them together into a 256-bit vector
115    let sum256 = _mm256_add_ps(low, high);
116
117    // Use existing AVX2 horizontal sum
118    horizontal_sum_avx2(sum256)
119}
120
121/// AVX-512 optimized dot product (x86_64 only)
122#[cfg(target_arch = "x86_64")]
123#[target_feature(enable = "avx512f")]
124#[inline]
125unsafe fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 {
126    let len = a.len();
127    let mut sum = _mm512_setzero_ps();
128
129    // Process 16 floats at a time with AVX-512
130    let chunks = len / 16;
131    for i in 0..chunks {
132        let offset = i * 16;
133        let a_ptr = a.as_ptr().add(offset);
134        let b_ptr = b.as_ptr().add(offset);
135
136        let a_vec = _mm512_loadu_ps(a_ptr);
137        let b_vec = _mm512_loadu_ps(b_ptr);
138        // FMA is part of AVX-512, so we can use it directly
139        sum = _mm512_fmadd_ps(a_vec, b_vec, sum);
140    }
141
142    // Horizontal sum
143    let mut total = horizontal_sum_avx512(sum);
144
145    // Process remainder
146    for i in (chunks * 16)..len {
147        total += a[i] * b[i];
148    }
149
150    total
151}
152
153/// AVX-512 optimized cosine similarity (x86_64 only)
154#[cfg(target_arch = "x86_64")]
155#[target_feature(enable = "avx512f")]
156#[inline]
157unsafe fn cosine_similarity_avx512(a: &[f32], b: &[f32]) -> f32 {
158    let len = a.len();
159    let mut dot_sum = _mm512_setzero_ps();
160    let mut norm_a_sum = _mm512_setzero_ps();
161    let mut norm_b_sum = _mm512_setzero_ps();
162
163    // Process 16 floats at a time
164    let chunks = len / 16;
165    for i in 0..chunks {
166        let offset = i * 16;
167        let a_ptr = a.as_ptr().add(offset);
168        let b_ptr = b.as_ptr().add(offset);
169
170        let a_vec = _mm512_loadu_ps(a_ptr);
171        let b_vec = _mm512_loadu_ps(b_ptr);
172
173        // Use FMA for all three accumulations
174        dot_sum = _mm512_fmadd_ps(a_vec, b_vec, dot_sum);
175        norm_a_sum = _mm512_fmadd_ps(a_vec, a_vec, norm_a_sum);
176        norm_b_sum = _mm512_fmadd_ps(b_vec, b_vec, norm_b_sum);
177    }
178
179    // Horizontal sum
180    let mut dot = horizontal_sum_avx512(dot_sum);
181    let mut norm_a = horizontal_sum_avx512(norm_a_sum);
182    let mut norm_b = horizontal_sum_avx512(norm_b_sum);
183
184    // Process remainder
185    for i in (chunks * 16)..len {
186        dot += a[i] * b[i];
187        norm_a += a[i] * a[i];
188        norm_b += b[i] * b[i];
189    }
190
191    let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
192    dot / denominator
193}
194
195/// AVX-512 optimized Euclidean distance (x86_64 only)
196#[cfg(target_arch = "x86_64")]
197#[target_feature(enable = "avx512f")]
198#[inline]
199unsafe fn euclidean_distance_avx512(a: &[f32], b: &[f32]) -> f32 {
200    let len = a.len();
201    let mut sum_sq = _mm512_setzero_ps();
202
203    // Process 16 floats at a time
204    let chunks = len / 16;
205    for i in 0..chunks {
206        let offset = i * 16;
207        let a_ptr = a.as_ptr().add(offset);
208        let b_ptr = b.as_ptr().add(offset);
209
210        let a_vec = _mm512_loadu_ps(a_ptr);
211        let b_vec = _mm512_loadu_ps(b_ptr);
212        let diff = _mm512_sub_ps(a_vec, b_vec);
213        // FMA: sum_sq = diff * diff + sum_sq
214        sum_sq = _mm512_fmadd_ps(diff, diff, sum_sq);
215    }
216
217    // Horizontal sum
218    let mut total = horizontal_sum_avx512(sum_sq);
219
220    // Process remainder
221    for i in (chunks * 16)..len {
222        let diff = a[i] - b[i];
223        total += diff * diff;
224    }
225
226    total.sqrt()
227}
228
229/// AVX-512 optimized Manhattan distance (x86_64 only)
230#[cfg(target_arch = "x86_64")]
231#[target_feature(enable = "avx512f")]
232#[inline]
233unsafe fn manhattan_distance_avx512(a: &[f32], b: &[f32]) -> f32 {
234    let len = a.len();
235    let mut sum = _mm512_setzero_ps();
236
237    // Process 16 floats at a time
238    let chunks = len / 16;
239    for i in 0..chunks {
240        let offset = i * 16;
241        let a_ptr = a.as_ptr().add(offset);
242        let b_ptr = b.as_ptr().add(offset);
243
244        let a_vec = _mm512_loadu_ps(a_ptr);
245        let b_vec = _mm512_loadu_ps(b_ptr);
246        let diff = _mm512_sub_ps(a_vec, b_vec);
247        // abs(x) using built-in AVX-512 abs instruction
248        let abs_diff = _mm512_abs_ps(diff);
249        sum = _mm512_add_ps(sum, abs_diff);
250    }
251
252    // Horizontal sum
253    let mut total = horizontal_sum_avx512(sum);
254
255    // Process remainder
256    for i in (chunks * 16)..len {
257        total += (a[i] - b[i]).abs();
258    }
259
260    total
261}
262
263// ============================================================================
264// ARM NEON Explicit Intrinsics (aarch64 only)
265// ============================================================================
266
267/// Horizontal sum of 4 f32 values in a 128-bit NEON register
268#[cfg(target_arch = "aarch64")]
269#[inline]
270unsafe fn horizontal_sum_neon(v: float32x4_t) -> f32 {
271    // Pairwise add: [a0, a1, a2, a3] -> [a0+a1, a2+a3, a0+a1, a2+a3]
272    let pair_sum = vpaddq_f32(v, v);
273    // Add pairs: [a0+a1, a2+a3, ...] -> [a0+a1+a2+a3, ...]
274    let final_sum = vpaddq_f32(pair_sum, pair_sum);
275    // Extract the result
276    vgetq_lane_f32(final_sum, 0)
277}
278
279/// NEON-optimized dot product (aarch64 only)
280#[cfg(target_arch = "aarch64")]
281#[inline]
282unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
283    let len = a.len();
284    let mut sum = vdupq_n_f32(0.0);
285
286    // Process 4 floats at a time with NEON
287    let chunks = len / 4;
288    for i in 0..chunks {
289        let offset = i * 4;
290        let a_ptr = a.as_ptr().add(offset);
291        let b_ptr = b.as_ptr().add(offset);
292
293        let a_vec = vld1q_f32(a_ptr);
294        let b_vec = vld1q_f32(b_ptr);
295        // Multiply and add: sum = a * b + sum
296        sum = vmlaq_f32(sum, a_vec, b_vec);
297    }
298
299    // Horizontal sum
300    let mut total = horizontal_sum_neon(sum);
301
302    // Process remainder
303    for i in (chunks * 4)..len {
304        total += a[i] * b[i];
305    }
306
307    total
308}
309
310/// NEON-optimized cosine similarity (aarch64 only)
311#[cfg(target_arch = "aarch64")]
312#[inline]
313unsafe fn cosine_similarity_neon(a: &[f32], b: &[f32]) -> f32 {
314    let len = a.len();
315    let mut dot_sum = vdupq_n_f32(0.0);
316    let mut norm_a_sum = vdupq_n_f32(0.0);
317    let mut norm_b_sum = vdupq_n_f32(0.0);
318
319    // Process 4 floats at a time
320    let chunks = len / 4;
321    for i in 0..chunks {
322        let offset = i * 4;
323        let a_ptr = a.as_ptr().add(offset);
324        let b_ptr = b.as_ptr().add(offset);
325
326        let a_vec = vld1q_f32(a_ptr);
327        let b_vec = vld1q_f32(b_ptr);
328
329        // Multiply and add
330        dot_sum = vmlaq_f32(dot_sum, a_vec, b_vec);
331        norm_a_sum = vmlaq_f32(norm_a_sum, a_vec, a_vec);
332        norm_b_sum = vmlaq_f32(norm_b_sum, b_vec, b_vec);
333    }
334
335    // Horizontal sum
336    let mut dot = horizontal_sum_neon(dot_sum);
337    let mut norm_a = horizontal_sum_neon(norm_a_sum);
338    let mut norm_b = horizontal_sum_neon(norm_b_sum);
339
340    // Process remainder
341    for i in (chunks * 4)..len {
342        dot += a[i] * b[i];
343        norm_a += a[i] * a[i];
344        norm_b += b[i] * b[i];
345    }
346
347    let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
348    dot / denominator
349}
350
351/// NEON-optimized Euclidean distance (aarch64 only)
352#[cfg(target_arch = "aarch64")]
353#[inline]
354unsafe fn euclidean_distance_neon(a: &[f32], b: &[f32]) -> f32 {
355    let len = a.len();
356    let mut sum_sq = vdupq_n_f32(0.0);
357
358    // Process 4 floats at a time
359    let chunks = len / 4;
360    for i in 0..chunks {
361        let offset = i * 4;
362        let a_ptr = a.as_ptr().add(offset);
363        let b_ptr = b.as_ptr().add(offset);
364
365        let a_vec = vld1q_f32(a_ptr);
366        let b_vec = vld1q_f32(b_ptr);
367        let diff = vsubq_f32(a_vec, b_vec);
368        // Multiply and add: sum_sq = diff * diff + sum_sq
369        sum_sq = vmlaq_f32(sum_sq, diff, diff);
370    }
371
372    // Horizontal sum
373    let mut total = horizontal_sum_neon(sum_sq);
374
375    // Process remainder
376    for i in (chunks * 4)..len {
377        let diff = a[i] - b[i];
378        total += diff * diff;
379    }
380
381    total.sqrt()
382}
383
384/// NEON-optimized Manhattan distance (aarch64 only)
385#[cfg(target_arch = "aarch64")]
386#[inline]
387unsafe fn manhattan_distance_neon(a: &[f32], b: &[f32]) -> f32 {
388    let len = a.len();
389    let mut sum = vdupq_n_f32(0.0);
390
391    // Process 4 floats at a time
392    let chunks = len / 4;
393    for i in 0..chunks {
394        let offset = i * 4;
395        let a_ptr = a.as_ptr().add(offset);
396        let b_ptr = b.as_ptr().add(offset);
397
398        let a_vec = vld1q_f32(a_ptr);
399        let b_vec = vld1q_f32(b_ptr);
400        let diff = vsubq_f32(a_vec, b_vec);
401        let abs_diff = vabsq_f32(diff);
402        sum = vaddq_f32(sum, abs_diff);
403    }
404
405    // Horizontal sum
406    let mut total = horizontal_sum_neon(sum);
407
408    // Process remainder
409    for i in (chunks * 4)..len {
410        total += (a[i] - b[i]).abs();
411    }
412
413    total
414}
415
416// ============================================================================
417// AVX2 + FMA Explicit Intrinsics (x86_64 only)
418// ============================================================================
419
420/// FMA-optimized dot product (x86_64 only, requires FMA)
421#[cfg(target_arch = "x86_64")]
422#[target_feature(enable = "avx2,fma")]
423#[inline]
424unsafe fn dot_product_fma(a: &[f32], b: &[f32]) -> f32 {
425    let len = a.len();
426    let mut sum = _mm256_setzero_ps();
427
428    // Process 8 floats at a time with FMA
429    let chunks = len / 8;
430    for i in 0..chunks {
431        let offset = i * 8;
432        let a_ptr = a.as_ptr().add(offset);
433        let b_ptr = b.as_ptr().add(offset);
434
435        let a_vec = _mm256_loadu_ps(a_ptr);
436        let b_vec = _mm256_loadu_ps(b_ptr);
437        // FMA: sum = a * b + sum (single instruction!)
438        sum = _mm256_fmadd_ps(a_vec, b_vec, sum);
439    }
440
441    // Horizontal sum using optimized intrinsics
442    let mut total = horizontal_sum_avx2(sum);
443
444    // Process remainder
445    for i in (chunks * 8)..len {
446        total += a[i] * b[i];
447    }
448
449    total
450}
451
452/// FMA-optimized cosine similarity (x86_64 only, requires FMA)
453#[cfg(target_arch = "x86_64")]
454#[target_feature(enable = "avx2,fma")]
455#[inline]
456unsafe fn cosine_similarity_fma(a: &[f32], b: &[f32]) -> f32 {
457    let len = a.len();
458    let mut dot_sum = _mm256_setzero_ps();
459    let mut norm_a_sum = _mm256_setzero_ps();
460    let mut norm_b_sum = _mm256_setzero_ps();
461
462    // Process 8 floats at a time with FMA
463    let chunks = len / 8;
464    for i in 0..chunks {
465        let offset = i * 8;
466        let a_ptr = a.as_ptr().add(offset);
467        let b_ptr = b.as_ptr().add(offset);
468
469        let a_vec = _mm256_loadu_ps(a_ptr);
470        let b_vec = _mm256_loadu_ps(b_ptr);
471
472        // Use FMA for all three accumulations
473        dot_sum = _mm256_fmadd_ps(a_vec, b_vec, dot_sum);
474        norm_a_sum = _mm256_fmadd_ps(a_vec, a_vec, norm_a_sum);
475        norm_b_sum = _mm256_fmadd_ps(b_vec, b_vec, norm_b_sum);
476    }
477
478    // Horizontal sum
479    let mut dot = horizontal_sum_avx2(dot_sum);
480    let mut norm_a = horizontal_sum_avx2(norm_a_sum);
481    let mut norm_b = horizontal_sum_avx2(norm_b_sum);
482
483    // Process remainder
484    for i in (chunks * 8)..len {
485        dot += a[i] * b[i];
486        norm_a += a[i] * a[i];
487        norm_b += b[i] * b[i];
488    }
489
490    let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
491    dot / denominator
492}
493
494/// FMA-optimized Euclidean distance (x86_64 only, requires FMA)
495#[cfg(target_arch = "x86_64")]
496#[target_feature(enable = "avx2,fma")]
497#[inline]
498unsafe fn euclidean_distance_fma(a: &[f32], b: &[f32]) -> f32 {
499    let len = a.len();
500    let mut sum_sq = _mm256_setzero_ps();
501
502    // Process 8 floats at a time with FMA
503    let chunks = len / 8;
504    for i in 0..chunks {
505        let offset = i * 8;
506        let a_ptr = a.as_ptr().add(offset);
507        let b_ptr = b.as_ptr().add(offset);
508
509        let a_vec = _mm256_loadu_ps(a_ptr);
510        let b_vec = _mm256_loadu_ps(b_ptr);
511        let diff = _mm256_sub_ps(a_vec, b_vec);
512        // FMA: sum_sq = diff * diff + sum_sq
513        sum_sq = _mm256_fmadd_ps(diff, diff, sum_sq);
514    }
515
516    // Horizontal sum
517    let mut total = horizontal_sum_avx2(sum_sq);
518
519    // Process remainder
520    for i in (chunks * 8)..len {
521        let diff = a[i] - b[i];
522        total += diff * diff;
523    }
524
525    total.sqrt()
526}
527
528/// Horizontal sum of 8 f32 values in a 256-bit register (AVX2 helper)
529#[cfg(target_arch = "x86_64")]
530#[inline]
531unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
532    // v = [a0, a1, a2, a3, a4, a5, a6, a7]
533    // Extract high and low 128-bit lanes
534    let hi = _mm256_extractf128_ps(v, 1); // [a4, a5, a6, a7]
535    let lo = _mm256_castps256_ps128(v); // [a0, a1, a2, a3]
536
537    // Add high and low lanes
538    let sum128 = _mm_add_ps(lo, hi); // [a0+a4, a1+a5, a2+a6, a3+a7]
539
540    // Horizontal add twice to sum all 4 elements
541    let sum64 = _mm_hadd_ps(sum128, sum128); // [a0+a4+a1+a5, a2+a6+a3+a7, ...]
542    let sum32 = _mm_hadd_ps(sum64, sum64); // [sum_all, sum_all, ...]
543
544    // Extract the final sum
545    _mm_cvtss_f32(sum32)
546}
547
548/// AVX2-optimized dot product (x86_64 only)
549#[cfg(target_arch = "x86_64")]
550#[target_feature(enable = "avx2")]
551#[inline]
552unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
553    let len = a.len();
554    let mut sum = _mm256_setzero_ps();
555
556    // Process 8 floats at a time with AVX2
557    let chunks = len / 8;
558    for i in 0..chunks {
559        let offset = i * 8;
560        let a_ptr = a.as_ptr().add(offset);
561        let b_ptr = b.as_ptr().add(offset);
562
563        let a_vec = _mm256_loadu_ps(a_ptr);
564        let b_vec = _mm256_loadu_ps(b_ptr);
565        let mul = _mm256_mul_ps(a_vec, b_vec);
566        sum = _mm256_add_ps(sum, mul);
567    }
568
569    // Horizontal sum of 8 floats using optimized intrinsics
570    let mut total = horizontal_sum_avx2(sum);
571
572    // Process remainder
573    for i in (chunks * 8)..len {
574        total += a[i] * b[i];
575    }
576
577    total
578}
579
580/// AVX2-optimized cosine similarity (x86_64 only)
581#[cfg(target_arch = "x86_64")]
582#[target_feature(enable = "avx2")]
583#[inline]
584unsafe fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
585    let len = a.len();
586    let mut dot_sum = _mm256_setzero_ps();
587    let mut norm_a_sum = _mm256_setzero_ps();
588    let mut norm_b_sum = _mm256_setzero_ps();
589
590    // Process 8 floats at a time
591    let chunks = len / 8;
592    for i in 0..chunks {
593        let offset = i * 8;
594        let a_ptr = a.as_ptr().add(offset);
595        let b_ptr = b.as_ptr().add(offset);
596
597        let a_vec = _mm256_loadu_ps(a_ptr);
598        let b_vec = _mm256_loadu_ps(b_ptr);
599
600        dot_sum = _mm256_add_ps(dot_sum, _mm256_mul_ps(a_vec, b_vec));
601        norm_a_sum = _mm256_add_ps(norm_a_sum, _mm256_mul_ps(a_vec, a_vec));
602        norm_b_sum = _mm256_add_ps(norm_b_sum, _mm256_mul_ps(b_vec, b_vec));
603    }
604
605    // Horizontal sum using optimized intrinsics
606    let mut dot = horizontal_sum_avx2(dot_sum);
607    let mut norm_a = horizontal_sum_avx2(norm_a_sum);
608    let mut norm_b = horizontal_sum_avx2(norm_b_sum);
609
610    // Process remainder
611    for i in (chunks * 8)..len {
612        dot += a[i] * b[i];
613        norm_a += a[i] * a[i];
614        norm_b += b[i] * b[i];
615    }
616
617    let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
618    dot / denominator
619}
620
621/// AVX2-optimized Euclidean distance (x86_64 only)
622#[cfg(target_arch = "x86_64")]
623#[target_feature(enable = "avx2")]
624#[inline]
625unsafe fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
626    let len = a.len();
627    let mut sum_sq = _mm256_setzero_ps();
628
629    // Process 8 floats at a time
630    let chunks = len / 8;
631    for i in 0..chunks {
632        let offset = i * 8;
633        let a_ptr = a.as_ptr().add(offset);
634        let b_ptr = b.as_ptr().add(offset);
635
636        let a_vec = _mm256_loadu_ps(a_ptr);
637        let b_vec = _mm256_loadu_ps(b_ptr);
638        let diff = _mm256_sub_ps(a_vec, b_vec);
639        sum_sq = _mm256_add_ps(sum_sq, _mm256_mul_ps(diff, diff));
640    }
641
642    // Horizontal sum using optimized intrinsics
643    let mut total = horizontal_sum_avx2(sum_sq);
644
645    // Process remainder
646    for i in (chunks * 8)..len {
647        let diff = a[i] - b[i];
648        total += diff * diff;
649    }
650
651    total.sqrt()
652}
653
654/// AVX2-optimized Manhattan distance (x86_64 only)
655#[cfg(target_arch = "x86_64")]
656#[target_feature(enable = "avx2")]
657#[inline]
658unsafe fn manhattan_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
659    let len = a.len();
660    let mut sum = _mm256_setzero_ps();
661    let sign_mask = _mm256_set1_ps(-0.0); // Mask for abs
662
663    // Process 8 floats at a time
664    let chunks = len / 8;
665    for i in 0..chunks {
666        let offset = i * 8;
667        let a_ptr = a.as_ptr().add(offset);
668        let b_ptr = b.as_ptr().add(offset);
669
670        let a_vec = _mm256_loadu_ps(a_ptr);
671        let b_vec = _mm256_loadu_ps(b_ptr);
672        let diff = _mm256_sub_ps(a_vec, b_vec);
673        // abs(x) = andnot(sign_bit, x)
674        let abs_diff = _mm256_andnot_ps(sign_mask, diff);
675        sum = _mm256_add_ps(sum, abs_diff);
676    }
677
678    // Horizontal sum using optimized intrinsics
679    let mut total = horizontal_sum_avx2(sum);
680
681    // Process remainder
682    for i in (chunks * 8)..len {
683        total += (a[i] - b[i]).abs();
684    }
685
686    total
687}
688
689// ============================================================================
690// Auto-Vectorization Fallback Implementations
691// These functions are kept for testing and as fallbacks on platforms without SIMD
692// ============================================================================
693
694/// Auto-vectorization fallback for cosine similarity
695#[inline]
696#[allow(dead_code)]
697fn cosine_similarity_autovec(a: &[f32], b: &[f32]) -> f32 {
698    debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
699
700    // Use chunks for better vectorization
701    let chunk_size = 8; // Process 8 elements at a time for better SIMD utilization
702    let len = a.len();
703    let chunks = len / chunk_size;
704
705    let mut dot_product = 0.0f32;
706    let mut norm_a = 0.0f32;
707    let mut norm_b = 0.0f32;
708
709    // Process chunks (compiler will auto-vectorize this)
710    for i in 0..chunks {
711        let offset = i * chunk_size;
712        for j in 0..chunk_size {
713            let idx = offset + j;
714            let a_val = unsafe { *a.get_unchecked(idx) };
715            let b_val = unsafe { *b.get_unchecked(idx) };
716
717            dot_product += a_val * b_val;
718            norm_a += a_val * a_val;
719            norm_b += b_val * b_val;
720        }
721    }
722
723    // Process remainder
724    for i in (chunks * chunk_size)..len {
725        let a_val = unsafe { *a.get_unchecked(i) };
726        let b_val = unsafe { *b.get_unchecked(i) };
727
728        dot_product += a_val * b_val;
729        norm_a += a_val * a_val;
730        norm_b += b_val * b_val;
731    }
732
733    let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
734    dot_product / denominator
735}
736
737/// SIMD-optimized cosine similarity calculation
738///
739/// Automatically uses the best available SIMD implementation:
740/// - x86_64: AVX-512 → FMA+AVX2 → AVX2 → auto-vectorization
741/// - aarch64: NEON (always available)
742/// - other: auto-vectorization
743#[inline]
744pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
745    #[cfg(target_arch = "x86_64")]
746    {
747        if is_avx512_available() {
748            unsafe { cosine_similarity_avx512(a, b) }
749        } else if is_fma_available() {
750            unsafe { cosine_similarity_fma(a, b) }
751        } else if is_avx2_available() {
752            unsafe { cosine_similarity_avx2(a, b) }
753        } else {
754            cosine_similarity_autovec(a, b)
755        }
756    }
757    #[cfg(target_arch = "aarch64")]
758    {
759        unsafe { cosine_similarity_neon(a, b) }
760    }
761    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
762    {
763        cosine_similarity_autovec(a, b)
764    }
765}
766
767/// Auto-vectorization fallback for Euclidean distance
768#[inline]
769#[allow(dead_code)]
770fn euclidean_distance_autovec(a: &[f32], b: &[f32]) -> f32 {
771    debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
772
773    let chunk_size = 8;
774    let len = a.len();
775    let chunks = len / chunk_size;
776
777    let mut sum_sq = 0.0f32;
778
779    // Process chunks (compiler will auto-vectorize this)
780    for i in 0..chunks {
781        let offset = i * chunk_size;
782        for j in 0..chunk_size {
783            let idx = offset + j;
784            let diff = unsafe { *a.get_unchecked(idx) - *b.get_unchecked(idx) };
785            sum_sq += diff * diff;
786        }
787    }
788
789    // Process remainder
790    for i in (chunks * chunk_size)..len {
791        let diff = unsafe { *a.get_unchecked(i) - *b.get_unchecked(i) };
792        sum_sq += diff * diff;
793    }
794
795    sum_sq.sqrt()
796}
797
798/// SIMD-optimized Euclidean distance calculation
799///
800/// Automatically uses the best available SIMD implementation:
801/// - x86_64: AVX-512 → FMA+AVX2 → AVX2 → auto-vectorization
802/// - aarch64: NEON (always available)
803/// - other: auto-vectorization
804#[inline]
805pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
806    #[cfg(target_arch = "x86_64")]
807    {
808        if is_avx512_available() {
809            unsafe { euclidean_distance_avx512(a, b) }
810        } else if is_fma_available() {
811            unsafe { euclidean_distance_fma(a, b) }
812        } else if is_avx2_available() {
813            unsafe { euclidean_distance_avx2(a, b) }
814        } else {
815            euclidean_distance_autovec(a, b)
816        }
817    }
818    #[cfg(target_arch = "aarch64")]
819    {
820        unsafe { euclidean_distance_neon(a, b) }
821    }
822    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
823    {
824        euclidean_distance_autovec(a, b)
825    }
826}
827
828/// Auto-vectorization fallback for dot product
829#[inline]
830#[allow(dead_code)]
831fn dot_product_autovec(a: &[f32], b: &[f32]) -> f32 {
832    debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
833
834    let chunk_size = 8;
835    let len = a.len();
836    let chunks = len / chunk_size;
837
838    let mut dot = 0.0f32;
839
840    // Process chunks (compiler will auto-vectorize this)
841    for i in 0..chunks {
842        let offset = i * chunk_size;
843        for j in 0..chunk_size {
844            let idx = offset + j;
845            dot += unsafe { *a.get_unchecked(idx) * *b.get_unchecked(idx) };
846        }
847    }
848
849    // Process remainder
850    for i in (chunks * chunk_size)..len {
851        dot += unsafe { *a.get_unchecked(i) * *b.get_unchecked(i) };
852    }
853
854    dot
855}
856
857/// SIMD-optimized dot product calculation
858///
859/// Automatically uses the best available SIMD implementation:
860/// - x86_64: AVX-512 → FMA+AVX2 → AVX2 → auto-vectorization
861/// - aarch64: NEON (always available)
862/// - other: auto-vectorization
863#[inline]
864pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
865    #[cfg(target_arch = "x86_64")]
866    {
867        if is_avx512_available() {
868            unsafe { dot_product_avx512(a, b) }
869        } else if is_fma_available() {
870            unsafe { dot_product_fma(a, b) }
871        } else if is_avx2_available() {
872            unsafe { dot_product_avx2(a, b) }
873        } else {
874            dot_product_autovec(a, b)
875        }
876    }
877    #[cfg(target_arch = "aarch64")]
878    {
879        unsafe { dot_product_neon(a, b) }
880    }
881    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
882    {
883        dot_product_autovec(a, b)
884    }
885}
886
887/// Auto-vectorization fallback for Manhattan distance
888#[inline]
889#[allow(dead_code)]
890fn manhattan_distance_autovec(a: &[f32], b: &[f32]) -> f32 {
891    debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
892
893    let chunk_size = 8;
894    let len = a.len();
895    let chunks = len / chunk_size;
896
897    let mut sum = 0.0f32;
898
899    // Process chunks (compiler will auto-vectorize this)
900    for i in 0..chunks {
901        let offset = i * chunk_size;
902        for j in 0..chunk_size {
903            let idx = offset + j;
904            sum += unsafe { (*a.get_unchecked(idx) - *b.get_unchecked(idx)).abs() };
905        }
906    }
907
908    // Process remainder
909    for i in (chunks * chunk_size)..len {
910        sum += unsafe { (*a.get_unchecked(i) - *b.get_unchecked(i)).abs() };
911    }
912
913    sum
914}
915
916/// SIMD-optimized Manhattan distance calculation
917///
918/// Automatically uses the best available SIMD implementation:
919/// - x86_64: AVX-512 → AVX2 → auto-vectorization
920/// - aarch64: NEON (always available)
921/// - other: auto-vectorization
922#[inline]
923pub fn manhattan_distance_simd(a: &[f32], b: &[f32]) -> f32 {
924    #[cfg(target_arch = "x86_64")]
925    {
926        if is_avx512_available() {
927            unsafe { manhattan_distance_avx512(a, b) }
928        } else if is_avx2_available() {
929            unsafe { manhattan_distance_avx2(a, b) }
930        } else {
931            manhattan_distance_autovec(a, b)
932        }
933    }
934    #[cfg(target_arch = "aarch64")]
935    {
936        unsafe { manhattan_distance_neon(a, b) }
937    }
938    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
939    {
940        manhattan_distance_autovec(a, b)
941    }
942}
943
944/// Compute similarity/distance using the specified metric with SIMD optimization
945///
946/// Returns a score where higher is better (for VectorSearchIndex).
947pub fn compute_distance_simd(metric: DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
948    match metric {
949        DistanceMetric::Cosine => cosine_similarity_simd(a, b),
950        DistanceMetric::Euclidean => -euclidean_distance_simd(a, b),
951        DistanceMetric::DotProduct => dot_product_simd(a, b),
952        DistanceMetric::Manhattan => -manhattan_distance_simd(a, b),
953    }
954}
955
956/// Compute distance using the specified metric with SIMD optimization
957///
958/// Returns a distance where lower is better (for HNSW and other ANN algorithms).
959#[inline]
960pub fn compute_distance_lower_is_better_simd(metric: DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
961    match metric {
962        DistanceMetric::Cosine => {
963            // 1 - cosine similarity = cosine distance
964            1.0 - cosine_similarity_simd(a, b)
965        }
966        DistanceMetric::Euclidean => euclidean_distance_simd(a, b),
967        DistanceMetric::DotProduct => {
968            // Negative dot product (lower is better)
969            -dot_product_simd(a, b)
970        }
971        DistanceMetric::Manhattan => manhattan_distance_simd(a, b),
972    }
973}
974
975// ============================================================================
976// Quantized Vector Distance (u8/int8) - SIMD Optimized
977// ============================================================================
978
979/// Compute Manhattan distance between two quantized (u8) vectors using SIMD
980///
981/// This is significantly faster than converting to f32 and using regular distance.
982/// Optimized for scalar quantization (8-bit).
983#[inline]
984pub fn quantized_manhattan_distance_simd(a: &[u8], b: &[u8]) -> u32 {
985    assert_eq!(a.len(), b.len(), "Vector dimension mismatch");
986
987    #[cfg(target_arch = "x86_64")]
988    {
989        if is_avx2_available() {
990            return unsafe { quantized_manhattan_distance_avx2(a, b) };
991        }
992    }
993
994    #[cfg(target_arch = "aarch64")]
995    {
996        return unsafe { quantized_manhattan_distance_neon(a, b) };
997    }
998
999    // Fallback: scalar implementation
1000    quantized_manhattan_distance_scalar(a, b)
1001}
1002
1003/// Compute dot product between two quantized (u8) vectors using SIMD
1004///
1005/// Returns u32 to avoid overflow (max value = 255*255*len).
1006/// For normalized comparison, you may need to convert to f32 afterward.
1007#[inline]
1008pub fn quantized_dot_product_simd(a: &[u8], b: &[u8]) -> u32 {
1009    assert_eq!(a.len(), b.len(), "Vector dimension mismatch");
1010
1011    #[cfg(target_arch = "x86_64")]
1012    {
1013        if is_avx2_available() {
1014            return unsafe { quantized_dot_product_avx2(a, b) };
1015        }
1016    }
1017
1018    #[cfg(target_arch = "aarch64")]
1019    {
1020        return unsafe { quantized_dot_product_neon(a, b) };
1021    }
1022
1023    // Fallback: scalar implementation
1024    quantized_dot_product_scalar(a, b)
1025}
1026
1027/// Compute Euclidean distance between two quantized (u8) vectors using SIMD
1028///
1029/// Returns the squared distance to avoid sqrt overhead.
1030/// If you need actual distance, take sqrt of the result.
1031#[inline]
1032pub fn quantized_euclidean_squared_simd(a: &[u8], b: &[u8]) -> u32 {
1033    assert_eq!(a.len(), b.len(), "Vector dimension mismatch");
1034
1035    #[cfg(target_arch = "x86_64")]
1036    {
1037        if is_avx2_available() {
1038            return unsafe { quantized_euclidean_squared_avx2(a, b) };
1039        }
1040    }
1041
1042    #[cfg(target_arch = "aarch64")]
1043    {
1044        return unsafe { quantized_euclidean_squared_neon(a, b) };
1045    }
1046
1047    // Fallback: scalar implementation
1048    quantized_euclidean_squared_scalar(a, b)
1049}
1050
1051// ============================================================================
1052// Scalar implementations (fallback)
1053// ============================================================================
1054
1055#[inline]
1056fn quantized_manhattan_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
1057    a.iter()
1058        .zip(b.iter())
1059        .map(|(&x, &y)| (x as i32 - y as i32).unsigned_abs())
1060        .sum()
1061}
1062
1063#[inline]
1064fn quantized_dot_product_scalar(a: &[u8], b: &[u8]) -> u32 {
1065    a.iter()
1066        .zip(b.iter())
1067        .map(|(&x, &y)| x as u32 * y as u32)
1068        .sum()
1069}
1070
1071#[inline]
1072fn quantized_euclidean_squared_scalar(a: &[u8], b: &[u8]) -> u32 {
1073    a.iter()
1074        .zip(b.iter())
1075        .map(|(&x, &y)| {
1076            let diff = x as i32 - y as i32;
1077            (diff * diff) as u32
1078        })
1079        .sum()
1080}
1081
1082// ============================================================================
1083// AVX2 implementations (x86_64)
1084// ============================================================================
1085
1086#[cfg(target_arch = "x86_64")]
1087#[target_feature(enable = "avx2")]
1088#[inline]
1089unsafe fn quantized_manhattan_distance_avx2(a: &[u8], b: &[u8]) -> u32 {
1090    let len = a.len();
1091    let mut sum = _mm256_setzero_si256();
1092
1093    let mut i = 0;
1094    // Process 32 bytes at a time with AVX2
1095    while i + 32 <= len {
1096        let va = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
1097        let vb = _mm256_loadu_si256(b.as_ptr().add(i) as *const __m256i);
1098
1099        // Compute absolute difference using unsigned saturation trick
1100        let diff1 = _mm256_subs_epu8(va, vb);
1101        let diff2 = _mm256_subs_epu8(vb, va);
1102        let abs_diff = _mm256_or_si256(diff1, diff2);
1103
1104        // Extend to 16-bit to avoid overflow in horizontal sum
1105        let abs_diff_lo = _mm256_unpacklo_epi8(abs_diff, _mm256_setzero_si256());
1106        let abs_diff_hi = _mm256_unpackhi_epi8(abs_diff, _mm256_setzero_si256());
1107
1108        // Add to accumulator
1109        sum = _mm256_add_epi16(sum, abs_diff_lo);
1110        sum = _mm256_add_epi16(sum, abs_diff_hi);
1111
1112        i += 32;
1113    }
1114
1115    // Horizontal sum of 16-bit values
1116    let sum_lo = _mm256_unpacklo_epi16(sum, _mm256_setzero_si256());
1117    let sum_hi = _mm256_unpackhi_epi16(sum, _mm256_setzero_si256());
1118    let sum32 = _mm256_add_epi32(sum_lo, sum_hi);
1119
1120    // Extract and sum all lanes
1121    let mut result_arr = [0u32; 8];
1122    _mm256_storeu_si256(result_arr.as_mut_ptr() as *mut __m256i, sum32);
1123    let mut result: u32 = result_arr.iter().sum();
1124
1125    // Handle remaining elements
1126    while i < len {
1127        result += (a[i] as i32 - b[i] as i32).unsigned_abs();
1128        i += 1;
1129    }
1130
1131    result
1132}
1133
1134#[cfg(target_arch = "x86_64")]
1135#[target_feature(enable = "avx2")]
1136#[inline]
1137unsafe fn quantized_dot_product_avx2(a: &[u8], b: &[u8]) -> u32 {
1138    let len = a.len();
1139    let mut sum = _mm256_setzero_si256();
1140
1141    let mut i = 0;
1142    // Process 16 bytes at a time (to fit in 16-bit accumulators)
1143    while i + 16 <= len {
1144        // Load 16 bytes from each vector
1145        let va_128 = _mm_loadu_si128(a.as_ptr().add(i) as *const __m128i);
1146        let vb_128 = _mm_loadu_si128(b.as_ptr().add(i) as *const __m128i);
1147
1148        // Extend to 256-bit
1149        let va = _mm256_cvtepu8_epi16(va_128);
1150        let vb = _mm256_cvtepu8_epi16(vb_128);
1151
1152        // Multiply: 16-bit * 16-bit = 32-bit (using _mm256_madd_epi16)
1153        let prod = _mm256_madd_epi16(va, vb);
1154        sum = _mm256_add_epi32(sum, prod);
1155
1156        i += 16;
1157    }
1158
1159    // Extract and sum all lanes
1160    let mut result_arr = [0u32; 8];
1161    _mm256_storeu_si256(result_arr.as_mut_ptr() as *mut __m256i, sum);
1162    let mut result: u32 = result_arr.iter().sum();
1163
1164    // Handle remaining elements
1165    while i < len {
1166        result += a[i] as u32 * b[i] as u32;
1167        i += 1;
1168    }
1169
1170    result
1171}
1172
1173#[cfg(target_arch = "x86_64")]
1174#[target_feature(enable = "avx2")]
1175#[inline]
1176unsafe fn quantized_euclidean_squared_avx2(a: &[u8], b: &[u8]) -> u32 {
1177    let len = a.len();
1178    let mut sum = _mm256_setzero_si256();
1179
1180    let mut i = 0;
1181    // Process 16 bytes at a time (to fit in 16-bit intermediate results)
1182    while i + 16 <= len {
1183        let va_128 = _mm_loadu_si128(a.as_ptr().add(i) as *const __m128i);
1184        let vb_128 = _mm_loadu_si128(b.as_ptr().add(i) as *const __m128i);
1185
1186        // Extend to 16-bit signed
1187        let va = _mm256_cvtepu8_epi16(va_128);
1188        let vb = _mm256_cvtepu8_epi16(vb_128);
1189
1190        // Compute difference
1191        let diff = _mm256_sub_epi16(va, vb);
1192
1193        // Square using _mm256_madd_epi16 (diff * diff)
1194        let squared = _mm256_madd_epi16(diff, diff);
1195        sum = _mm256_add_epi32(sum, squared);
1196
1197        i += 16;
1198    }
1199
1200    // Extract and sum all lanes
1201    let mut result_arr = [0u32; 8];
1202    _mm256_storeu_si256(result_arr.as_mut_ptr() as *mut __m256i, sum);
1203    let mut result: u32 = result_arr.iter().sum();
1204
1205    // Handle remaining elements
1206    while i < len {
1207        let diff = a[i] as i32 - b[i] as i32;
1208        result += (diff * diff) as u32;
1209        i += 1;
1210    }
1211
1212    result
1213}
1214
1215// ============================================================================
1216// NEON implementations (aarch64)
1217// ============================================================================
1218
1219#[cfg(target_arch = "aarch64")]
1220#[inline]
1221unsafe fn quantized_manhattan_distance_neon(a: &[u8], b: &[u8]) -> u32 {
1222    let len = a.len();
1223    let mut sum = vdupq_n_u32(0);
1224
1225    let mut i = 0;
1226    // Process 16 bytes at a time
1227    while i + 16 <= len {
1228        let va = vld1q_u8(a.as_ptr().add(i));
1229        let vb = vld1q_u8(b.as_ptr().add(i));
1230
1231        // Compute absolute difference
1232        let abs_diff = vabdq_u8(va, vb);
1233
1234        // Extend to 16-bit and accumulate
1235        let abs_diff_lo = vmovl_u8(vget_low_u8(abs_diff));
1236        let abs_diff_hi = vmovl_u8(vget_high_u8(abs_diff));
1237
1238        // Accumulate into 32-bit
1239        sum = vaddw_u16(sum, vget_low_u16(abs_diff_lo));
1240        sum = vaddw_u16(sum, vget_high_u16(abs_diff_lo));
1241        sum = vaddw_u16(sum, vget_low_u16(abs_diff_hi));
1242        sum = vaddw_u16(sum, vget_high_u16(abs_diff_hi));
1243
1244        i += 16;
1245    }
1246
1247    // Horizontal sum
1248    let mut result = vaddvq_u32(sum);
1249
1250    // Handle remaining elements
1251    while i < len {
1252        result += (a[i] as i32 - b[i] as i32).unsigned_abs();
1253        i += 1;
1254    }
1255
1256    result
1257}
1258
1259#[cfg(target_arch = "aarch64")]
1260#[inline]
1261unsafe fn quantized_dot_product_neon(a: &[u8], b: &[u8]) -> u32 {
1262    let len = a.len();
1263    let mut sum = vdupq_n_u32(0);
1264
1265    let mut i = 0;
1266    // Process 8 bytes at a time (to avoid overflow)
1267    while i + 8 <= len {
1268        let va = vld1_u8(a.as_ptr().add(i));
1269        let vb = vld1_u8(b.as_ptr().add(i));
1270
1271        // Extend to 16-bit
1272        let va_16 = vmovl_u8(va);
1273        let vb_16 = vmovl_u8(vb);
1274
1275        // Multiply and accumulate
1276        let prod = vmull_u16(vget_low_u16(va_16), vget_low_u16(vb_16));
1277        sum = vaddq_u32(sum, prod);
1278
1279        let prod_hi = vmull_u16(vget_high_u16(va_16), vget_high_u16(vb_16));
1280        sum = vaddq_u32(sum, prod_hi);
1281
1282        i += 8;
1283    }
1284
1285    // Horizontal sum
1286    let mut result = vaddvq_u32(sum);
1287
1288    // Handle remaining elements
1289    while i < len {
1290        result += a[i] as u32 * b[i] as u32;
1291        i += 1;
1292    }
1293
1294    result
1295}
1296
1297#[cfg(target_arch = "aarch64")]
1298#[inline]
1299unsafe fn quantized_euclidean_squared_neon(a: &[u8], b: &[u8]) -> u32 {
1300    let len = a.len();
1301    let mut sum = vdupq_n_u32(0);
1302
1303    let mut i = 0;
1304    // Process 8 bytes at a time
1305    while i + 8 <= len {
1306        let va = vld1_u8(a.as_ptr().add(i));
1307        let vb = vld1_u8(b.as_ptr().add(i));
1308
1309        // Compute absolute difference and extend to 16-bit
1310        let abs_diff = vabd_u8(va, vb);
1311        let diff_16 = vmovl_u8(abs_diff);
1312
1313        // Square and accumulate
1314        let squared = vmull_u16(vget_low_u16(diff_16), vget_low_u16(diff_16));
1315        sum = vaddq_u32(sum, squared);
1316
1317        let squared_hi = vmull_u16(vget_high_u16(diff_16), vget_high_u16(diff_16));
1318        sum = vaddq_u32(sum, squared_hi);
1319
1320        i += 8;
1321    }
1322
1323    // Horizontal sum
1324    let mut result = vaddvq_u32(sum);
1325
1326    // Handle remaining elements
1327    while i < len {
1328        let diff = a[i] as i32 - b[i] as i32;
1329        result += (diff * diff) as u32;
1330        i += 1;
1331    }
1332
1333    result
1334}
1335
1336// ============================================================================
1337// Vector Normalization (SIMD-optimized)
1338// ============================================================================
1339
1340/// Normalize a vector in-place using SIMD optimization
1341///
1342/// Computes the L2 norm and divides each element by it.
1343/// Uses SIMD-optimized dot product for norm calculation.
1344#[inline]
1345pub fn normalize_vector_simd(vec: &mut [f32]) {
1346    // Compute L2 norm using SIMD dot product
1347    let norm_squared = dot_product_simd(vec, vec);
1348    let norm = norm_squared.sqrt();
1349
1350    if norm > 1e-10 {
1351        let inv_norm = 1.0 / norm;
1352        scale_vector_simd(vec, inv_norm);
1353    }
1354}
1355
1356/// Scale a vector by a constant using SIMD optimization
1357///
1358/// Multiplies each element by the given scalar.
1359#[inline]
1360pub fn scale_vector_simd(vec: &mut [f32], scalar: f32) {
1361    #[cfg(target_arch = "x86_64")]
1362    {
1363        if is_avx512_available() {
1364            unsafe {
1365                scale_vector_avx512(vec, scalar);
1366            }
1367            return;
1368        }
1369        if is_avx2_available() {
1370            unsafe {
1371                scale_vector_avx2(vec, scalar);
1372            }
1373            return;
1374        }
1375    }
1376
1377    #[cfg(target_arch = "aarch64")]
1378    {
1379        unsafe {
1380            scale_vector_neon(vec, scalar);
1381        }
1382        return;
1383    }
1384
1385    // Fallback to scalar implementation with auto-vectorization hints
1386    for x in vec.iter_mut() {
1387        *x *= scalar;
1388    }
1389}
1390
1391#[cfg(target_arch = "x86_64")]
1392#[target_feature(enable = "avx512f")]
1393#[inline]
1394unsafe fn scale_vector_avx512(vec: &mut [f32], scalar: f32) {
1395    let len = vec.len();
1396    let scalar_vec = _mm512_set1_ps(scalar);
1397    let mut i = 0;
1398
1399    // Process 16 floats at a time
1400    while i + 16 <= len {
1401        let ptr = vec.as_mut_ptr().add(i);
1402        let v = _mm512_loadu_ps(ptr);
1403        let scaled = _mm512_mul_ps(v, scalar_vec);
1404        _mm512_storeu_ps(ptr, scaled);
1405        i += 16;
1406    }
1407
1408    // Handle remainder
1409    while i < len {
1410        vec[i] *= scalar;
1411        i += 1;
1412    }
1413}
1414
1415#[cfg(target_arch = "x86_64")]
1416#[target_feature(enable = "avx2")]
1417#[inline]
1418unsafe fn scale_vector_avx2(vec: &mut [f32], scalar: f32) {
1419    let len = vec.len();
1420    let scalar_vec = _mm256_set1_ps(scalar);
1421    let mut i = 0;
1422
1423    // Process 8 floats at a time
1424    while i + 8 <= len {
1425        let ptr = vec.as_mut_ptr().add(i);
1426        let v = _mm256_loadu_ps(ptr);
1427        let scaled = _mm256_mul_ps(v, scalar_vec);
1428        _mm256_storeu_ps(ptr, scaled);
1429        i += 8;
1430    }
1431
1432    // Handle remainder
1433    while i < len {
1434        vec[i] *= scalar;
1435        i += 1;
1436    }
1437}
1438
1439#[cfg(target_arch = "aarch64")]
1440#[inline]
1441unsafe fn scale_vector_neon(vec: &mut [f32], scalar: f32) {
1442    let len = vec.len();
1443    let scalar_vec = vdupq_n_f32(scalar);
1444    let mut i = 0;
1445
1446    // Process 4 floats at a time
1447    while i + 4 <= len {
1448        let ptr = vec.as_mut_ptr().add(i);
1449        let v = vld1q_f32(ptr);
1450        let scaled = vmulq_f32(v, scalar_vec);
1451        vst1q_f32(ptr, scaled);
1452        i += 4;
1453    }
1454
1455    // Handle remainder
1456    while i < len {
1457        vec[i] *= scalar;
1458        i += 1;
1459    }
1460}
1461
1462#[cfg(test)]
1463mod tests {
1464    use super::*;
1465
1466    #[test]
1467    fn test_cosine_similarity_simd() {
1468        let v1 = vec![1.0, 0.0, 0.0];
1469        let v2 = vec![1.0, 0.0, 0.0];
1470        let sim = cosine_similarity_simd(&v1, &v2);
1471        assert!((sim - 1.0).abs() < 1e-6);
1472
1473        let v1 = vec![1.0, 0.0, 0.0];
1474        let v2 = vec![0.0, 1.0, 0.0];
1475        let sim = cosine_similarity_simd(&v1, &v2);
1476        assert!(sim.abs() < 1e-6);
1477    }
1478
1479    #[test]
1480    fn test_cosine_similarity_simd_large() {
1481        // Test with vectors larger than chunk size
1482        let v1: Vec<f32> = (0..100).map(|i| i as f32).collect();
1483        let v2: Vec<f32> = (0..100).map(|i| (i + 1) as f32).collect();
1484        let sim = cosine_similarity_simd(&v1, &v2);
1485        assert!(sim > 0.99); // Highly correlated
1486    }
1487
1488    #[test]
1489    fn test_euclidean_distance_simd() {
1490        let v1 = vec![0.0, 0.0, 0.0];
1491        let v2 = vec![3.0, 4.0, 0.0];
1492        let dist = euclidean_distance_simd(&v1, &v2);
1493        assert!((dist - 5.0).abs() < 1e-6);
1494    }
1495
1496    #[test]
1497    fn test_euclidean_distance_simd_large() {
1498        let v1 = vec![0.0; 100];
1499        let v2 = vec![1.0; 100];
1500        let dist = euclidean_distance_simd(&v1, &v2);
1501        assert!((dist - 10.0).abs() < 1e-6); // sqrt(100)
1502    }
1503
1504    #[test]
1505    fn test_dot_product_simd() {
1506        let v1 = vec![1.0, 2.0, 3.0];
1507        let v2 = vec![4.0, 5.0, 6.0];
1508        let dot = dot_product_simd(&v1, &v2);
1509        assert!((dot - 32.0).abs() < 1e-6); // 1*4 + 2*5 + 3*6 = 32
1510    }
1511
1512    #[test]
1513    fn test_dot_product_simd_large() {
1514        let v1: Vec<f32> = (1..=100).map(|i| i as f32).collect();
1515        let v2: Vec<f32> = (1..=100).map(|i| i as f32).collect();
1516        let dot = dot_product_simd(&v1, &v2);
1517        let expected: f32 = (1..=100).map(|i| (i * i) as f32).sum();
1518        assert!((dot - expected).abs() < 1e-3);
1519    }
1520
1521    #[test]
1522    fn test_manhattan_distance_simd() {
1523        let v1 = vec![1.0, 2.0, 3.0];
1524        let v2 = vec![4.0, 5.0, 6.0];
1525        let dist = manhattan_distance_simd(&v1, &v2);
1526        assert!((dist - 9.0).abs() < 1e-6); // |1-4| + |2-5| + |3-6| = 9
1527    }
1528
1529    #[test]
1530    fn test_manhattan_distance_simd_large() {
1531        let v1 = vec![0.0; 100];
1532        let v2 = vec![1.0; 100];
1533        let dist = manhattan_distance_simd(&v1, &v2);
1534        assert!((dist - 100.0).abs() < 1e-6);
1535    }
1536
1537    #[test]
1538    fn test_compute_distance_simd() {
1539        let v1 = vec![1.0, 0.0, 0.0];
1540        let v2 = vec![1.0, 0.0, 0.0];
1541
1542        let sim = compute_distance_simd(DistanceMetric::Cosine, &v1, &v2);
1543        assert!((sim - 1.0).abs() < 1e-6);
1544
1545        let dist = compute_distance_simd(DistanceMetric::Euclidean, &v1, &v2);
1546        assert!(dist.abs() < 1e-6); // Distance is 0, but returned as -0.0
1547
1548        let dot = compute_distance_simd(DistanceMetric::DotProduct, &v1, &v2);
1549        assert!((dot - 1.0).abs() < 1e-6);
1550
1551        let manhattan = compute_distance_simd(DistanceMetric::Manhattan, &v1, &v2);
1552        assert!(manhattan.abs() < 1e-6);
1553    }
1554
1555    #[test]
1556    fn test_is_avx2_available() {
1557        // Test that the function doesn't panic
1558        let _available = is_avx2_available();
1559        // On x86_64, it should detect AVX2 support (or not)
1560        // On other architectures, it should always return false
1561        #[cfg(not(target_arch = "x86_64"))]
1562        assert!(!is_avx2_available());
1563    }
1564
1565    #[test]
1566    fn test_is_neon_available() {
1567        // Test that the function doesn't panic
1568        let available = is_neon_available();
1569        // On aarch64, NEON is always available (mandatory feature)
1570        #[cfg(target_arch = "aarch64")]
1571        assert!(available, "NEON should always be available on aarch64");
1572        // On other architectures, it should always return false
1573        #[cfg(not(target_arch = "aarch64"))]
1574        assert!(!available, "NEON should not be available on non-aarch64");
1575    }
1576
1577    #[test]
1578    fn test_is_avx512_available() {
1579        // Test that the function doesn't panic
1580        let _available = is_avx512_available();
1581        // On x86_64, it should detect AVX-512 support (or not)
1582        // On other architectures, it should always return false
1583        #[cfg(not(target_arch = "x86_64"))]
1584        assert!(!is_avx512_available());
1585    }
1586
1587    #[test]
1588    fn test_avx2_correctness() {
1589        // Test that AVX2 implementations give same results as auto-vectorized ones
1590        let v1: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
1591        let v2: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02).collect();
1592
1593        let cosine = cosine_similarity_simd(&v1, &v2);
1594        let euclidean = euclidean_distance_simd(&v1, &v2);
1595        let dot = dot_product_simd(&v1, &v2);
1596        let manhattan = manhattan_distance_simd(&v1, &v2);
1597
1598        // Verify results are reasonable
1599        assert!(cosine > 0.0 && cosine <= 1.0);
1600        assert!(euclidean > 0.0);
1601        assert!(dot > 0.0);
1602        assert!(manhattan > 0.0);
1603
1604        // Compare with autovec versions
1605        let cosine_autovec = cosine_similarity_autovec(&v1, &v2);
1606        let euclidean_autovec = euclidean_distance_autovec(&v1, &v2);
1607        let dot_autovec = dot_product_autovec(&v1, &v2);
1608        let manhattan_autovec = manhattan_distance_autovec(&v1, &v2);
1609
1610        // Use relative error for large values
1611        let relative_error = |a: f32, b: f32| (a - b).abs() / a.max(b).max(1.0);
1612        assert!(relative_error(cosine, cosine_autovec) < 1e-5);
1613        assert!(relative_error(euclidean, euclidean_autovec) < 1e-5);
1614        assert!(relative_error(dot, dot_autovec) < 1e-5);
1615        assert!(relative_error(manhattan, manhattan_autovec) < 1e-5);
1616    }
1617
1618    #[test]
1619    fn test_neon_correctness() {
1620        // Test that NEON implementations give same results as auto-vectorized ones
1621        let v1: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
1622        let v2: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02).collect();
1623
1624        let cosine = cosine_similarity_simd(&v1, &v2);
1625        let euclidean = euclidean_distance_simd(&v1, &v2);
1626        let dot = dot_product_simd(&v1, &v2);
1627        let manhattan = manhattan_distance_simd(&v1, &v2);
1628
1629        // Verify results are reasonable
1630        assert!(cosine > 0.0 && cosine <= 1.0);
1631        assert!(euclidean > 0.0);
1632        assert!(dot > 0.0);
1633        assert!(manhattan > 0.0);
1634
1635        // Compare with autovec versions
1636        let cosine_autovec = cosine_similarity_autovec(&v1, &v2);
1637        let euclidean_autovec = euclidean_distance_autovec(&v1, &v2);
1638        let dot_autovec = dot_product_autovec(&v1, &v2);
1639        let manhattan_autovec = manhattan_distance_autovec(&v1, &v2);
1640
1641        // Use relative error for large values
1642        let relative_error = |a: f32, b: f32| (a - b).abs() / a.max(b).max(1.0);
1643        assert!(relative_error(cosine, cosine_autovec) < 1e-5);
1644        assert!(relative_error(euclidean, euclidean_autovec) < 1e-5);
1645        assert!(relative_error(dot, dot_autovec) < 1e-5);
1646        assert!(relative_error(manhattan, manhattan_autovec) < 1e-5);
1647    }
1648
1649    #[test]
1650    fn test_avx512_correctness() {
1651        // Test that AVX-512 implementations give same results as auto-vectorized ones
1652        let v1: Vec<f32> = (0..1024).map(|i| (i as f32) * 0.01).collect();
1653        let v2: Vec<f32> = (0..1024).map(|i| (i as f32) * 0.02).collect();
1654
1655        let cosine = cosine_similarity_simd(&v1, &v2);
1656        let euclidean = euclidean_distance_simd(&v1, &v2);
1657        let dot = dot_product_simd(&v1, &v2);
1658        let manhattan = manhattan_distance_simd(&v1, &v2);
1659
1660        // Verify results are reasonable
1661        assert!(cosine > 0.0 && cosine <= 1.0);
1662        assert!(euclidean > 0.0);
1663        assert!(dot > 0.0);
1664        assert!(manhattan > 0.0);
1665
1666        // Compare with autovec versions
1667        let cosine_autovec = cosine_similarity_autovec(&v1, &v2);
1668        let euclidean_autovec = euclidean_distance_autovec(&v1, &v2);
1669        let dot_autovec = dot_product_autovec(&v1, &v2);
1670        let manhattan_autovec = manhattan_distance_autovec(&v1, &v2);
1671
1672        // Use relative error for large values
1673        let relative_error = |a: f32, b: f32| (a - b).abs() / a.max(b).max(1.0);
1674        assert!(relative_error(cosine, cosine_autovec) < 1e-5);
1675        assert!(relative_error(euclidean, euclidean_autovec) < 1e-5);
1676        assert!(relative_error(dot, dot_autovec) < 1e-5);
1677        assert!(relative_error(manhattan, manhattan_autovec) < 1e-5);
1678    }
1679
1680    #[test]
1681    fn test_quantized_manhattan_distance() {
1682        // Test quantized Manhattan distance
1683        let a = vec![10u8, 20, 30, 40, 50, 60, 70, 80];
1684        let b = vec![15u8, 25, 35, 45, 55, 65, 75, 85];
1685
1686        let distance_simd = quantized_manhattan_distance_simd(&a, &b);
1687        let distance_scalar = quantized_manhattan_distance_scalar(&a, &b);
1688
1689        assert_eq!(distance_simd, distance_scalar);
1690        assert_eq!(distance_simd, 40); // |10-15| + |20-25| + ... = 5*8 = 40
1691    }
1692
1693    #[test]
1694    fn test_quantized_manhattan_distance_large() {
1695        // Test with larger vectors (768 dimensions)
1696        let a: Vec<u8> = (0..768).map(|i| (i % 256) as u8).collect();
1697        let b: Vec<u8> = (0..768).map(|i| ((i + 10) % 256) as u8).collect();
1698
1699        let distance_simd = quantized_manhattan_distance_simd(&a, &b);
1700        let distance_scalar = quantized_manhattan_distance_scalar(&a, &b);
1701
1702        assert_eq!(distance_simd, distance_scalar);
1703    }
1704
1705    #[test]
1706    fn test_quantized_dot_product() {
1707        // Test quantized dot product
1708        let a = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
1709        let b = vec![8u8, 7, 6, 5, 4, 3, 2, 1];
1710
1711        let dot_simd = quantized_dot_product_simd(&a, &b);
1712        let dot_scalar = quantized_dot_product_scalar(&a, &b);
1713
1714        assert_eq!(dot_simd, dot_scalar);
1715        // 1*8 + 2*7 + 3*6 + 4*5 + 5*4 + 6*3 + 7*2 + 8*1 = 8+14+18+20+20+18+14+8 = 120
1716        assert_eq!(dot_simd, 120);
1717    }
1718
1719    #[test]
1720    fn test_quantized_dot_product_large() {
1721        // Test with larger vectors (768 dimensions)
1722        let a: Vec<u8> = (0..768).map(|i| (i % 256) as u8).collect();
1723        let b: Vec<u8> = (0..768).map(|i| ((255 - i) % 256) as u8).collect();
1724
1725        let dot_simd = quantized_dot_product_simd(&a, &b);
1726        let dot_scalar = quantized_dot_product_scalar(&a, &b);
1727
1728        assert_eq!(dot_simd, dot_scalar);
1729    }
1730
1731    #[test]
1732    fn test_quantized_euclidean_squared() {
1733        // Test quantized Euclidean distance (squared)
1734        let a = vec![10u8, 20, 30, 40];
1735        let b = vec![13u8, 24, 27, 45];
1736
1737        let dist_simd = quantized_euclidean_squared_simd(&a, &b);
1738        let dist_scalar = quantized_euclidean_squared_scalar(&a, &b);
1739
1740        assert_eq!(dist_simd, dist_scalar);
1741        // (10-13)^2 + (20-24)^2 + (30-27)^2 + (40-45)^2 = 9 + 16 + 9 + 25 = 59
1742        assert_eq!(dist_simd, 59);
1743    }
1744
1745    #[test]
1746    fn test_quantized_euclidean_squared_large() {
1747        // Test with larger vectors (768 dimensions)
1748        let a: Vec<u8> = (0..768).map(|i| (i % 256) as u8).collect();
1749        let b: Vec<u8> = (0..768).map(|i| ((i + 5) % 256) as u8).collect();
1750
1751        let dist_simd = quantized_euclidean_squared_simd(&a, &b);
1752        let dist_scalar = quantized_euclidean_squared_scalar(&a, &b);
1753
1754        assert_eq!(dist_simd, dist_scalar);
1755    }
1756
1757    #[test]
1758    fn test_quantized_edge_cases() {
1759        // Test with identical vectors
1760        let a = vec![100u8; 100];
1761        let b = vec![100u8; 100];
1762
1763        assert_eq!(quantized_manhattan_distance_simd(&a, &b), 0);
1764        assert_eq!(quantized_euclidean_squared_simd(&a, &b), 0);
1765
1766        // Test with maximum difference
1767        let c = vec![0u8; 100];
1768        let d = vec![255u8; 100];
1769
1770        assert_eq!(quantized_manhattan_distance_simd(&c, &d), 255 * 100);
1771        assert_eq!(quantized_euclidean_squared_simd(&c, &d), 255 * 255 * 100);
1772    }
1773
1774    #[test]
1775    fn test_quantized_simd_correctness() {
1776        // Comprehensive correctness test with random-like values
1777        let a: Vec<u8> = (0..1024).map(|i| ((i * 17 + 42) % 256) as u8).collect();
1778        let b: Vec<u8> = (0..1024).map(|i| ((i * 23 + 99) % 256) as u8).collect();
1779
1780        // All SIMD implementations should match scalar
1781        let manhattan_simd = quantized_manhattan_distance_simd(&a, &b);
1782        let manhattan_scalar = quantized_manhattan_distance_scalar(&a, &b);
1783        assert_eq!(manhattan_simd, manhattan_scalar);
1784
1785        let dot_simd = quantized_dot_product_simd(&a, &b);
1786        let dot_scalar = quantized_dot_product_scalar(&a, &b);
1787        assert_eq!(dot_simd, dot_scalar);
1788
1789        let euclidean_simd = quantized_euclidean_squared_simd(&a, &b);
1790        let euclidean_scalar = quantized_euclidean_squared_scalar(&a, &b);
1791        assert_eq!(euclidean_simd, euclidean_scalar);
1792    }
1793
1794    #[test]
1795    fn test_normalize_vector_simd() {
1796        let mut vec = vec![3.0, 4.0, 0.0];
1797        normalize_vector_simd(&mut vec);
1798
1799        // Expected: [3/5, 4/5, 0] = [0.6, 0.8, 0.0]
1800        assert!((vec[0] - 0.6).abs() < 1e-6);
1801        assert!((vec[1] - 0.8).abs() < 1e-6);
1802        assert!((vec[2] - 0.0).abs() < 1e-6);
1803
1804        // Check that norm is 1.0
1805        let norm_squared: f32 = vec.iter().map(|x| x * x).sum();
1806        assert!((norm_squared - 1.0).abs() < 1e-6);
1807    }
1808
1809    #[test]
1810    fn test_normalize_vector_simd_large() {
1811        // Test with large vector (768 dimensions)
1812        let mut vec: Vec<f32> = (0..768).map(|i| (i % 100) as f32).collect();
1813        normalize_vector_simd(&mut vec);
1814
1815        // Check that norm is 1.0
1816        let norm_squared: f32 = vec.iter().map(|x| x * x).sum();
1817        assert!((norm_squared - 1.0).abs() < 1e-5);
1818    }
1819
1820    #[test]
1821    fn test_normalize_vector_simd_zero() {
1822        // Test with zero vector (should not panic)
1823        let mut vec = vec![0.0, 0.0, 0.0];
1824        normalize_vector_simd(&mut vec);
1825
1826        // Should remain zero
1827        assert_eq!(vec, vec![0.0, 0.0, 0.0]);
1828    }
1829
1830    #[test]
1831    fn test_scale_vector_simd() {
1832        let mut vec = vec![1.0, 2.0, 3.0, 4.0];
1833        scale_vector_simd(&mut vec, 2.0);
1834
1835        assert_eq!(vec, vec![2.0, 4.0, 6.0, 8.0]);
1836    }
1837
1838    #[test]
1839    fn test_scale_vector_simd_large() {
1840        // Test with large vector (1024 dimensions)
1841        let mut vec: Vec<f32> = (0..1024).map(|i| i as f32).collect();
1842        scale_vector_simd(&mut vec, 0.5);
1843
1844        for (i, &value) in vec.iter().enumerate() {
1845            assert!((value - (i as f32 * 0.5)).abs() < 1e-5);
1846        }
1847    }
1848}