1#[cfg(target_arch = "x86_64")]
6use std::arch::x86_64::*;
7
8#[cfg(target_arch = "aarch64")]
9use std::arch::aarch64::*;
10
11#[cfg(target_arch = "x86_64")]
13const MIN_DIM_SIZE_AVX: usize = 32;
14
15#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
16const MIN_DIM_SIZE_SIMD: usize = 16;
17
18#[inline]
22pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
23 if a.len() != b.len() {
24 return 0.0;
25 }
26
27 #[cfg(target_arch = "x86_64")]
29 {
30 if is_x86_feature_detected!("avx2")
31 && is_x86_feature_detected!("fma")
32 && a.len() >= MIN_DIM_SIZE_AVX
33 {
34 return unsafe { dot_product_avx2(a, b) };
35 }
36 }
37
38 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
39 {
40 if is_x86_feature_detected!("sse") && a.len() >= MIN_DIM_SIZE_SIMD {
41 return unsafe { dot_product_sse(a, b) };
42 }
43 }
44
45 #[cfg(target_arch = "aarch64")]
46 {
47 if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD {
48 return unsafe { dot_product_neon(a, b) };
49 }
50 }
51
52 dot_product_scalar(a, b)
55}
56
57#[cfg(target_arch = "x86_64")]
60#[target_feature(enable = "avx2", enable = "fma")]
61#[inline]
62unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
63 let dim = a.len();
64 let mut i = 0;
65
66 let mut sum1 = _mm256_setzero_ps();
67 let mut sum2 = _mm256_setzero_ps();
68
69 while i + 15 < dim {
71 let vx1 = _mm256_loadu_ps(a.as_ptr().add(i));
72 let vy1 = _mm256_loadu_ps(b.as_ptr().add(i));
73 let vx2 = _mm256_loadu_ps(a.as_ptr().add(i + 8));
74 let vy2 = _mm256_loadu_ps(b.as_ptr().add(i + 8));
75
76 sum1 = _mm256_fmadd_ps(vx1, vy1, sum1);
77 sum2 = _mm256_fmadd_ps(vx2, vy2, sum2);
78
79 i += 16;
80 }
81
82 let combined = _mm256_add_ps(sum1, sum2);
84
85 let sum_high = _mm256_extractf128_ps(combined, 1);
87 let sum_low = _mm256_castps256_ps128(combined);
88 let mut sum_128 = _mm_add_ps(sum_high, sum_low);
89
90 sum_128 = _mm_hadd_ps(sum_128, sum_128);
91 sum_128 = _mm_hadd_ps(sum_128, sum_128);
92
93 let mut dot = _mm_cvtss_f32(sum_128);
94
95 while i < dim {
97 dot += a[i] * b[i];
98 i += 1;
99 }
100
101 dot
102}
103
104#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
106#[target_feature(enable = "sse")]
107#[inline]
108unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
109 #[cfg(target_arch = "x86")]
110 use std::arch::x86::*;
111 #[cfg(target_arch = "x86_64")]
112 use std::arch::x86_64::*;
113
114 let dim = a.len();
115 let mut i = 0;
116 let mut sum = _mm_setzero_ps();
117
118 while i + 3 < dim {
120 let va = _mm_loadu_ps(a.as_ptr().add(i));
121 let vb = _mm_loadu_ps(b.as_ptr().add(i));
122 sum = _mm_add_ps(sum, _mm_mul_ps(va, vb));
123 i += 4;
124 }
125
126 let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
128 sum = _mm_add_ps(sum, shuf);
129 let shuf = _mm_movehl_ps(sum, sum);
130 sum = _mm_add_ss(sum, shuf);
131
132 let mut dot = _mm_cvtss_f32(sum);
133
134 while i < dim {
136 dot += a[i] * b[i];
137 i += 1;
138 }
139
140 dot
141}
142
143#[cfg(target_arch = "aarch64")]
146#[target_feature(enable = "neon")]
147#[inline]
148unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
149 let dim = a.len();
150 let mut i = 0;
151
152 let mut sum1 = vdupq_n_f32(0.0);
154 let mut sum2 = vdupq_n_f32(0.0);
155
156 while i + 7 < dim {
158 let va1 = vld1q_f32(a.as_ptr().add(i));
159 let vb1 = vld1q_f32(b.as_ptr().add(i));
160 let va2 = vld1q_f32(a.as_ptr().add(i + 4));
161 let vb2 = vld1q_f32(b.as_ptr().add(i + 4));
162
163 sum1 = vfmaq_f32(sum1, va1, vb1);
164 sum2 = vfmaq_f32(sum2, va2, vb2);
165
166 i += 8;
167 }
168
169 while i + 3 < dim {
171 let va = vld1q_f32(a.as_ptr().add(i));
172 let vb = vld1q_f32(b.as_ptr().add(i));
173 sum1 = vfmaq_f32(sum1, va, vb);
174 i += 4;
175 }
176
177 let combined = vaddq_f32(sum1, sum2);
179 let mut dot = vaddvq_f32(combined);
180
181 while i < dim {
183 dot += a[i] * b[i];
184 i += 1;
185 }
186
187 dot
188}
189
190#[inline]
192fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
193 let mut dot0 = 0.0f32;
194 let mut dot1 = 0.0f32;
195
196 let chunks = a.chunks_exact(8);
198 let remainder = chunks.remainder();
199 let b_chunks = b.chunks_exact(8);
200
201 for (a_chunk, b_chunk) in chunks.zip(b_chunks) {
202 dot0 += a_chunk[0] * b_chunk[0] +
203 a_chunk[1] * b_chunk[1] +
204 a_chunk[2] * b_chunk[2] +
205 a_chunk[3] * b_chunk[3];
206
207 dot1 += a_chunk[4] * b_chunk[4] +
208 a_chunk[5] * b_chunk[5] +
209 a_chunk[6] * b_chunk[6] +
210 a_chunk[7] * b_chunk[7];
211 }
212
213 for i in (a.len() - remainder.len())..a.len() {
215 dot0 += a[i] * b[i];
216 }
217
218 dot0 + dot1
219}
220
221
222#[inline]
224pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
225 if a.len() != b.len() {
226 return f32::INFINITY;
227 }
228
229 #[cfg(target_arch = "x86_64")]
231 {
232 if is_x86_feature_detected!("avx2")
233 && is_x86_feature_detected!("fma")
234 && a.len() >= MIN_DIM_SIZE_AVX
235 {
236 return unsafe { l2_distance_avx2(a, b) };
237 }
238 }
239
240 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
241 {
242 if is_x86_feature_detected!("sse") && a.len() >= MIN_DIM_SIZE_SIMD {
243 return unsafe { l2_distance_sse(a, b) };
244 }
245 }
246
247 #[cfg(target_arch = "aarch64")]
248 {
249 if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD {
250 return unsafe { l2_distance_neon(a, b) };
251 }
252 }
253
254 l2_distance_scalar(a, b)
255}
256
257#[cfg(target_arch = "x86_64")]
259#[target_feature(enable = "avx2", enable = "fma")]
260#[inline]
261unsafe fn l2_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
262 let dim = a.len();
263 let mut i = 0;
264
265 let mut sum1 = _mm256_setzero_ps();
266 let mut sum2 = _mm256_setzero_ps();
267
268 while i + 15 < dim {
270 let va1 = _mm256_loadu_ps(a.as_ptr().add(i));
271 let vb1 = _mm256_loadu_ps(b.as_ptr().add(i));
272 let va2 = _mm256_loadu_ps(a.as_ptr().add(i + 8));
273 let vb2 = _mm256_loadu_ps(b.as_ptr().add(i + 8));
274
275 let diff1 = _mm256_sub_ps(va1, vb1);
276 let diff2 = _mm256_sub_ps(va2, vb2);
277
278 sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
279 sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
280
281 i += 16;
282 }
283
284 let combined = _mm256_add_ps(sum1, sum2);
286
287 let sum_high = _mm256_extractf128_ps(combined, 1);
289 let sum_low = _mm256_castps256_ps128(combined);
290 let mut sum_128 = _mm_add_ps(sum_high, sum_low);
291
292 sum_128 = _mm_hadd_ps(sum_128, sum_128);
293 sum_128 = _mm_hadd_ps(sum_128, sum_128);
294
295 let mut sum_sq = _mm_cvtss_f32(sum_128);
296
297 while i < dim {
299 let diff = a[i] - b[i];
300 sum_sq += diff * diff;
301 i += 1;
302 }
303
304 sum_sq.sqrt()
305}
306
307#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
309#[target_feature(enable = "sse")]
310#[inline]
311unsafe fn l2_distance_sse(a: &[f32], b: &[f32]) -> f32 {
312 #[cfg(target_arch = "x86")]
313 use std::arch::x86::*;
314 #[cfg(target_arch = "x86_64")]
315 use std::arch::x86_64::*;
316
317 let dim = a.len();
318 let mut i = 0;
319 let mut sum = _mm_setzero_ps();
320
321 while i + 3 < dim {
323 let va = _mm_loadu_ps(a.as_ptr().add(i));
324 let vb = _mm_loadu_ps(b.as_ptr().add(i));
325 let diff = _mm_sub_ps(va, vb);
326 sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
327 i += 4;
328 }
329
330 let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
332 sum = _mm_add_ps(sum, shuf);
333 let shuf = _mm_movehl_ps(sum, sum);
334 sum = _mm_add_ss(sum, shuf);
335
336 let mut sum_sq = _mm_cvtss_f32(sum);
337
338 while i < dim {
340 let diff = a[i] - b[i];
341 sum_sq += diff * diff;
342 i += 1;
343 }
344
345 sum_sq.sqrt()
346}
347
348#[cfg(target_arch = "aarch64")]
351#[target_feature(enable = "neon")]
352#[inline]
353unsafe fn l2_distance_neon(a: &[f32], b: &[f32]) -> f32 {
354 let dim = a.len();
355 let mut i = 0;
356
357 let mut sum1 = vdupq_n_f32(0.0);
359 let mut sum2 = vdupq_n_f32(0.0);
360
361 while i + 7 < dim {
363 let va1 = vld1q_f32(a.as_ptr().add(i));
364 let vb1 = vld1q_f32(b.as_ptr().add(i));
365 let va2 = vld1q_f32(a.as_ptr().add(i + 4));
366 let vb2 = vld1q_f32(b.as_ptr().add(i + 4));
367
368 let diff1 = vsubq_f32(va1, vb1);
369 let diff2 = vsubq_f32(va2, vb2);
370
371 sum1 = vfmaq_f32(sum1, diff1, diff1);
372 sum2 = vfmaq_f32(sum2, diff2, diff2);
373
374 i += 8;
375 }
376
377 while i + 3 < dim {
379 let va = vld1q_f32(a.as_ptr().add(i));
380 let vb = vld1q_f32(b.as_ptr().add(i));
381 let diff = vsubq_f32(va, vb);
382 sum1 = vfmaq_f32(sum1, diff, diff);
383 i += 4;
384 }
385
386 let combined = vaddq_f32(sum1, sum2);
388 let mut sum_sq = vaddvq_f32(combined);
389
390 while i < dim {
392 let diff = a[i] - b[i];
393 sum_sq += diff * diff;
394 i += 1;
395 }
396
397 sum_sq.sqrt()
398}
399
400#[inline]
402fn l2_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
403 let mut sum0 = 0.0f32;
404 let mut sum1 = 0.0f32;
405
406 let chunks = a.chunks_exact(4);
408 let remainder = chunks.remainder();
409 let b_chunks = b.chunks_exact(4);
410
411 for (a_chunk, b_chunk) in chunks.zip(b_chunks) {
412 let d0 = a_chunk[0] - b_chunk[0];
413 let d1 = a_chunk[1] - b_chunk[1];
414 let d2 = a_chunk[2] - b_chunk[2];
415 let d3 = a_chunk[3] - b_chunk[3];
416
417 sum0 += d0 * d0 + d1 * d1;
418 sum1 += d2 * d2 + d3 * d3;
419 }
420
421 for i in (a.len() - remainder.len())..a.len() {
423 let diff = a[i] - b[i];
424 sum0 += diff * diff;
425 }
426
427 (sum0 + sum1).sqrt()
428}
429
430#[inline]
432pub fn norm_squared_simd(v: &[f32]) -> f32 {
433 dot_product_simd(v, v)
434}
435
436#[inline]
438pub fn norm_simd(v: &[f32]) -> f32 {
439 norm_squared_simd(v).sqrt()
440}
441