distx_core/
simd.rs

1// SIMD optimizations for vector operations
2// Inspired by Redis (prefetch patterns, scalar fallbacks) and Qdrant (SIMD hierarchy)
3// Uses platform-specific SIMD intrinsics for maximum performance
4
5#[cfg(target_arch = "x86_64")]
6use std::arch::x86_64::*;
7
8#[cfg(target_arch = "aarch64")]
9use std::arch::aarch64::*;
10
11// Minimum dimension sizes for SIMD (qdrant pattern)
12#[cfg(target_arch = "x86_64")]
13const MIN_DIM_SIZE_AVX: usize = 32;
14
15#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
16const MIN_DIM_SIZE_SIMD: usize = 16;
17
18/// SIMD-optimized dot product for cosine similarity
19/// Vectors should be normalized for cosine similarity
20/// Uses optimized scalar code with better pipelining (like Redis fallback)
21#[inline]
22pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
23    if a.len() != b.len() {
24        return 0.0;
25    }
26    
27    // Try platform-specific SIMD if available (qdrant hierarchy pattern)
28    #[cfg(target_arch = "x86_64")]
29    {
30        if is_x86_feature_detected!("avx2") 
31            && is_x86_feature_detected!("fma") 
32            && a.len() >= MIN_DIM_SIZE_AVX 
33        {
34            return unsafe { dot_product_avx2(a, b) };
35        }
36    }
37    
38    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
39    {
40        if is_x86_feature_detected!("sse") && a.len() >= MIN_DIM_SIZE_SIMD {
41            return unsafe { dot_product_sse(a, b) };
42        }
43    }
44    
45    #[cfg(target_arch = "aarch64")]
46    {
47        if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD {
48            return unsafe { dot_product_neon(a, b) };
49        }
50    }
51    
52    // Optimized scalar fallback (like Redis's scalar implementation)
53    // Uses two accumulators for better pipelining
54    dot_product_scalar(a, b)
55}
56
57/// AVX2-optimized dot product (16 floats at a time)
58/// Inspired by Redis's vectors_distance_float_avx2
59#[cfg(target_arch = "x86_64")]
60#[target_feature(enable = "avx2", enable = "fma")]
61#[inline]
62unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
63    let dim = a.len();
64    let mut i = 0;
65    
66    let mut sum1 = _mm256_setzero_ps();
67    let mut sum2 = _mm256_setzero_ps();
68    
69    // Process 16 floats at a time with two AVX2 registers
70    while i + 15 < dim {
71        let vx1 = _mm256_loadu_ps(a.as_ptr().add(i));
72        let vy1 = _mm256_loadu_ps(b.as_ptr().add(i));
73        let vx2 = _mm256_loadu_ps(a.as_ptr().add(i + 8));
74        let vy2 = _mm256_loadu_ps(b.as_ptr().add(i + 8));
75        
76        sum1 = _mm256_fmadd_ps(vx1, vy1, sum1);
77        sum2 = _mm256_fmadd_ps(vx2, vy2, sum2);
78        
79        i += 16;
80    }
81    
82    // Combine the two sums
83    let combined = _mm256_add_ps(sum1, sum2);
84    
85    // Horizontal sum of the 8 elements
86    let sum_high = _mm256_extractf128_ps(combined, 1);
87    let sum_low = _mm256_castps256_ps128(combined);
88    let mut sum_128 = _mm_add_ps(sum_high, sum_low);
89    
90    sum_128 = _mm_hadd_ps(sum_128, sum_128);
91    sum_128 = _mm_hadd_ps(sum_128, sum_128);
92    
93    let mut dot = _mm_cvtss_f32(sum_128);
94    
95    // Handle remaining elements
96    while i < dim {
97        dot += a[i] * b[i];
98        i += 1;
99    }
100    
101    dot
102}
103
104/// SSE-optimized dot product (qdrant compatibility pattern)
105#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
106#[target_feature(enable = "sse")]
107#[inline]
108unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
109    #[cfg(target_arch = "x86")]
110    use std::arch::x86::*;
111    #[cfg(target_arch = "x86_64")]
112    use std::arch::x86_64::*;
113    
114    let dim = a.len();
115    let mut i = 0;
116    let mut sum = _mm_setzero_ps();
117    
118    // Process 4 floats at a time
119    while i + 3 < dim {
120        let va = _mm_loadu_ps(a.as_ptr().add(i));
121        let vb = _mm_loadu_ps(b.as_ptr().add(i));
122        sum = _mm_add_ps(sum, _mm_mul_ps(va, vb));
123        i += 4;
124    }
125    
126    // Horizontal sum
127    let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
128    sum = _mm_add_ps(sum, shuf);
129    let shuf = _mm_movehl_ps(sum, sum);
130    sum = _mm_add_ss(sum, shuf);
131    
132    let mut dot = _mm_cvtss_f32(sum);
133    
134    // Handle remaining elements
135    while i < dim {
136        dot += a[i] * b[i];
137        i += 1;
138    }
139    
140    dot
141}
142
143/// NEON-optimized dot product for ARM/Apple Silicon
144/// Uses 8-wide processing with two NEON registers for better throughput
145#[cfg(target_arch = "aarch64")]
146#[target_feature(enable = "neon")]
147#[inline]
148unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
149    let dim = a.len();
150    let mut i = 0;
151    
152    // Use two accumulators for better instruction pipelining
153    let mut sum1 = vdupq_n_f32(0.0);
154    let mut sum2 = vdupq_n_f32(0.0);
155    
156    // Process 8 floats at a time with two NEON registers
157    while i + 7 < dim {
158        let va1 = vld1q_f32(a.as_ptr().add(i));
159        let vb1 = vld1q_f32(b.as_ptr().add(i));
160        let va2 = vld1q_f32(a.as_ptr().add(i + 4));
161        let vb2 = vld1q_f32(b.as_ptr().add(i + 4));
162        
163        sum1 = vfmaq_f32(sum1, va1, vb1);
164        sum2 = vfmaq_f32(sum2, va2, vb2);
165        
166        i += 8;
167    }
168    
169    // Process remaining 4 floats
170    while i + 3 < dim {
171        let va = vld1q_f32(a.as_ptr().add(i));
172        let vb = vld1q_f32(b.as_ptr().add(i));
173        sum1 = vfmaq_f32(sum1, va, vb);
174        i += 4;
175    }
176    
177    // Combine accumulators and horizontal sum
178    let combined = vaddq_f32(sum1, sum2);
179    let mut dot = vaddvq_f32(combined);
180    
181    // Handle remaining elements
182    while i < dim {
183        dot += a[i] * b[i];
184        i += 1;
185    }
186    
187    dot
188}
189
190/// Scalar fallback (two accumulators for better pipelining - Redis pattern)
191#[inline]
192fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
193    let mut dot0 = 0.0f32;
194    let mut dot1 = 0.0f32;
195    
196    // Process 8 elements at a time with two accumulators
197    let chunks = a.chunks_exact(8);
198    let remainder = chunks.remainder();
199    let b_chunks = b.chunks_exact(8);
200    
201    for (a_chunk, b_chunk) in chunks.zip(b_chunks) {
202        dot0 += a_chunk[0] * b_chunk[0] +
203                a_chunk[1] * b_chunk[1] +
204                a_chunk[2] * b_chunk[2] +
205                a_chunk[3] * b_chunk[3];
206        
207        dot1 += a_chunk[4] * b_chunk[4] +
208                a_chunk[5] * b_chunk[5] +
209                a_chunk[6] * b_chunk[6] +
210                a_chunk[7] * b_chunk[7];
211    }
212    
213    // Handle remainder
214    for i in (a.len() - remainder.len())..a.len() {
215        dot0 += a[i] * b[i];
216    }
217    
218    dot0 + dot1
219}
220
221
222/// SIMD-optimized L2 distance (Euclidean)
223#[inline]
224pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
225    if a.len() != b.len() {
226        return f32::INFINITY;
227    }
228    
229    // Try platform-specific SIMD if available (qdrant hierarchy pattern)
230    #[cfg(target_arch = "x86_64")]
231    {
232        if is_x86_feature_detected!("avx2") 
233            && is_x86_feature_detected!("fma") 
234            && a.len() >= MIN_DIM_SIZE_AVX 
235        {
236            return unsafe { l2_distance_avx2(a, b) };
237        }
238    }
239    
240    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
241    {
242        if is_x86_feature_detected!("sse") && a.len() >= MIN_DIM_SIZE_SIMD {
243            return unsafe { l2_distance_sse(a, b) };
244        }
245    }
246    
247    #[cfg(target_arch = "aarch64")]
248    {
249        if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD {
250            return unsafe { l2_distance_neon(a, b) };
251        }
252    }
253    
254    l2_distance_scalar(a, b)
255}
256
257/// AVX2-optimized L2 distance
258#[cfg(target_arch = "x86_64")]
259#[target_feature(enable = "avx2", enable = "fma")]
260#[inline]
261unsafe fn l2_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
262    let dim = a.len();
263    let mut i = 0;
264    
265    let mut sum1 = _mm256_setzero_ps();
266    let mut sum2 = _mm256_setzero_ps();
267    
268    // Process 16 floats at a time with two AVX2 registers
269    while i + 15 < dim {
270        let va1 = _mm256_loadu_ps(a.as_ptr().add(i));
271        let vb1 = _mm256_loadu_ps(b.as_ptr().add(i));
272        let va2 = _mm256_loadu_ps(a.as_ptr().add(i + 8));
273        let vb2 = _mm256_loadu_ps(b.as_ptr().add(i + 8));
274        
275        let diff1 = _mm256_sub_ps(va1, vb1);
276        let diff2 = _mm256_sub_ps(va2, vb2);
277        
278        sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
279        sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
280        
281        i += 16;
282    }
283    
284    // Combine the two sums
285    let combined = _mm256_add_ps(sum1, sum2);
286    
287    // Horizontal sum of the 8 elements
288    let sum_high = _mm256_extractf128_ps(combined, 1);
289    let sum_low = _mm256_castps256_ps128(combined);
290    let mut sum_128 = _mm_add_ps(sum_high, sum_low);
291    
292    sum_128 = _mm_hadd_ps(sum_128, sum_128);
293    sum_128 = _mm_hadd_ps(sum_128, sum_128);
294    
295    let mut sum_sq = _mm_cvtss_f32(sum_128);
296    
297    // Handle remaining elements
298    while i < dim {
299        let diff = a[i] - b[i];
300        sum_sq += diff * diff;
301        i += 1;
302    }
303    
304    sum_sq.sqrt()
305}
306
307/// SSE-optimized L2 distance
308#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
309#[target_feature(enable = "sse")]
310#[inline]
311unsafe fn l2_distance_sse(a: &[f32], b: &[f32]) -> f32 {
312    #[cfg(target_arch = "x86")]
313    use std::arch::x86::*;
314    #[cfg(target_arch = "x86_64")]
315    use std::arch::x86_64::*;
316    
317    let dim = a.len();
318    let mut i = 0;
319    let mut sum = _mm_setzero_ps();
320    
321    // Process 4 floats at a time
322    while i + 3 < dim {
323        let va = _mm_loadu_ps(a.as_ptr().add(i));
324        let vb = _mm_loadu_ps(b.as_ptr().add(i));
325        let diff = _mm_sub_ps(va, vb);
326        sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
327        i += 4;
328    }
329    
330    // Horizontal sum
331    let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
332    sum = _mm_add_ps(sum, shuf);
333    let shuf = _mm_movehl_ps(sum, sum);
334    sum = _mm_add_ss(sum, shuf);
335    
336    let mut sum_sq = _mm_cvtss_f32(sum);
337    
338    // Handle remaining elements
339    while i < dim {
340        let diff = a[i] - b[i];
341        sum_sq += diff * diff;
342        i += 1;
343    }
344    
345    sum_sq.sqrt()
346}
347
348/// NEON-optimized L2 distance for ARM/Apple Silicon
349/// Uses 8-wide processing with two NEON registers for better throughput
350#[cfg(target_arch = "aarch64")]
351#[target_feature(enable = "neon")]
352#[inline]
353unsafe fn l2_distance_neon(a: &[f32], b: &[f32]) -> f32 {
354    let dim = a.len();
355    let mut i = 0;
356    
357    // Use two accumulators for better instruction pipelining
358    let mut sum1 = vdupq_n_f32(0.0);
359    let mut sum2 = vdupq_n_f32(0.0);
360    
361    // Process 8 floats at a time with two NEON registers
362    while i + 7 < dim {
363        let va1 = vld1q_f32(a.as_ptr().add(i));
364        let vb1 = vld1q_f32(b.as_ptr().add(i));
365        let va2 = vld1q_f32(a.as_ptr().add(i + 4));
366        let vb2 = vld1q_f32(b.as_ptr().add(i + 4));
367        
368        let diff1 = vsubq_f32(va1, vb1);
369        let diff2 = vsubq_f32(va2, vb2);
370        
371        sum1 = vfmaq_f32(sum1, diff1, diff1);
372        sum2 = vfmaq_f32(sum2, diff2, diff2);
373        
374        i += 8;
375    }
376    
377    // Process remaining 4 floats
378    while i + 3 < dim {
379        let va = vld1q_f32(a.as_ptr().add(i));
380        let vb = vld1q_f32(b.as_ptr().add(i));
381        let diff = vsubq_f32(va, vb);
382        sum1 = vfmaq_f32(sum1, diff, diff);
383        i += 4;
384    }
385    
386    // Combine accumulators and horizontal sum
387    let combined = vaddq_f32(sum1, sum2);
388    let mut sum_sq = vaddvq_f32(combined);
389    
390    // Handle remaining elements
391    while i < dim {
392        let diff = a[i] - b[i];
393        sum_sq += diff * diff;
394        i += 1;
395    }
396    
397    sum_sq.sqrt()
398}
399
400/// Scalar L2 distance (two accumulators for better pipelining)
401#[inline]
402fn l2_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
403    let mut sum0 = 0.0f32;
404    let mut sum1 = 0.0f32;
405    
406    // Process 4 elements at a time with two accumulators
407    let chunks = a.chunks_exact(4);
408    let remainder = chunks.remainder();
409    let b_chunks = b.chunks_exact(4);
410    
411    for (a_chunk, b_chunk) in chunks.zip(b_chunks) {
412        let d0 = a_chunk[0] - b_chunk[0];
413        let d1 = a_chunk[1] - b_chunk[1];
414        let d2 = a_chunk[2] - b_chunk[2];
415        let d3 = a_chunk[3] - b_chunk[3];
416        
417        sum0 += d0 * d0 + d1 * d1;
418        sum1 += d2 * d2 + d3 * d3;
419    }
420    
421    // Handle remainder
422    for i in (a.len() - remainder.len())..a.len() {
423        let diff = a[i] - b[i];
424        sum0 += diff * diff;
425    }
426    
427    (sum0 + sum1).sqrt()
428}
429
430/// SIMD-optimized vector norm (squared length)
431#[inline]
432pub fn norm_squared_simd(v: &[f32]) -> f32 {
433    dot_product_simd(v, v)
434}
435
436/// SIMD-optimized vector norm (length)
437#[inline]
438pub fn norm_simd(v: &[f32]) -> f32 {
439    norm_squared_simd(v).sqrt()
440}
441