1use ndarray::Array1;
2use std::arch::x86_64::*;
3
4pub struct SimdVectorOps;
13
14impl SimdVectorOps {
15 pub fn dot_product(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
26 debug_assert_eq!(a.len(), b.len(), "Vector lengths must match");
27
28 let len = a.len();
29 let a_slice = a.as_slice().unwrap();
30 let b_slice = b.as_slice().unwrap();
31
32 #[cfg(target_arch = "x86_64")]
33 {
34 if is_x86_feature_detected!("fma") && is_x86_feature_detected!("avx2") {
35 unsafe { Self::dot_product_avx2_fma(a_slice, b_slice, len) }
36 } else if is_x86_feature_detected!("avx2") {
37 unsafe { Self::dot_product_avx2(a_slice, b_slice, len) }
38 } else if is_x86_feature_detected!("sse4.1") {
39 unsafe { Self::dot_product_sse41(a_slice, b_slice, len) }
40 } else {
41 Self::dot_product_fallback(a_slice, b_slice)
42 }
43 }
44
45 #[cfg(not(target_arch = "x86_64"))]
46 {
47 Self::dot_product_fallback(a_slice, b_slice)
48 }
49 }
50
51 pub fn squared_norm(a: &Array1<f32>) -> f32 {
53 let a_slice = a.as_slice().unwrap();
54 let len = a.len();
55
56 #[cfg(target_arch = "x86_64")]
57 {
58 if is_x86_feature_detected!("fma") && is_x86_feature_detected!("avx2") {
59 unsafe { Self::squared_norm_avx2_fma(a_slice, len) }
60 } else if is_x86_feature_detected!("avx2") {
61 unsafe { Self::squared_norm_avx2(a_slice, len) }
62 } else if is_x86_feature_detected!("sse4.1") {
63 unsafe { Self::squared_norm_sse41(a_slice, len) }
64 } else {
65 Self::squared_norm_fallback(a_slice)
66 }
67 }
68
69 #[cfg(not(target_arch = "x86_64"))]
70 {
71 Self::squared_norm_fallback(a_slice)
72 }
73 }
74
75 pub fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
77 let dot = Self::dot_product(a, b);
78 let norm_a = Self::squared_norm(a).sqrt();
79 let norm_b = Self::squared_norm(b).sqrt();
80
81 if norm_a == 0.0 || norm_b == 0.0 {
82 0.0
83 } else {
84 dot / (norm_a * norm_b)
85 }
86 }
87
88 pub fn add_vectors(a: &Array1<f32>, b: &Array1<f32>) -> Array1<f32> {
90 debug_assert_eq!(a.len(), b.len(), "Vector lengths must match");
91
92 let len = a.len();
93 let a_slice = a.as_slice().unwrap();
94 let b_slice = b.as_slice().unwrap();
95 let mut result = Array1::zeros(len);
96 let result_slice = result.as_slice_mut().unwrap();
97
98 #[cfg(target_arch = "x86_64")]
99 {
100 if is_x86_feature_detected!("avx2") {
101 unsafe { Self::add_vectors_avx2(a_slice, b_slice, result_slice, len) }
102 } else if is_x86_feature_detected!("sse4.1") {
103 unsafe { Self::add_vectors_sse41(a_slice, b_slice, result_slice, len) }
104 } else {
105 Self::add_vectors_fallback(a_slice, b_slice, result_slice)
106 }
107 }
108
109 #[cfg(not(target_arch = "x86_64"))]
110 {
111 Self::add_vectors_fallback(a_slice, b_slice, result_slice)
112 }
113
114 result
115 }
116
117 pub fn scale_vector(a: &Array1<f32>, scalar: f32) -> Array1<f32> {
119 let len = a.len();
120 let a_slice = a.as_slice().unwrap();
121 let mut result = Array1::zeros(len);
122 let result_slice = result.as_slice_mut().unwrap();
123
124 #[cfg(target_arch = "x86_64")]
125 {
126 if is_x86_feature_detected!("avx2") {
127 unsafe { Self::scale_vector_avx2(a_slice, scalar, result_slice, len) }
128 } else if is_x86_feature_detected!("sse4.1") {
129 unsafe { Self::scale_vector_sse41(a_slice, scalar, result_slice, len) }
130 } else {
131 Self::scale_vector_fallback(a_slice, scalar, result_slice)
132 }
133 }
134
135 #[cfg(not(target_arch = "x86_64"))]
136 {
137 Self::scale_vector_fallback(a_slice, scalar, result_slice)
138 }
139
140 result
141 }
142
143 #[cfg(target_arch = "x86_64")]
148 #[target_feature(enable = "avx2,fma")]
149 unsafe fn dot_product_avx2_fma(a: &[f32], b: &[f32], len: usize) -> f32 {
150 let mut sum = _mm256_setzero_ps();
151 let chunks = len / 8;
152
153 let unroll_chunks = chunks / 4;
155 let mut i = 0;
156
157 for _ in 0..unroll_chunks {
158 let a_vec1 = _mm256_loadu_ps(a.as_ptr().add(i * 8));
160 let b_vec1 = _mm256_loadu_ps(b.as_ptr().add(i * 8));
161 sum = _mm256_fmadd_ps(a_vec1, b_vec1, sum);
162
163 let a_vec2 = _mm256_loadu_ps(a.as_ptr().add((i + 1) * 8));
164 let b_vec2 = _mm256_loadu_ps(b.as_ptr().add((i + 1) * 8));
165 sum = _mm256_fmadd_ps(a_vec2, b_vec2, sum);
166
167 let a_vec3 = _mm256_loadu_ps(a.as_ptr().add((i + 2) * 8));
168 let b_vec3 = _mm256_loadu_ps(b.as_ptr().add((i + 2) * 8));
169 sum = _mm256_fmadd_ps(a_vec3, b_vec3, sum);
170
171 let a_vec4 = _mm256_loadu_ps(a.as_ptr().add((i + 3) * 8));
172 let b_vec4 = _mm256_loadu_ps(b.as_ptr().add((i + 3) * 8));
173 sum = _mm256_fmadd_ps(a_vec4, b_vec4, sum);
174
175 i += 4;
176 }
177
178 for j in i..chunks {
180 let a_vec = _mm256_loadu_ps(a.as_ptr().add(j * 8));
181 let b_vec = _mm256_loadu_ps(b.as_ptr().add(j * 8));
182 sum = _mm256_fmadd_ps(a_vec, b_vec, sum);
183 }
184
185 let sum_low = _mm256_extractf128_ps(sum, 0);
187 let sum_high = _mm256_extractf128_ps(sum, 1);
188 let sum_combined = _mm_add_ps(sum_low, sum_high);
189
190 let sum_shuffled = _mm_shuffle_ps(sum_combined, sum_combined, 0b01_00_11_10);
191 let sum_partial = _mm_add_ps(sum_combined, sum_shuffled);
192 let sum_final_shuffled = _mm_shuffle_ps(sum_partial, sum_partial, 0b00_00_00_01);
193 let final_sum = _mm_add_ps(sum_partial, sum_final_shuffled);
194
195 let mut result = _mm_cvtss_f32(final_sum);
196
197 for k in (chunks * 8)..len {
199 result += a[k] * b[k];
200 }
201
202 result
203 }
204
205 #[cfg(target_arch = "x86_64")]
206 #[target_feature(enable = "avx2,fma")]
207 unsafe fn squared_norm_avx2_fma(a: &[f32], len: usize) -> f32 {
208 let mut sum = _mm256_setzero_ps();
209 let chunks = len / 8;
210
211 let unroll_chunks = chunks / 4;
213 let mut i = 0;
214
215 for _ in 0..unroll_chunks {
216 let a_vec1 = _mm256_loadu_ps(a.as_ptr().add(i * 8));
217 sum = _mm256_fmadd_ps(a_vec1, a_vec1, sum);
218
219 let a_vec2 = _mm256_loadu_ps(a.as_ptr().add((i + 1) * 8));
220 sum = _mm256_fmadd_ps(a_vec2, a_vec2, sum);
221
222 let a_vec3 = _mm256_loadu_ps(a.as_ptr().add((i + 2) * 8));
223 sum = _mm256_fmadd_ps(a_vec3, a_vec3, sum);
224
225 let a_vec4 = _mm256_loadu_ps(a.as_ptr().add((i + 3) * 8));
226 sum = _mm256_fmadd_ps(a_vec4, a_vec4, sum);
227
228 i += 4;
229 }
230
231 for j in i..chunks {
233 let a_vec = _mm256_loadu_ps(a.as_ptr().add(j * 8));
234 sum = _mm256_fmadd_ps(a_vec, a_vec, sum);
235 }
236
237 let sum_low = _mm256_extractf128_ps(sum, 0);
239 let sum_high = _mm256_extractf128_ps(sum, 1);
240 let sum_combined = _mm_add_ps(sum_low, sum_high);
241
242 let sum_shuffled = _mm_shuffle_ps(sum_combined, sum_combined, 0b01_00_11_10);
243 let sum_partial = _mm_add_ps(sum_combined, sum_shuffled);
244 let sum_final_shuffled = _mm_shuffle_ps(sum_partial, sum_partial, 0b00_00_00_01);
245 let final_sum = _mm_add_ps(sum_partial, sum_final_shuffled);
246
247 let mut result = _mm_cvtss_f32(final_sum);
248
249 for k in (chunks * 8)..len {
251 result += a[k] * a[k];
252 }
253
254 result
255 }
256
257 #[cfg(target_arch = "x86_64")]
259 #[target_feature(enable = "avx2")]
260 unsafe fn dot_product_avx2(a: &[f32], b: &[f32], len: usize) -> f32 {
261 let mut sum = _mm256_setzero_ps();
262 let chunks = len / 8;
263
264 for i in 0..chunks {
265 let a_vec = _mm256_loadu_ps(a.as_ptr().add(i * 8));
266 let b_vec = _mm256_loadu_ps(b.as_ptr().add(i * 8));
267 let product = _mm256_mul_ps(a_vec, b_vec);
268 sum = _mm256_add_ps(sum, product);
269 }
270
271 let sum_low = _mm256_extractf128_ps(sum, 0);
273 let sum_high = _mm256_extractf128_ps(sum, 1);
274 let sum_combined = _mm_add_ps(sum_low, sum_high);
275
276 let sum_shuffled = _mm_shuffle_ps(sum_combined, sum_combined, 0b01_00_11_10);
277 let sum_partial = _mm_add_ps(sum_combined, sum_shuffled);
278 let sum_final_shuffled = _mm_shuffle_ps(sum_partial, sum_partial, 0b00_00_00_01);
279 let final_sum = _mm_add_ps(sum_partial, sum_final_shuffled);
280
281 let mut result = _mm_cvtss_f32(final_sum);
282
283 for i in (chunks * 8)..len {
285 result += a[i] * b[i];
286 }
287
288 result
289 }
290
291 #[cfg(target_arch = "x86_64")]
292 #[target_feature(enable = "avx2")]
293 unsafe fn squared_norm_avx2(a: &[f32], len: usize) -> f32 {
294 let mut sum = _mm256_setzero_ps();
295 let chunks = len / 8;
296
297 for i in 0..chunks {
298 let a_vec = _mm256_loadu_ps(a.as_ptr().add(i * 8));
299 let squared = _mm256_mul_ps(a_vec, a_vec);
300 sum = _mm256_add_ps(sum, squared);
301 }
302
303 let sum_low = _mm256_extractf128_ps(sum, 0);
305 let sum_high = _mm256_extractf128_ps(sum, 1);
306 let sum_combined = _mm_add_ps(sum_low, sum_high);
307
308 let sum_shuffled = _mm_shuffle_ps(sum_combined, sum_combined, 0b01_00_11_10);
309 let sum_partial = _mm_add_ps(sum_combined, sum_shuffled);
310 let sum_final_shuffled = _mm_shuffle_ps(sum_partial, sum_partial, 0b00_00_00_01);
311 let final_sum = _mm_add_ps(sum_partial, sum_final_shuffled);
312
313 let mut result = _mm_cvtss_f32(final_sum);
314
315 for i in (chunks * 8)..len {
317 result += a[i] * a[i];
318 }
319
320 result
321 }
322
323 #[cfg(target_arch = "x86_64")]
324 #[target_feature(enable = "avx2")]
325 unsafe fn add_vectors_avx2(a: &[f32], b: &[f32], result: &mut [f32], len: usize) {
326 let chunks = len / 8;
327
328 for i in 0..chunks {
329 let a_vec = _mm256_loadu_ps(a.as_ptr().add(i * 8));
330 let b_vec = _mm256_loadu_ps(b.as_ptr().add(i * 8));
331 let sum = _mm256_add_ps(a_vec, b_vec);
332 _mm256_storeu_ps(result.as_mut_ptr().add(i * 8), sum);
333 }
334
335 for i in (chunks * 8)..len {
337 result[i] = a[i] + b[i];
338 }
339 }
340
341 #[cfg(target_arch = "x86_64")]
342 #[target_feature(enable = "avx2")]
343 unsafe fn scale_vector_avx2(a: &[f32], scalar: f32, result: &mut [f32], len: usize) {
344 let scalar_vec = _mm256_set1_ps(scalar);
345 let chunks = len / 8;
346
347 for i in 0..chunks {
348 let a_vec = _mm256_loadu_ps(a.as_ptr().add(i * 8));
349 let scaled = _mm256_mul_ps(a_vec, scalar_vec);
350 _mm256_storeu_ps(result.as_mut_ptr().add(i * 8), scaled);
351 }
352
353 for i in (chunks * 8)..len {
355 result[i] = a[i] * scalar;
356 }
357 }
358
359 #[cfg(target_arch = "x86_64")]
361 #[target_feature(enable = "sse4.1")]
362 unsafe fn dot_product_sse41(a: &[f32], b: &[f32], len: usize) -> f32 {
363 let mut sum = _mm_setzero_ps();
364 let chunks = len / 4;
365
366 for i in 0..chunks {
367 let a_vec = _mm_loadu_ps(a.as_ptr().add(i * 4));
368 let b_vec = _mm_loadu_ps(b.as_ptr().add(i * 4));
369 let product = _mm_mul_ps(a_vec, b_vec);
370 sum = _mm_add_ps(sum, product);
371 }
372
373 let sum_shuffled = _mm_shuffle_ps(sum, sum, 0b01_00_11_10);
375 let sum_partial = _mm_add_ps(sum, sum_shuffled);
376 let sum_final_shuffled = _mm_shuffle_ps(sum_partial, sum_partial, 0b00_00_00_01);
377 let final_sum = _mm_add_ps(sum_partial, sum_final_shuffled);
378
379 let mut result = _mm_cvtss_f32(final_sum);
380
381 for i in (chunks * 4)..len {
383 result += a[i] * b[i];
384 }
385
386 result
387 }
388
389 #[cfg(target_arch = "x86_64")]
390 #[target_feature(enable = "sse4.1")]
391 unsafe fn squared_norm_sse41(a: &[f32], len: usize) -> f32 {
392 let mut sum = _mm_setzero_ps();
393 let chunks = len / 4;
394
395 for i in 0..chunks {
396 let a_vec = _mm_loadu_ps(a.as_ptr().add(i * 4));
397 let squared = _mm_mul_ps(a_vec, a_vec);
398 sum = _mm_add_ps(sum, squared);
399 }
400
401 let sum_shuffled = _mm_shuffle_ps(sum, sum, 0b01_00_11_10);
403 let sum_partial = _mm_add_ps(sum, sum_shuffled);
404 let sum_final_shuffled = _mm_shuffle_ps(sum_partial, sum_partial, 0b00_00_00_01);
405 let final_sum = _mm_add_ps(sum_partial, sum_final_shuffled);
406
407 let mut result = _mm_cvtss_f32(final_sum);
408
409 for i in (chunks * 4)..len {
411 result += a[i] * a[i];
412 }
413
414 result
415 }
416
417 #[cfg(target_arch = "x86_64")]
418 #[target_feature(enable = "sse4.1")]
419 unsafe fn add_vectors_sse41(a: &[f32], b: &[f32], result: &mut [f32], len: usize) {
420 let chunks = len / 4;
421
422 for i in 0..chunks {
423 let a_vec = _mm_loadu_ps(a.as_ptr().add(i * 4));
424 let b_vec = _mm_loadu_ps(b.as_ptr().add(i * 4));
425 let sum = _mm_add_ps(a_vec, b_vec);
426 _mm_storeu_ps(result.as_mut_ptr().add(i * 4), sum);
427 }
428
429 for i in (chunks * 4)..len {
431 result[i] = a[i] + b[i];
432 }
433 }
434
435 #[cfg(target_arch = "x86_64")]
436 #[target_feature(enable = "sse4.1")]
437 unsafe fn scale_vector_sse41(a: &[f32], scalar: f32, result: &mut [f32], len: usize) {
438 let scalar_vec = _mm_set1_ps(scalar);
439 let chunks = len / 4;
440
441 for i in 0..chunks {
442 let a_vec = _mm_loadu_ps(a.as_ptr().add(i * 4));
443 let scaled = _mm_mul_ps(a_vec, scalar_vec);
444 _mm_storeu_ps(result.as_mut_ptr().add(i * 4), scaled);
445 }
446
447 for i in (chunks * 4)..len {
449 result[i] = a[i] * scalar;
450 }
451 }
452
453 fn dot_product_fallback(a: &[f32], b: &[f32]) -> f32 {
455 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
456 }
457
458 fn squared_norm_fallback(a: &[f32]) -> f32 {
459 a.iter().map(|&x| x * x).sum()
460 }
461
462 fn add_vectors_fallback(a: &[f32], b: &[f32], result: &mut [f32]) {
463 for i in 0..a.len() {
464 result[i] = a[i] + b[i];
465 }
466 }
467
468 fn scale_vector_fallback(a: &[f32], scalar: f32, result: &mut [f32]) {
469 for i in 0..a.len() {
470 result[i] = a[i] * scalar;
471 }
472 }
473}
474
475pub struct BatchSimdOps;
484
485impl BatchSimdOps {
486 pub fn pairwise_cosine_similarities(vectors: &[Array1<f32>]) -> Vec<Vec<f32>> {
488 let n = vectors.len();
489 let mut results = vec![vec![0.0; n]; n];
490
491 let norms: Vec<f32> = vectors
493 .iter()
494 .map(|v| SimdVectorOps::squared_norm(v).sqrt())
495 .collect();
496
497 for i in 0..n {
498 for j in i..n {
499 if i == j {
500 results[i][j] = 1.0;
501 } else {
502 let dot = SimdVectorOps::dot_product(&vectors[i], &vectors[j]);
503 let similarity = if norms[i] == 0.0 || norms[j] == 0.0 {
504 0.0
505 } else {
506 dot / (norms[i] * norms[j])
507 };
508 results[i][j] = similarity;
509 results[j][i] = similarity; }
511 }
512 }
513
514 results
515 }
516
517 pub fn find_k_most_similar(
519 query: &Array1<f32>,
520 vectors: &[Array1<f32>],
521 k: usize,
522 ) -> Vec<(usize, f32)> {
523 let query_norm = SimdVectorOps::squared_norm(query).sqrt();
524
525 let mut similarities: Vec<(usize, f32)> = vectors
526 .iter()
527 .enumerate()
528 .map(|(i, v)| {
529 let dot = SimdVectorOps::dot_product(query, v);
530 let v_norm = SimdVectorOps::squared_norm(v).sqrt();
531 let similarity = if query_norm == 0.0 || v_norm == 0.0 {
532 0.0
533 } else {
534 dot / (query_norm * v_norm)
535 };
536 (i, similarity)
537 })
538 .collect();
539
540 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
542
543 similarities.into_iter().take(k).collect()
544 }
545
546 pub fn compute_centroid(vectors: &[Array1<f32>]) -> Array1<f32> {
548 if vectors.is_empty() {
549 return Array1::zeros(0);
550 }
551
552 let len = vectors[0].len();
553 let mut centroid = Array1::zeros(len);
554
555 for vector in vectors {
556 centroid = SimdVectorOps::add_vectors(¢roid, vector);
557 }
558
559 let count = vectors.len() as f32;
560 SimdVectorOps::scale_vector(¢roid, 1.0 / count)
561 }
562
563 pub fn fast_cosine_similarity_with_norms(
565 a: &Array1<f32>,
566 b: &Array1<f32>,
567 norm_a: f32,
568 norm_b: f32,
569 ) -> f32 {
570 if norm_a == 0.0 || norm_b == 0.0 {
571 return 0.0;
572 }
573 let dot = SimdVectorOps::dot_product(a, b);
574 dot / (norm_a * norm_b)
575 }
576
577 pub fn cache_optimized_batch_similarities(
579 query: &Array1<f32>,
580 vectors: &[Array1<f32>],
581 batch_size: usize,
582 ) -> Vec<f32> {
583 let mut results = Vec::with_capacity(vectors.len());
584 let query_norm = SimdVectorOps::squared_norm(query).sqrt();
585
586 for chunk in vectors.chunks(batch_size) {
588 let norms: Vec<f32> = chunk
590 .iter()
591 .map(|v| SimdVectorOps::squared_norm(v).sqrt())
592 .collect();
593
594 for (vector, &norm) in chunk.iter().zip(norms.iter()) {
596 let similarity = Self::fast_cosine_similarity_with_norms(query, vector, query_norm, norm);
597 results.push(similarity);
598 }
599 }
600
601 results
602 }
603
604 pub fn aligned_dot_product(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
606 let a_slice = a.as_slice().unwrap();
608 let b_slice = b.as_slice().unwrap();
609
610 #[cfg(target_arch = "x86_64")]
612 {
613 if Self::is_aligned(a_slice.as_ptr(), 32) && Self::is_aligned(b_slice.as_ptr(), 32) {
614 unsafe { Self::aligned_dot_product_avx2(a_slice, b_slice) }
616 } else {
617 SimdVectorOps::dot_product(a, b)
618 }
619 }
620
621 #[cfg(not(target_arch = "x86_64"))]
622 {
623 SimdVectorOps::dot_product(a, b)
624 }
625 }
626
627 pub fn batch_normalize(vectors: &mut [Array1<f32>]) {
629 for vector in vectors {
630 let norm = SimdVectorOps::squared_norm(vector).sqrt();
631 if norm > 0.0 {
632 *vector = SimdVectorOps::scale_vector(vector, 1.0 / norm);
633 }
634 }
635 }
636
637 pub fn matrix_vector_multiply(matrix: &[Array1<f32>], vector: &Array1<f32>) -> Array1<f32> {
639 let rows = matrix.len();
640 let mut result = Array1::zeros(rows);
641
642 if rows > 100 {
644 use rayon::prelude::*;
645 let results: Vec<f32> = matrix
646 .par_iter()
647 .map(|row| SimdVectorOps::dot_product(row, vector))
648 .collect();
649 result = Array1::from_vec(results);
650 } else {
651 for (i, row) in matrix.iter().enumerate() {
652 result[i] = SimdVectorOps::dot_product(row, vector);
653 }
654 }
655
656 result
657 }
658
659 #[cfg(target_arch = "x86_64")]
661 fn is_aligned(ptr: *const f32, alignment: usize) -> bool {
662 (ptr as usize) % alignment == 0
663 }
664
665 #[cfg(target_arch = "x86_64")]
666 #[target_feature(enable = "avx2")]
667 unsafe fn aligned_dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
668 let mut sum = _mm256_setzero_ps();
669 let len = a.len();
670 let chunks = len / 8;
671
672 for i in 0..chunks {
673 let a_vec = _mm256_load_ps(a.as_ptr().add(i * 8));
675 let b_vec = _mm256_load_ps(b.as_ptr().add(i * 8));
676 let product = _mm256_mul_ps(a_vec, b_vec);
677 sum = _mm256_add_ps(sum, product);
678 }
679
680 let sum_low = _mm256_extractf128_ps(sum, 0);
682 let sum_high = _mm256_extractf128_ps(sum, 1);
683 let sum_combined = _mm_add_ps(sum_low, sum_high);
684
685 let sum_shuffled = _mm_shuffle_ps(sum_combined, sum_combined, 0b01_00_11_10);
686 let sum_partial = _mm_add_ps(sum_combined, sum_shuffled);
687 let sum_final_shuffled = _mm_shuffle_ps(sum_partial, sum_partial, 0b00_00_00_01);
688 let final_sum = _mm_add_ps(sum_partial, sum_final_shuffled);
689
690 let mut result = _mm_cvtss_f32(final_sum);
691
692 for i in (chunks * 8)..len {
694 result += a[i] * b[i];
695 }
696
697 result
698 }
699}
700
701#[cfg(test)]
702mod tests {
703 use super::*;
704 use ndarray::Array1;
705
706 #[test]
707 fn test_simd_dot_product() {
708 let a = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
709 let b = Array1::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
710
711 let result = SimdVectorOps::dot_product(&a, &b);
712 let expected = 1.0 * 5.0 + 2.0 * 6.0 + 3.0 * 7.0 + 4.0 * 8.0;
713
714 assert!((result - expected).abs() < 1e-6);
715 }
716
717 #[test]
718 fn test_simd_squared_norm() {
719 let a = Array1::from_vec(vec![3.0, 4.0]);
720 let result = SimdVectorOps::squared_norm(&a);
721 let expected = 9.0 + 16.0;
722
723 assert!((result - expected).abs() < 1e-6);
724 }
725
726 #[test]
727 fn test_simd_cosine_similarity() {
728 let a = Array1::from_vec(vec![1.0, 0.0]);
729 let b = Array1::from_vec(vec![0.0, 1.0]);
730 let c = Array1::from_vec(vec![1.0, 0.0]);
731
732 let sim_ab = SimdVectorOps::cosine_similarity(&a, &b);
734 assert!((sim_ab - 0.0).abs() < 1e-6);
735
736 let sim_ac = SimdVectorOps::cosine_similarity(&a, &c);
738 assert!((sim_ac - 1.0).abs() < 1e-6);
739 }
740
741 #[test]
742 fn test_simd_add_vectors() {
743 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
744 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
745 let result = SimdVectorOps::add_vectors(&a, &b);
746
747 assert_eq!(result[0], 5.0);
748 assert_eq!(result[1], 7.0);
749 assert_eq!(result[2], 9.0);
750 }
751
752 #[test]
753 fn test_simd_scale_vector() {
754 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
755 let result = SimdVectorOps::scale_vector(&a, 2.0);
756
757 assert_eq!(result[0], 2.0);
758 assert_eq!(result[1], 4.0);
759 assert_eq!(result[2], 6.0);
760 }
761
762 #[test]
763 fn test_batch_pairwise_similarities() {
764 let vectors = vec![
765 Array1::from_vec(vec![1.0, 0.0]),
766 Array1::from_vec(vec![0.0, 1.0]),
767 Array1::from_vec(vec![1.0, 1.0]),
768 ];
769
770 let similarities = BatchSimdOps::pairwise_cosine_similarities(&vectors);
771
772 assert!((similarities[0][0] - 1.0).abs() < 1e-6);
774 assert!((similarities[1][1] - 1.0).abs() < 1e-6);
775 assert!((similarities[2][2] - 1.0).abs() < 1e-6);
776
777 assert!((similarities[0][1] - 0.0).abs() < 1e-6);
779 assert!((similarities[1][0] - 0.0).abs() < 1e-6);
780 }
781
782 #[test]
783 fn test_find_k_most_similar() {
784 let query = Array1::from_vec(vec![1.0, 0.0]);
785 let vectors = vec![
786 Array1::from_vec(vec![1.0, 0.0]), Array1::from_vec(vec![0.0, 1.0]), Array1::from_vec(vec![0.5, 0.5]), ];
790
791 let results = BatchSimdOps::find_k_most_similar(&query, &vectors, 2);
792
793 assert_eq!(results[0].0, 0);
795 assert!(results[0].1 > results[1].1);
796 }
797
798 #[test]
799 fn test_compute_centroid() {
800 let vectors = vec![
801 Array1::from_vec(vec![1.0, 2.0]),
802 Array1::from_vec(vec![3.0, 4.0]),
803 Array1::from_vec(vec![5.0, 6.0]),
804 ];
805
806 let centroid = BatchSimdOps::compute_centroid(&vectors);
807
808 assert!((centroid[0] - 3.0).abs() < 1e-6);
809 assert!((centroid[1] - 4.0).abs() < 1e-6);
810 }
811
812 #[test]
813 fn test_large_vector_performance() {
814 let size = 1024;
815 let a = Array1::from_vec((0..size).map(|i| i as f32).collect());
816 let b = Array1::from_vec((0..size).map(|i| (i * 2) as f32).collect());
817
818 let dot_simd = SimdVectorOps::dot_product(&a, &b);
820 let dot_naive: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
821
822 let relative_error = (dot_simd - dot_naive).abs() / dot_naive.abs();
824 assert!(relative_error < 1e-5, "SIMD dot product relative error too large: {}", relative_error);
825 }
826}