Skip to main content

diskann_rs/
simd.rs

1//! # SIMD-Accelerated Distance Functions
2//!
3//! Optimized distance calculations using SIMD instructions.
4//! Falls back to scalar implementations when SIMD is not available.
5//!
6//! ## Supported Architectures
7//!
8//! - **x86_64**: AVX2, SSE4.1 (auto-detected at runtime)
9//! - **aarch64**: NEON (always available on Apple Silicon)
10//! - **Fallback**: Portable scalar implementation
11//!
12//! ## Performance
13//!
14//! SIMD acceleration provides 2-8x speedup for distance calculations:
15//! - L2 (Euclidean): Process 8 floats per iteration (AVX) or 4 (SSE/NEON)
16//! - Dot product: Same vectorization approach
17//! - Cosine: Computed as 1 - dot(a,b) / (||a|| * ||b||)
18
19use anndists::prelude::Distance;
20
21/// SIMD-accelerated L2 (Euclidean squared) distance
22#[derive(Clone, Copy, Debug, Default)]
23pub struct SimdL2;
24
25/// SIMD-accelerated dot product distance (for normalized vectors)
26#[derive(Clone, Copy, Debug, Default)]
27pub struct SimdDot;
28
29/// SIMD-accelerated cosine distance
30#[derive(Clone, Copy, Debug, Default)]
31pub struct SimdCosine;
32
33// =============================================================================
34// Portable scalar implementations (fallback)
35// =============================================================================
36
37#[inline]
38fn l2_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
39    a.iter()
40        .zip(b.iter())
41        .map(|(x, y)| {
42            let d = x - y;
43            d * d
44        })
45        .sum()
46}
47
48#[inline]
49fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
50    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
51}
52
53#[inline]
54fn norm_squared_scalar(a: &[f32]) -> f32 {
55    a.iter().map(|x| x * x).sum()
56}
57
58// =============================================================================
59// x86_64 AVX2 implementations
60// =============================================================================
61
62#[cfg(target_arch = "x86_64")]
63mod x86_simd {
64    #[cfg(target_arch = "x86_64")]
65    use std::arch::x86_64::*;
66
67    /// Check if AVX2 is available at runtime
68    #[inline]
69    pub fn has_avx2() -> bool {
70        is_x86_feature_detected!("avx2")
71    }
72
73    /// Check if SSE4.1 is available at runtime
74    #[inline]
75    pub fn has_sse41() -> bool {
76        is_x86_feature_detected!("sse4.1")
77    }
78
79    /// L2 squared distance using AVX2 (8 floats at a time)
80    #[target_feature(enable = "avx2")]
81    #[inline]
82    pub unsafe fn l2_squared_avx2(a: &[f32], b: &[f32]) -> f32 {
83        debug_assert_eq!(a.len(), b.len());
84        let n = a.len();
85
86        let mut sum = _mm256_setzero_ps();
87        let mut i = 0;
88
89        // Process 8 elements at a time
90        while i + 8 <= n {
91            let va = _mm256_loadu_ps(a.as_ptr().add(i));
92            let vb = _mm256_loadu_ps(b.as_ptr().add(i));
93            let diff = _mm256_sub_ps(va, vb);
94            sum = _mm256_fmadd_ps(diff, diff, sum);
95            i += 8;
96        }
97
98        // Horizontal sum of 256-bit vector
99        let high = _mm256_extractf128_ps(sum, 1);
100        let low = _mm256_castps256_ps128(sum);
101        let sum128 = _mm_add_ps(high, low);
102        let shuf = _mm_movehdup_ps(sum128);
103        let sums = _mm_add_ps(sum128, shuf);
104        let shuf2 = _mm_movehl_ps(sums, sums);
105        let final_sum = _mm_add_ss(sums, shuf2);
106        let mut result = _mm_cvtss_f32(final_sum);
107
108        // Handle remaining elements
109        while i < n {
110            let d = a[i] - b[i];
111            result += d * d;
112            i += 1;
113        }
114
115        result
116    }
117
118    /// Dot product using AVX2
119    #[target_feature(enable = "avx2")]
120    #[inline]
121    pub unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
122        debug_assert_eq!(a.len(), b.len());
123        let n = a.len();
124
125        let mut sum = _mm256_setzero_ps();
126        let mut i = 0;
127
128        while i + 8 <= n {
129            let va = _mm256_loadu_ps(a.as_ptr().add(i));
130            let vb = _mm256_loadu_ps(b.as_ptr().add(i));
131            sum = _mm256_fmadd_ps(va, vb, sum);
132            i += 8;
133        }
134
135        // Horizontal sum
136        let high = _mm256_extractf128_ps(sum, 1);
137        let low = _mm256_castps256_ps128(sum);
138        let sum128 = _mm_add_ps(high, low);
139        let shuf = _mm_movehdup_ps(sum128);
140        let sums = _mm_add_ps(sum128, shuf);
141        let shuf2 = _mm_movehl_ps(sums, sums);
142        let final_sum = _mm_add_ss(sums, shuf2);
143        let mut result = _mm_cvtss_f32(final_sum);
144
145        while i < n {
146            result += a[i] * b[i];
147            i += 1;
148        }
149
150        result
151    }
152
153    /// Norm squared using AVX2
154    #[target_feature(enable = "avx2")]
155    #[inline]
156    pub unsafe fn norm_squared_avx2(a: &[f32]) -> f32 {
157        let n = a.len();
158        let mut sum = _mm256_setzero_ps();
159        let mut i = 0;
160
161        while i + 8 <= n {
162            let va = _mm256_loadu_ps(a.as_ptr().add(i));
163            sum = _mm256_fmadd_ps(va, va, sum);
164            i += 8;
165        }
166
167        let high = _mm256_extractf128_ps(sum, 1);
168        let low = _mm256_castps256_ps128(sum);
169        let sum128 = _mm_add_ps(high, low);
170        let shuf = _mm_movehdup_ps(sum128);
171        let sums = _mm_add_ps(sum128, shuf);
172        let shuf2 = _mm_movehl_ps(sums, sums);
173        let final_sum = _mm_add_ss(sums, shuf2);
174        let mut result = _mm_cvtss_f32(final_sum);
175
176        while i < n {
177            result += a[i] * a[i];
178            i += 1;
179        }
180
181        result
182    }
183
184    /// L2 squared using SSE4.1 (4 floats at a time)
185    #[target_feature(enable = "sse4.1")]
186    #[inline]
187    pub unsafe fn l2_squared_sse41(a: &[f32], b: &[f32]) -> f32 {
188        debug_assert_eq!(a.len(), b.len());
189        let n = a.len();
190
191        let mut sum = _mm_setzero_ps();
192        let mut i = 0;
193
194        while i + 4 <= n {
195            let va = _mm_loadu_ps(a.as_ptr().add(i));
196            let vb = _mm_loadu_ps(b.as_ptr().add(i));
197            let diff = _mm_sub_ps(va, vb);
198            let sq = _mm_mul_ps(diff, diff);
199            sum = _mm_add_ps(sum, sq);
200            i += 4;
201        }
202
203        // Horizontal sum
204        let shuf = _mm_movehdup_ps(sum);
205        let sums = _mm_add_ps(sum, shuf);
206        let shuf2 = _mm_movehl_ps(sums, sums);
207        let final_sum = _mm_add_ss(sums, shuf2);
208        let mut result = _mm_cvtss_f32(final_sum);
209
210        while i < n {
211            let d = a[i] - b[i];
212            result += d * d;
213            i += 1;
214        }
215
216        result
217    }
218
219    /// Dot product using SSE4.1
220    #[target_feature(enable = "sse4.1")]
221    #[inline]
222    pub unsafe fn dot_product_sse41(a: &[f32], b: &[f32]) -> f32 {
223        debug_assert_eq!(a.len(), b.len());
224        let n = a.len();
225
226        let mut sum = _mm_setzero_ps();
227        let mut i = 0;
228
229        while i + 4 <= n {
230            let va = _mm_loadu_ps(a.as_ptr().add(i));
231            let vb = _mm_loadu_ps(b.as_ptr().add(i));
232            let prod = _mm_mul_ps(va, vb);
233            sum = _mm_add_ps(sum, prod);
234            i += 4;
235        }
236
237        let shuf = _mm_movehdup_ps(sum);
238        let sums = _mm_add_ps(sum, shuf);
239        let shuf2 = _mm_movehl_ps(sums, sums);
240        let final_sum = _mm_add_ss(sums, shuf2);
241        let mut result = _mm_cvtss_f32(final_sum);
242
243        while i < n {
244            result += a[i] * b[i];
245            i += 1;
246        }
247
248        result
249    }
250}
251
252// =============================================================================
253// aarch64 NEON implementations
254// =============================================================================
255
256#[cfg(target_arch = "aarch64")]
257mod neon_simd {
258    use std::arch::aarch64::*;
259
260    /// L2 squared distance using NEON (4 floats at a time)
261    #[inline]
262    pub fn l2_squared_neon(a: &[f32], b: &[f32]) -> f32 {
263        debug_assert_eq!(a.len(), b.len());
264        let n = a.len();
265
266        // SAFETY: NEON is always available on aarch64
267        unsafe {
268            let mut sum = vdupq_n_f32(0.0);
269            let mut i = 0;
270
271            while i + 4 <= n {
272                let va = vld1q_f32(a.as_ptr().add(i));
273                let vb = vld1q_f32(b.as_ptr().add(i));
274                let diff = vsubq_f32(va, vb);
275                sum = vfmaq_f32(sum, diff, diff);
276                i += 4;
277            }
278
279            // Horizontal sum
280            let mut result = vaddvq_f32(sum);
281
282            while i < n {
283                let d = a[i] - b[i];
284                result += d * d;
285                i += 1;
286            }
287
288            result
289        }
290    }
291
292    /// Dot product using NEON
293    #[inline]
294    pub fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
295        debug_assert_eq!(a.len(), b.len());
296        let n = a.len();
297
298        // SAFETY: NEON is always available on aarch64
299        unsafe {
300            let mut sum = vdupq_n_f32(0.0);
301            let mut i = 0;
302
303            while i + 4 <= n {
304                let va = vld1q_f32(a.as_ptr().add(i));
305                let vb = vld1q_f32(b.as_ptr().add(i));
306                sum = vfmaq_f32(sum, va, vb);
307                i += 4;
308            }
309
310            let mut result = vaddvq_f32(sum);
311
312            while i < n {
313                result += a[i] * b[i];
314                i += 1;
315            }
316
317            result
318        }
319    }
320
321    /// Norm squared using NEON
322    #[inline]
323    pub fn norm_squared_neon(a: &[f32]) -> f32 {
324        let n = a.len();
325
326        // SAFETY: NEON is always available on aarch64
327        unsafe {
328            let mut sum = vdupq_n_f32(0.0);
329            let mut i = 0;
330
331            while i + 4 <= n {
332                let va = vld1q_f32(a.as_ptr().add(i));
333                sum = vfmaq_f32(sum, va, va);
334                i += 4;
335            }
336
337            let mut result = vaddvq_f32(sum);
338
339            while i < n {
340                result += a[i] * a[i];
341                i += 1;
342            }
343
344            result
345        }
346    }
347}
348
349// =============================================================================
350// aarch64 NEON F16/Int8 helpers
351// =============================================================================
352
353#[cfg(target_arch = "aarch64")]
354mod neon_quant {
355    use std::arch::aarch64::*;
356
357    /// Convert f16 (as u16 bits) to f32.
358    /// Uses scalar half crate conversion (NEON f16 intrinsics are unstable in Rust).
359    /// Then uses NEON for the L2 distance computation.
360    #[inline]
361    pub fn f16_to_f32_bulk_neon(input: &[u16], output: &mut [f32]) {
362        debug_assert_eq!(input.len(), output.len());
363        for (i, &bits) in input.iter().enumerate() {
364            output[i] = half::f16::from_bits(bits).to_f32();
365        }
366    }
367
368    /// L2 squared: f16 database vector vs f32 query.
369    /// Converts f16->f32 in a temp buffer then uses NEON L2.
370    #[inline]
371    pub fn l2_f16_vs_f32_neon(f16_data: &[u16], query: &[f32]) -> f32 {
372        debug_assert_eq!(f16_data.len(), query.len());
373        let n = f16_data.len();
374
375        // Convert f16 to f32 first
376        let mut db = vec![0.0f32; n];
377        for (i, &bits) in f16_data.iter().enumerate() {
378            db[i] = half::f16::from_bits(bits).to_f32();
379        }
380
381        // Now use NEON for the L2 computation
382        super::neon_simd::l2_squared_neon(&db, query)
383    }
384
385    /// L2 squared: u8 scaled database vector vs f32 query
386    /// Dequantizes on the fly: val = u8 * scale + offset (per dimension)
387    #[inline]
388    pub fn l2_u8_scaled_vs_f32_neon(
389        u8_data: &[u8],
390        query: &[f32],
391        scales: &[f32],
392        offsets: &[f32],
393    ) -> f32 {
394        debug_assert_eq!(u8_data.len(), query.len());
395        debug_assert_eq!(scales.len(), query.len());
396        debug_assert_eq!(offsets.len(), query.len());
397        let n = u8_data.len();
398        let mut i = 0;
399
400        unsafe {
401            let mut sum = vdupq_n_f32(0.0);
402
403            while i + 4 <= n {
404                // Load 4 u8 values and convert to f32
405                let b0 = u8_data[i] as f32;
406                let b1 = u8_data[i + 1] as f32;
407                let b2 = u8_data[i + 2] as f32;
408                let b3 = u8_data[i + 3] as f32;
409                let vals = [b0, b1, b2, b3];
410                let vu8 = vld1q_f32(vals.as_ptr());
411
412                let vscale = vld1q_f32(scales.as_ptr().add(i));
413                let voff = vld1q_f32(offsets.as_ptr().add(i));
414                let vq = vld1q_f32(query.as_ptr().add(i));
415
416                // dequant = u8 * scale + offset
417                let dequant = vfmaq_f32(voff, vu8, vscale);
418                let diff = vsubq_f32(dequant, vq);
419                sum = vfmaq_f32(sum, diff, diff);
420                i += 4;
421            }
422
423            let mut result = vaddvq_f32(sum);
424
425            while i < n {
426                let dequant = u8_data[i] as f32 * scales[i] + offsets[i];
427                let d = dequant - query[i];
428                result += d * d;
429                i += 1;
430            }
431
432            result
433        }
434    }
435}
436
437// =============================================================================
438// x86_64 F16/Int8 SIMD helpers
439// =============================================================================
440
441#[cfg(target_arch = "x86_64")]
442mod x86_quant {
443    use std::arch::x86_64::*;
444
445    #[inline]
446    pub fn has_f16c() -> bool {
447        is_x86_feature_detected!("f16c")
448    }
449
450    /// Bulk convert f16 (as u16 bits) to f32 using F16C
451    #[target_feature(enable = "f16c")]
452    #[inline]
453    pub unsafe fn f16_to_f32_bulk_f16c(input: &[u16], output: &mut [f32]) {
454        debug_assert_eq!(input.len(), output.len());
455        let n = input.len();
456        let mut i = 0;
457
458        while i + 8 <= n {
459            let half8 = _mm_loadu_si128(input.as_ptr().add(i) as *const __m128i);
460            let f8 = _mm256_cvtph_ps(half8);
461            _mm256_storeu_ps(output.as_mut_ptr().add(i), f8);
462            i += 8;
463        }
464
465        while i < n {
466            output[i] = half::f16::from_bits(input[i]).to_f32();
467            i += 1;
468        }
469    }
470
471    /// L2 squared: f16 database vs f32 query, fused F16C convert+distance
472    #[target_feature(enable = "f16c", enable = "avx2")]
473    #[inline]
474    pub unsafe fn l2_f16_vs_f32_f16c(f16_data: &[u16], query: &[f32]) -> f32 {
475        debug_assert_eq!(f16_data.len(), query.len());
476        let n = f16_data.len();
477        let mut i = 0;
478        let mut sum = _mm256_setzero_ps();
479
480        while i + 8 <= n {
481            let half8 = _mm_loadu_si128(f16_data.as_ptr().add(i) as *const __m128i);
482            let db = _mm256_cvtph_ps(half8);
483            let q = _mm256_loadu_ps(query.as_ptr().add(i));
484            let diff = _mm256_sub_ps(db, q);
485            sum = _mm256_fmadd_ps(diff, diff, sum);
486            i += 8;
487        }
488
489        // Horizontal sum
490        let high = _mm256_extractf128_ps(sum, 1);
491        let low = _mm256_castps256_ps128(sum);
492        let sum128 = _mm_add_ps(high, low);
493        let shuf = _mm_movehdup_ps(sum128);
494        let sums = _mm_add_ps(sum128, shuf);
495        let shuf2 = _mm_movehl_ps(sums, sums);
496        let final_sum = _mm_add_ss(sums, shuf2);
497        let mut result = _mm_cvtss_f32(final_sum);
498
499        while i < n {
500            let f = half::f16::from_bits(f16_data[i]).to_f32();
501            let d = f - query[i];
502            result += d * d;
503            i += 1;
504        }
505
506        result
507    }
508}
509
510// =============================================================================
511// Unified dispatch functions
512// =============================================================================
513
514/// Compute L2 squared distance with best available SIMD
515#[inline]
516pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
517    #[cfg(target_arch = "x86_64")]
518    {
519        if x86_simd::has_avx2() {
520            return unsafe { x86_simd::l2_squared_avx2(a, b) };
521        }
522        if x86_simd::has_sse41() {
523            return unsafe { x86_simd::l2_squared_sse41(a, b) };
524        }
525    }
526
527    #[cfg(target_arch = "aarch64")]
528    {
529        return neon_simd::l2_squared_neon(a, b);
530    }
531
532    #[allow(unreachable_code)]
533    l2_squared_scalar(a, b)
534}
535
536/// Compute dot product with best available SIMD
537#[inline]
538pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
539    #[cfg(target_arch = "x86_64")]
540    {
541        if x86_simd::has_avx2() {
542            return unsafe { x86_simd::dot_product_avx2(a, b) };
543        }
544        if x86_simd::has_sse41() {
545            return unsafe { x86_simd::dot_product_sse41(a, b) };
546        }
547    }
548
549    #[cfg(target_arch = "aarch64")]
550    {
551        return neon_simd::dot_product_neon(a, b);
552    }
553
554    #[allow(unreachable_code)]
555    dot_product_scalar(a, b)
556}
557
558/// Compute squared norm with best available SIMD
559#[inline]
560pub fn norm_squared(a: &[f32]) -> f32 {
561    #[cfg(target_arch = "x86_64")]
562    {
563        if x86_simd::has_avx2() {
564            return unsafe { x86_simd::norm_squared_avx2(a) };
565        }
566    }
567
568    #[cfg(target_arch = "aarch64")]
569    {
570        return neon_simd::norm_squared_neon(a);
571    }
572
573    #[allow(unreachable_code)]
574    norm_squared_scalar(a)
575}
576
577/// Compute cosine distance with best available SIMD
578/// Returns 1 - cosine_similarity, so 0 = identical, 2 = opposite
579#[inline]
580pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
581    let dot = dot_product(a, b);
582    let norm_a = norm_squared(a).sqrt();
583    let norm_b = norm_squared(b).sqrt();
584
585    if norm_a == 0.0 || norm_b == 0.0 {
586        return 1.0;
587    }
588
589    let cosine_sim = dot / (norm_a * norm_b);
590    1.0 - cosine_sim.clamp(-1.0, 1.0)
591}
592
593// =============================================================================
594// F16 / Int8 quantization dispatch functions
595// =============================================================================
596
597/// Bulk convert f16 values (as u16 bits) to f32.
598/// Uses F16C on x86_64 or NEON vcvt on aarch64, scalar fallback otherwise.
599#[inline]
600pub fn f16_to_f32_bulk(input: &[u16], output: &mut [f32]) {
601    debug_assert_eq!(input.len(), output.len());
602
603    #[cfg(target_arch = "x86_64")]
604    {
605        if x86_quant::has_f16c() {
606            unsafe { x86_quant::f16_to_f32_bulk_f16c(input, output) };
607            return;
608        }
609    }
610
611    #[cfg(target_arch = "aarch64")]
612    {
613        neon_quant::f16_to_f32_bulk_neon(input, output);
614        return;
615    }
616
617    // Scalar fallback
618    #[allow(unreachable_code)]
619    for (i, &bits) in input.iter().enumerate() {
620        output[i] = half::f16::from_bits(bits).to_f32();
621    }
622}
623
624/// L2 squared distance: f16 database vector (as u16 bits) vs f32 query.
625/// Fused convert + distance for fewer memory passes.
626#[inline]
627pub fn l2_f16_vs_f32(f16_data: &[u16], query: &[f32]) -> f32 {
628    debug_assert_eq!(f16_data.len(), query.len());
629
630    #[cfg(target_arch = "x86_64")]
631    {
632        if x86_quant::has_f16c() && x86_simd::has_avx2() {
633            return unsafe { x86_quant::l2_f16_vs_f32_f16c(f16_data, query) };
634        }
635    }
636
637    #[cfg(target_arch = "aarch64")]
638    {
639        return neon_quant::l2_f16_vs_f32_neon(f16_data, query);
640    }
641
642    // Scalar fallback
643    #[allow(unreachable_code)]
644    {
645        let mut sum = 0.0f32;
646        for (i, &bits) in f16_data.iter().enumerate() {
647            let f = half::f16::from_bits(bits).to_f32();
648            let d = f - query[i];
649            sum += d * d;
650        }
651        sum
652    }
653}
654
655/// L2 squared distance: u8 quantized vector vs f32 query.
656/// Dequantizes on the fly: `value = u8_val * scale[dim] + offset[dim]`
657#[inline]
658pub fn l2_u8_scaled_vs_f32(
659    u8_data: &[u8],
660    query: &[f32],
661    scales: &[f32],
662    offsets: &[f32],
663) -> f32 {
664    debug_assert_eq!(u8_data.len(), query.len());
665    debug_assert_eq!(scales.len(), query.len());
666    debug_assert_eq!(offsets.len(), query.len());
667
668    #[cfg(target_arch = "aarch64")]
669    {
670        return neon_quant::l2_u8_scaled_vs_f32_neon(u8_data, query, scales, offsets);
671    }
672
673    // Scalar fallback (also used on x86_64 without specific intrinsic)
674    #[allow(unreachable_code)]
675    {
676        let mut sum = 0.0f32;
677        for i in 0..u8_data.len() {
678            let dequant = u8_data[i] as f32 * scales[i] + offsets[i];
679            let d = dequant - query[i];
680            sum += d * d;
681        }
682        sum
683    }
684}
685
686// =============================================================================
687// Distance trait implementations
688// =============================================================================
689
690impl Distance<f32> for SimdL2 {
691    fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
692        l2_squared(a, b)
693    }
694}
695
696impl Distance<f32> for SimdDot {
697    fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
698        // For ANN, we want distance (lower = closer)
699        // Assuming normalized vectors: distance = 1 - dot_product
700        1.0 - dot_product(a, b)
701    }
702}
703
704impl Distance<f32> for SimdCosine {
705    fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
706        cosine_distance(a, b)
707    }
708}
709
710// =============================================================================
711// Runtime info
712// =============================================================================
713
714/// Returns information about SIMD capabilities
715pub fn simd_info() -> SimdInfo {
716    SimdInfo {
717        #[cfg(target_arch = "x86_64")]
718        avx2: x86_simd::has_avx2(),
719        #[cfg(not(target_arch = "x86_64"))]
720        avx2: false,
721
722        #[cfg(target_arch = "x86_64")]
723        sse41: x86_simd::has_sse41(),
724        #[cfg(not(target_arch = "x86_64"))]
725        sse41: false,
726
727        #[cfg(target_arch = "aarch64")]
728        neon: true,
729        #[cfg(not(target_arch = "aarch64"))]
730        neon: false,
731    }
732}
733
734/// Information about available SIMD features
735#[derive(Debug, Clone)]
736pub struct SimdInfo {
737    pub avx2: bool,
738    pub sse41: bool,
739    pub neon: bool,
740}
741
742impl std::fmt::Display for SimdInfo {
743    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
744        let mut features = Vec::new();
745        if self.avx2 {
746            features.push("AVX2");
747        }
748        if self.sse41 {
749            features.push("SSE4.1");
750        }
751        if self.neon {
752            features.push("NEON");
753        }
754        if features.is_empty() {
755            write!(f, "SIMD: none (scalar fallback)")
756        } else {
757            write!(f, "SIMD: {}", features.join(", "))
758        }
759    }
760}
761
762#[cfg(test)]
763mod tests {
764    use super::*;
765
766    #[test]
767    fn test_l2_squared_basic() {
768        let a = vec![1.0, 2.0, 3.0, 4.0];
769        let b = vec![5.0, 6.0, 7.0, 8.0];
770
771        let expected: f32 = a
772            .iter()
773            .zip(&b)
774            .map(|(x, y)| (x - y) * (x - y))
775            .sum();
776
777        let result = l2_squared(&a, &b);
778        assert!((result - expected).abs() < 1e-5, "expected {expected}, got {result}");
779    }
780
781    #[test]
782    fn test_l2_squared_large() {
783        // Test with dimension that requires multiple SIMD iterations + remainder
784        let dim = 133; // Not divisible by 4 or 8
785        let a: Vec<f32> = (0..dim).map(|i| i as f32).collect();
786        let b: Vec<f32> = (0..dim).map(|i| (i * 2) as f32).collect();
787
788        let expected = l2_squared_scalar(&a, &b);
789        let result = l2_squared(&a, &b);
790
791        assert!(
792            (result - expected).abs() < 1e-3,
793            "expected {expected}, got {result}"
794        );
795    }
796
797    #[test]
798    fn test_dot_product_basic() {
799        let a = vec![1.0, 2.0, 3.0, 4.0];
800        let b = vec![5.0, 6.0, 7.0, 8.0];
801
802        let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
803        let result = dot_product(&a, &b);
804
805        assert!((result - expected).abs() < 1e-5, "expected {expected}, got {result}");
806    }
807
808    #[test]
809    fn test_dot_product_large() {
810        let dim = 128;
811        let a: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
812        let b: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.02).collect();
813
814        let expected = dot_product_scalar(&a, &b);
815        let result = dot_product(&a, &b);
816
817        assert!(
818            (result - expected).abs() < 1e-3,
819            "expected {expected}, got {result}"
820        );
821    }
822
823    #[test]
824    fn test_cosine_identical() {
825        let a = vec![1.0, 2.0, 3.0, 4.0];
826        let result = cosine_distance(&a, &a);
827        assert!(result.abs() < 1e-5, "identical vectors should have distance ~0, got {result}");
828    }
829
830    #[test]
831    fn test_cosine_orthogonal() {
832        let a = vec![1.0, 0.0];
833        let b = vec![0.0, 1.0];
834        let result = cosine_distance(&a, &b);
835        assert!((result - 1.0).abs() < 1e-5, "orthogonal vectors should have distance ~1, got {result}");
836    }
837
838    #[test]
839    fn test_cosine_opposite() {
840        let a = vec![1.0, 2.0, 3.0];
841        let b: Vec<f32> = a.iter().map(|x| -x).collect();
842        let result = cosine_distance(&a, &b);
843        assert!((result - 2.0).abs() < 1e-5, "opposite vectors should have distance ~2, got {result}");
844    }
845
846    #[test]
847    fn test_simd_info() {
848        let info = simd_info();
849        println!("{}", info);
850        // Just verify it doesn't panic
851    }
852
853    #[test]
854    fn test_distance_trait_impl() {
855        let a = vec![1.0, 2.0, 3.0, 4.0];
856        let b = vec![5.0, 6.0, 7.0, 8.0];
857
858        let l2 = SimdL2;
859        let result = l2.eval(&a, &b);
860        assert!(result > 0.0);
861
862        let cosine = SimdCosine;
863        let result = cosine.eval(&a, &b);
864        assert!(result >= 0.0 && result <= 2.0);
865    }
866}