Skip to main content

foxstash_core/vector/
simd.rs

1//! SIMD-accelerated vector operations
2//!
3//! This module provides high-performance SIMD implementations of vector operations
4//! for x86_64 (AVX2, SSE) and ARM (NEON) architectures. The implementations use
5//! the `pulp` crate for portable SIMD abstraction with runtime CPU detection.
6//!
7//! # Performance
8//!
9//! SIMD implementations provide 3-4x speedup over scalar operations for typical
10//! embedding dimensions (384, 768, 1024). The exact speedup depends on:
11//! - Vector length (longer vectors benefit more)
12//! - CPU architecture and SIMD support
13//! - Memory alignment and cache behavior
14//!
15//! # Architecture Support
16//!
17//! - **x86_64**: AVX2 (8x f32), SSE (4x f32), scalar fallback
18//! - **ARM**: NEON (4x f32), scalar fallback
19//! - **Other**: Scalar fallback
20//!
21//! # Usage
22//!
23//! ```
24//! use foxstash_core::vector::simd::{dot_product_simd, cosine_similarity_simd};
25//!
26//! let a = vec![1.0; 384];
27//! let b = vec![2.0; 384];
28//!
29//! let dot = dot_product_simd(&a, &b);
30//! let similarity = cosine_similarity_simd(&a, &b);
31//! ```
32
33use pulp::Simd;
34
35/// Computes dot product using SIMD acceleration.
36///
37/// This function automatically detects and uses the best available SIMD
38/// instruction set (AVX2, SSE, NEON, or scalar fallback).
39///
40/// # Arguments
41///
42/// * `a` - First vector (must have same length as `b`)
43/// * `b` - Second vector (must have same length as `a`)
44///
45/// # Returns
46///
47/// Returns the dot product as a scalar f32 value.
48///
49/// # Panics
50///
51/// Panics if vectors have different lengths (use checked version for safety).
52///
53/// # Examples
54///
55/// ```
56/// use foxstash_core::vector::simd::dot_product_simd;
57///
58/// let a = vec![1.0, 2.0, 3.0];
59/// let b = vec![4.0, 5.0, 6.0];
60/// let result = dot_product_simd(&a, &b);
61/// assert!((result - 32.0).abs() < 1e-5);
62/// ```
63#[inline]
64pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
65    assert_eq!(a.len(), b.len(), "Vector dimensions must match");
66
67    let simd = pulp::Arch::new();
68
69    simd.dispatch(|| dot_product_simd_impl(simd, a, b))
70}
71
72/// Computes L2 (Euclidean) distance using SIMD acceleration.
73///
74/// Calculates: `sqrt(sum((a[i] - b[i])^2))`
75///
76/// # Arguments
77///
78/// * `a` - First vector
79/// * `b` - Second vector
80///
81/// # Returns
82///
83/// Returns the non-negative L2 distance.
84///
85/// # Panics
86///
87/// Panics if vectors have different lengths.
88///
89/// # Examples
90///
91/// ```
92/// use foxstash_core::vector::simd::l2_distance_simd;
93///
94/// let a = vec![0.0, 0.0];
95/// let b = vec![3.0, 4.0];
96/// let distance = l2_distance_simd(&a, &b);
97/// assert!((distance - 5.0).abs() < 1e-5);
98/// ```
99#[inline]
100pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
101    assert_eq!(a.len(), b.len(), "Vector dimensions must match");
102
103    let simd = pulp::Arch::new();
104
105    simd.dispatch(|| l2_distance_simd_impl(simd, a, b))
106}
107
108/// Computes cosine similarity using SIMD acceleration.
109///
110/// Calculates: dot(a, b) / (||a|| * ||b||)
111///
112/// Returns a value in [-1, 1] where:
113/// - 1.0 = identical direction
114/// - 0.0 = orthogonal
115/// - -1.0 = opposite direction
116///
117/// # Arguments
118///
119/// * `a` - First vector
120/// * `b` - Second vector
121///
122/// # Returns
123///
124/// Returns cosine similarity in range [-1, 1].
125///
126/// # Panics
127///
128/// Panics if vectors have different lengths.
129///
130/// # Examples
131///
132/// ```
133/// use foxstash_core::vector::simd::cosine_similarity_simd;
134///
135/// let a = vec![1.0, 0.0, 0.0];
136/// let b = vec![0.0, 1.0, 0.0];
137/// let similarity = cosine_similarity_simd(&a, &b);
138/// assert!((similarity - 0.0).abs() < 1e-5);
139/// ```
140#[inline]
141pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
142    assert_eq!(a.len(), b.len(), "Vector dimensions must match");
143
144    if a.is_empty() {
145        return 1.0; // Convention: empty vectors are maximally similar
146    }
147
148    let simd = pulp::Arch::new();
149
150    simd.dispatch(|| {
151        let dot = dot_product_simd_impl(simd, a, b);
152        let norm_a = magnitude_simd_impl(simd, a);
153        let norm_b = magnitude_simd_impl(simd, b);
154
155        // Handle zero vectors
156        if norm_a == 0.0 || norm_b == 0.0 {
157            return 0.0;
158        }
159
160        // Compute similarity and clamp to [-1, 1] to handle numerical errors
161        let similarity = dot / (norm_a * norm_b);
162        similarity.clamp(-1.0, 1.0)
163    })
164}
165
166/// Computes the L2 norm (magnitude) of a vector using SIMD acceleration.
167///
168/// Returns `sqrt(sum(v[i]^2))`.
169#[inline]
170pub fn norm_simd(v: &[f32]) -> f32 {
171    let simd = pulp::Arch::new();
172    simd.dispatch(Magnitude { vector: v })
173}
174
175/// Computes cosine distance with a precomputed norm for vector `b`.
176///
177/// This is the fused hot-path: a single `dispatch` call with two SIMD
178/// accumulators (dot product + norm_a²) in one pass over the data.
179/// The caller supplies `norm_b` (precomputed and cached per stored vector).
180///
181/// Returns `1.0 - dot(a,b) / (||a|| * norm_b)`, i.e. cosine distance in [0, 2].
182#[inline]
183pub fn cosine_distance_prenorm(a: &[f32], b: &[f32], norm_b: f32) -> f32 {
184    debug_assert_eq!(a.len(), b.len());
185
186    if norm_b == 0.0 {
187        return 1.0;
188    }
189
190    let simd = pulp::Arch::new();
191    simd.dispatch(FusedCosineDistance { a, b, norm_b })
192}
193
194/// Fused cosine distance: single SIMD pass computing dot(a,b) and ||a||² simultaneously.
195struct FusedCosineDistance<'a> {
196    a: &'a [f32],
197    b: &'a [f32],
198    norm_b: f32,
199}
200
201impl pulp::WithSimd for FusedCosineDistance<'_> {
202    type Output = f32;
203
204    #[inline(always)]
205    fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
206        let a = self.a;
207        let b = self.b;
208        let norm_b = self.norm_b;
209        let (a_chunks, a_tail) = S::as_simd_f32s(a);
210        let (b_chunks, b_tail) = S::as_simd_f32s(b);
211
212        let mut dot_acc = simd.splat_f32s(0.0);
213        let mut norm_a_acc = simd.splat_f32s(0.0);
214        for (&a_vec, &b_vec) in a_chunks.iter().zip(b_chunks.iter()) {
215            dot_acc = simd.mul_add_e_f32s(a_vec, b_vec, dot_acc);
216            norm_a_acc = simd.mul_add_e_f32s(a_vec, a_vec, norm_a_acc);
217        }
218
219        let mut dot = simd.reduce_sum_f32s(dot_acc);
220        let mut norm_a_sq = simd.reduce_sum_f32s(norm_a_acc);
221
222        debug_assert_eq!(a_tail.len(), b_tail.len());
223        for (&a_scalar, &b_scalar) in a_tail.iter().zip(b_tail.iter()) {
224            dot += a_scalar * b_scalar;
225            norm_a_sq += a_scalar * a_scalar;
226        }
227
228        let norm_a = norm_a_sq.sqrt();
229        if norm_a == 0.0 {
230            return 1.0;
231        }
232
233        let similarity = dot / (norm_a * norm_b);
234        1.0 - similarity.clamp(-1.0, 1.0)
235    }
236}
237
238/// Internal implementation of dot product with SIMD.
239///
240/// This function is generic over SIMD architecture and will use the best
241/// available instruction set at runtime.
242#[inline(always)]
243fn dot_product_simd_impl(simd: pulp::Arch, a: &[f32], b: &[f32]) -> f32 {
244    struct DotProduct<'a> {
245        a: &'a [f32],
246        b: &'a [f32],
247    }
248
249    impl pulp::WithSimd for DotProduct<'_> {
250        type Output = f32;
251
252        #[inline(always)]
253        fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
254            let a = self.a;
255            let b = self.b;
256            let (a_chunks, a_tail) = S::as_simd_f32s(a);
257            let (b_chunks, b_tail) = S::as_simd_f32s(b);
258
259            let mut sum = simd.splat_f32s(0.0);
260            for (&a_vec, &b_vec) in a_chunks.iter().zip(b_chunks.iter()) {
261                sum = simd.mul_add_e_f32s(a_vec, b_vec, sum);
262            }
263
264            let mut result = simd.reduce_sum_f32s(sum);
265            debug_assert_eq!(a_tail.len(), b_tail.len());
266            for (&a_scalar, &b_scalar) in a_tail.iter().zip(b_tail.iter()) {
267                result += a_scalar * b_scalar;
268            }
269
270            result
271        }
272    }
273
274    simd.dispatch(DotProduct { a, b })
275}
276
277/// Internal implementation of L2 distance with SIMD.
278#[inline(always)]
279fn l2_distance_simd_impl(simd: pulp::Arch, a: &[f32], b: &[f32]) -> f32 {
280    struct L2Distance<'a> {
281        a: &'a [f32],
282        b: &'a [f32],
283    }
284
285    impl pulp::WithSimd for L2Distance<'_> {
286        type Output = f32;
287
288        #[inline(always)]
289        fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
290            let a = self.a;
291            let b = self.b;
292            let (a_chunks, a_tail) = S::as_simd_f32s(a);
293            let (b_chunks, b_tail) = S::as_simd_f32s(b);
294
295            let mut sum_squares = simd.splat_f32s(0.0);
296            for (&a_vec, &b_vec) in a_chunks.iter().zip(b_chunks.iter()) {
297                let diff = simd.sub_f32s(a_vec, b_vec);
298                sum_squares = simd.mul_add_e_f32s(diff, diff, sum_squares);
299            }
300
301            let mut result = simd.reduce_sum_f32s(sum_squares);
302            debug_assert_eq!(a_tail.len(), b_tail.len());
303            for (&a_scalar, &b_scalar) in a_tail.iter().zip(b_tail.iter()) {
304                let diff = a_scalar - b_scalar;
305                result += diff * diff;
306            }
307
308            result.sqrt()
309        }
310    }
311
312    simd.dispatch(L2Distance { a, b })
313}
314
315/// Vector magnitude WithSimd impl — used by both `magnitude_simd_impl` and `norm_simd`.
316struct Magnitude<'a> {
317    vector: &'a [f32],
318}
319
320impl pulp::WithSimd for Magnitude<'_> {
321    type Output = f32;
322
323    #[inline(always)]
324    fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
325        let vector = self.vector;
326        let (chunks, tail) = S::as_simd_f32s(vector);
327
328        let mut sum_squares = simd.splat_f32s(0.0);
329        for &vector_chunk in chunks {
330            sum_squares = simd.mul_add_e_f32s(vector_chunk, vector_chunk, sum_squares);
331        }
332
333        let mut result = simd.reduce_sum_f32s(sum_squares);
334
335        for &value in tail {
336            result += value * value;
337        }
338
339        result.sqrt()
340    }
341}
342
343/// Internal implementation of vector magnitude with SIMD.
344#[inline(always)]
345fn magnitude_simd_impl(simd: pulp::Arch, vector: &[f32]) -> f32 {
346    simd.dispatch(Magnitude { vector })
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use crate::vector::ops::{cosine_similarity, dot_product, l2_distance};
353
354    const EPSILON: f32 = 1e-5;
355
356    #[test]
357    fn test_dot_product_simd_basic() {
358        let a = vec![1.0, 2.0, 3.0];
359        let b = vec![4.0, 5.0, 6.0];
360
361        let result = dot_product_simd(&a, &b);
362        let expected = dot_product(&a, &b).unwrap();
363
364        assert!((result - expected).abs() < EPSILON);
365        assert!((result - 32.0).abs() < EPSILON);
366    }
367
368    #[test]
369    fn test_dot_product_simd_zero() {
370        let a = vec![1.0, 0.0, 0.0];
371        let b = vec![0.0, 1.0, 0.0];
372
373        let result = dot_product_simd(&a, &b);
374        assert!((result - 0.0).abs() < EPSILON);
375    }
376
377    #[test]
378    fn test_dot_product_simd_negative() {
379        let a = vec![1.0, 2.0, 3.0];
380        let b = vec![-1.0, -2.0, -3.0];
381
382        let result = dot_product_simd(&a, &b);
383        let expected = dot_product(&a, &b).unwrap();
384
385        assert!((result - expected).abs() < EPSILON);
386    }
387
388    #[test]
389    fn test_dot_product_simd_various_sizes() {
390        // Test different sizes to verify remainder handling
391        for size in [
392            1, 2, 3, 4, 5, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 383, 384, 767, 768,
393        ] {
394            let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
395            let b: Vec<f32> = (0..size).map(|i| (i * 2) as f32).collect();
396
397            let simd_result = dot_product_simd(&a, &b);
398            let scalar_result = dot_product(&a, &b).unwrap();
399
400            // Use relative epsilon for large results
401            let epsilon = if scalar_result.abs() > 1000.0 {
402                scalar_result.abs() * 1e-5 // 0.001% relative error
403            } else {
404                EPSILON
405            };
406
407            assert!(
408                (simd_result - scalar_result).abs() < epsilon,
409                "Size {}: SIMD={}, Scalar={}",
410                size,
411                simd_result,
412                scalar_result
413            );
414        }
415    }
416
417    #[test]
418    fn test_dot_product_simd_misaligned_subslice_regression() {
419        let size = 257;
420        let a_storage: Vec<f32> = (0..(size + 3))
421            .map(|i| ((i as f32) - 90.0) * 0.03125)
422            .collect();
423        let b_storage: Vec<f32> = (0..(size + 4))
424            .map(|i| ((size + 4 - i) as f32 - 120.0) * 0.0625)
425            .collect();
426
427        let a = &a_storage[1..(1 + size)];
428        let b = &b_storage[2..(2 + size)];
429
430        let simd_result = dot_product_simd(a, b);
431        let scalar_result = dot_product(a, b).unwrap();
432        assert!((simd_result - scalar_result).abs() < 1e-4);
433    }
434
435    #[test]
436    fn test_l2_distance_simd_basic() {
437        let a = vec![0.0, 0.0];
438        let b = vec![3.0, 4.0];
439
440        let result = l2_distance_simd(&a, &b);
441        let expected = l2_distance(&a, &b).unwrap();
442
443        assert!((result - expected).abs() < EPSILON);
444        assert!((result - 5.0).abs() < EPSILON);
445    }
446
447    #[test]
448    fn test_l2_distance_simd_zero() {
449        let a = vec![1.0, 2.0, 3.0];
450        let b = vec![1.0, 2.0, 3.0];
451
452        let result = l2_distance_simd(&a, &b);
453        assert!(result < EPSILON);
454    }
455
456    #[test]
457    fn test_l2_distance_simd_various_sizes() {
458        for size in [
459            1, 2, 3, 4, 5, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 383, 384, 767, 768,
460        ] {
461            let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
462            let b: Vec<f32> = (0..size).map(|i| (i * 2) as f32).collect();
463
464            let simd_result = l2_distance_simd(&a, &b);
465            let scalar_result = l2_distance(&a, &b).unwrap();
466
467            // Use relative epsilon for large results
468            let epsilon = if scalar_result.abs() > 1000.0 {
469                scalar_result.abs() * 1e-5 // 0.001% relative error
470            } else {
471                EPSILON
472            };
473
474            assert!(
475                (simd_result - scalar_result).abs() < epsilon,
476                "Size {}: SIMD={}, Scalar={}",
477                size,
478                simd_result,
479                scalar_result
480            );
481        }
482    }
483
484    #[test]
485    fn test_l2_distance_simd_misaligned_subslice_regression() {
486        let size = 257;
487        let a_storage: Vec<f32> = (0..(size + 3))
488            .map(|i| ((i as f32) - 30.0) * 0.125)
489            .collect();
490        let b_storage: Vec<f32> = (0..(size + 4))
491            .map(|i| ((i as f32) - 170.0) * -0.09375)
492            .collect();
493
494        let a = &a_storage[1..(1 + size)];
495        let b = &b_storage[2..(2 + size)];
496
497        let simd_result = l2_distance_simd(a, b);
498        let scalar_result = l2_distance(a, b).unwrap();
499        assert!((simd_result - scalar_result).abs() < 1e-4);
500    }
501
502    #[test]
503    fn test_cosine_similarity_simd_basic() {
504        let a = vec![1.0, 0.0, 0.0];
505        let b = vec![0.0, 1.0, 0.0];
506
507        let result = cosine_similarity_simd(&a, &b);
508        let expected = cosine_similarity(&a, &b).unwrap();
509
510        assert!((result - expected).abs() < EPSILON);
511        assert!((result - 0.0).abs() < EPSILON);
512    }
513
514    #[test]
515    fn test_cosine_similarity_simd_identical() {
516        let a = vec![1.0, 2.0, 3.0];
517        let b = vec![1.0, 2.0, 3.0];
518
519        let result = cosine_similarity_simd(&a, &b);
520        assert!((result - 1.0).abs() < EPSILON);
521    }
522
523    #[test]
524    fn test_cosine_similarity_simd_opposite() {
525        let a = vec![1.0, 2.0, 3.0];
526        let b = vec![-1.0, -2.0, -3.0];
527
528        let result = cosine_similarity_simd(&a, &b);
529        assert!((result - (-1.0)).abs() < EPSILON);
530    }
531
532    #[test]
533    fn test_cosine_similarity_simd_zero_vector() {
534        let a = vec![0.0, 0.0, 0.0];
535        let b = vec![1.0, 2.0, 3.0];
536
537        let result = cosine_similarity_simd(&a, &b);
538        assert!((result - 0.0).abs() < EPSILON);
539    }
540
541    #[test]
542    fn test_cosine_similarity_simd_various_sizes() {
543        for size in [
544            1, 2, 3, 4, 5, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 383, 384, 767, 768, 1023, 1024,
545        ] {
546            let a: Vec<f32> = (0..size).map(|i| (i as f32) / (size as f32)).collect();
547            let b: Vec<f32> = (0..size)
548                .map(|i| 1.0 - (i as f32) / (size as f32))
549                .collect();
550
551            let simd_result = cosine_similarity_simd(&a, &b);
552            let scalar_result = cosine_similarity(&a, &b).unwrap();
553
554            // Cosine similarity is always in [-1, 1], but may have more rounding for large vectors
555            let epsilon = if size > 100 { 1e-4 } else { EPSILON };
556
557            assert!(
558                (simd_result - scalar_result).abs() < epsilon,
559                "Size {}: SIMD={}, Scalar={}",
560                size,
561                simd_result,
562                scalar_result
563            );
564        }
565    }
566
567    #[test]
568    fn test_cosine_similarity_simd_misaligned_subslice_regression() {
569        let size = 257;
570        let a_storage: Vec<f32> = (0..(size + 3))
571            .map(|i| (((i as f32) % 17.0) - 8.0) * 0.37)
572            .collect();
573        let b_storage: Vec<f32> = (0..(size + 4))
574            .map(|i| (((i as f32) % 19.0) - 9.0) * -0.29)
575            .collect();
576
577        let a = &a_storage[1..(1 + size)];
578        let b = &b_storage[2..(2 + size)];
579
580        let simd_result = cosine_similarity_simd(a, b);
581        let scalar_result = cosine_similarity(a, b).unwrap();
582        assert!((simd_result - scalar_result).abs() < 1e-4);
583    }
584
585    #[test]
586    fn test_simd_numerical_stability() {
587        // Test with large values
588        let a = vec![1e6; 384];
589        let b = vec![2e6; 384];
590
591        let simd_result = cosine_similarity_simd(&a, &b);
592        let scalar_result = cosine_similarity(&a, &b).unwrap();
593
594        assert!((simd_result - scalar_result).abs() < EPSILON);
595        assert!((-1.0..=1.0).contains(&simd_result));
596
597        // Test with small values
598        let a = vec![1e-6; 384];
599        let b = vec![2e-6; 384];
600
601        let simd_result = cosine_similarity_simd(&a, &b);
602        let scalar_result = cosine_similarity(&a, &b).unwrap();
603
604        assert!((simd_result - scalar_result).abs() < EPSILON);
605        assert!((-1.0..=1.0).contains(&simd_result));
606    }
607
608    #[test]
609    #[should_panic(expected = "Vector dimensions must match")]
610    fn test_dot_product_simd_dimension_mismatch() {
611        let a = vec![1.0, 2.0];
612        let b = vec![1.0, 2.0, 3.0];
613        let _ = dot_product_simd(&a, &b);
614    }
615
616    #[test]
617    #[should_panic(expected = "Vector dimensions must match")]
618    fn test_l2_distance_simd_dimension_mismatch() {
619        let a = vec![1.0, 2.0];
620        let b = vec![1.0, 2.0, 3.0];
621        let _ = l2_distance_simd(&a, &b);
622    }
623
624    #[test]
625    #[should_panic(expected = "Vector dimensions must match")]
626    fn test_cosine_similarity_simd_dimension_mismatch() {
627        let a = vec![1.0, 2.0];
628        let b = vec![1.0, 2.0, 3.0];
629        let _ = cosine_similarity_simd(&a, &b);
630    }
631
632    #[test]
633    fn test_norm_simd() {
634        let v = vec![3.0, 4.0];
635        let norm = norm_simd(&v);
636        assert!((norm - 5.0).abs() < EPSILON);
637    }
638
639    #[test]
640    fn test_cosine_distance_prenorm_matches_old() {
641        // Compare fused prenorm distance against the original cosine_similarity_simd path
642        for size in [3, 4, 8, 16, 32, 64, 128, 384, 768] {
643            let a: Vec<f32> = (0..size).map(|i| (i as f32) / (size as f32)).collect();
644            let b: Vec<f32> = (0..size)
645                .map(|i| 1.0 - (i as f32) / (size as f32))
646                .collect();
647
648            let old_dist = 1.0 - cosine_similarity_simd(&a, &b);
649            let norm_b = norm_simd(&b);
650            let new_dist = cosine_distance_prenorm(&a, &b, norm_b);
651
652            let epsilon = if size > 100 { 1e-4 } else { EPSILON };
653            assert!(
654                (old_dist - new_dist).abs() < epsilon,
655                "Size {}: old={}, new={}",
656                size,
657                old_dist,
658                new_dist
659            );
660        }
661    }
662
663    #[test]
664    fn test_cosine_distance_prenorm_misaligned_subslice_regression() {
665        let size = 257;
666        let a_storage: Vec<f32> = (0..(size + 3))
667            .map(|i| (((i as f32) % 13.0) - 6.0) * 0.41)
668            .collect();
669        let b_storage: Vec<f32> = (0..(size + 4))
670            .map(|i| (((i as f32) % 11.0) - 5.0) * -0.23)
671            .collect();
672
673        let a = &a_storage[1..(1 + size)];
674        let b = &b_storage[2..(2 + size)];
675        let norm_b = norm_simd(b);
676
677        let simd_result = cosine_distance_prenorm(a, b, norm_b);
678        let scalar_result = 1.0 - cosine_similarity(a, b).unwrap();
679        assert!((simd_result - scalar_result).abs() < 1e-4);
680    }
681
682    #[test]
683    fn test_cosine_distance_prenorm_zero_vectors() {
684        let a = vec![0.0, 0.0, 0.0];
685        let b = vec![1.0, 2.0, 3.0];
686        let norm_b = norm_simd(&b);
687        // Zero query should return distance 1.0
688        assert!((cosine_distance_prenorm(&a, &b, norm_b) - 1.0).abs() < EPSILON);
689        // Zero stored vector (norm_b=0) should return distance 1.0
690        assert!((cosine_distance_prenorm(&b, &a, 0.0) - 1.0).abs() < EPSILON);
691    }
692
693    #[test]
694    fn test_cosine_distance_prenorm_identical() {
695        let a = vec![1.0, 2.0, 3.0];
696        let norm_a = norm_simd(&a);
697        let dist = cosine_distance_prenorm(&a, &a, norm_a);
698        assert!(
699            dist.abs() < EPSILON,
700            "Identical vectors should have distance ~0, got {}",
701            dist
702        );
703    }
704
705    #[test]
706    fn test_cosine_distance_prenorm_opposite() {
707        let a = vec![1.0, 2.0, 3.0];
708        let b: Vec<f32> = a.iter().map(|x| -x).collect();
709        let norm_b = norm_simd(&b);
710        let dist = cosine_distance_prenorm(&a, &b, norm_b);
711        assert!(
712            (dist - 2.0).abs() < EPSILON,
713            "Opposite vectors should have distance ~2, got {}",
714            dist
715        );
716    }
717}