1pub fn vector_add_f32(a: &[f32], b: &[f32], c: &mut [f32]) {
17 assert_eq!(a.len(), b.len(), "vector_add_f32: a.len() != b.len()");
18 assert_eq!(a.len(), c.len(), "vector_add_f32: a.len() != c.len()");
19
20 #[cfg(target_arch = "x86_64")]
21 {
22 if is_x86_feature_detected!("avx2") {
23 unsafe { avx2::vector_add_f32_avx2(a, b, c) };
25 return;
26 }
27 }
28
29 #[cfg(target_arch = "aarch64")]
30 {
31 unsafe { neon::vector_add_f32_neon(a, b, c) };
33 return;
34 }
35
36 #[allow(unreachable_code)]
38 scalar::vector_add_f32_scalar(a, b, c);
39}
40
41pub fn vector_mul_f32(a: &[f32], b: &[f32], c: &mut [f32]) {
46 assert_eq!(a.len(), b.len(), "vector_mul_f32: a.len() != b.len()");
47 assert_eq!(a.len(), c.len(), "vector_mul_f32: a.len() != c.len()");
48
49 #[cfg(target_arch = "x86_64")]
50 {
51 if is_x86_feature_detected!("avx2") {
52 unsafe { avx2::vector_mul_f32_avx2(a, b, c) };
53 return;
54 }
55 }
56
57 #[cfg(target_arch = "aarch64")]
58 {
59 unsafe { neon::vector_mul_f32_neon(a, b, c) };
60 return;
61 }
62
63 #[allow(unreachable_code)]
64 scalar::vector_mul_f32_scalar(a, b, c);
65}
66
67pub fn vector_scale_f32(a: &[f32], scalar: f32, c: &mut [f32]) {
72 assert_eq!(a.len(), c.len(), "vector_scale_f32: a.len() != c.len()");
73
74 #[cfg(target_arch = "x86_64")]
75 {
76 if is_x86_feature_detected!("avx2") {
77 unsafe { avx2::vector_scale_f32_avx2(a, scalar, c) };
78 return;
79 }
80 }
81
82 #[cfg(target_arch = "aarch64")]
83 {
84 unsafe { neon::vector_scale_f32_neon(a, scalar, c) };
85 return;
86 }
87
88 #[allow(unreachable_code)]
89 scalar::vector_scale_f32_scalar(a, scalar, c);
90}
91
92pub fn vector_dot_f32(a: &[f32], b: &[f32]) -> f32 {
97 assert_eq!(a.len(), b.len(), "vector_dot_f32: a.len() != b.len()");
98
99 #[cfg(target_arch = "x86_64")]
100 {
101 if is_x86_feature_detected!("avx2") {
102 return unsafe { avx2::vector_dot_f32_avx2(a, b) };
103 }
104 }
105
106 #[cfg(target_arch = "aarch64")]
107 {
108 return unsafe { neon::vector_dot_f32_neon(a, b) };
109 }
110
111 #[allow(unreachable_code)]
112 scalar::vector_dot_f32_scalar(a, b)
113}
114
115pub fn vector_reduce_sum_f32(a: &[f32]) -> f32 {
117 #[cfg(target_arch = "x86_64")]
118 {
119 if is_x86_feature_detected!("avx2") {
120 return unsafe { avx2::vector_reduce_sum_f32_avx2(a) };
121 }
122 }
123
124 #[cfg(target_arch = "aarch64")]
125 {
126 return unsafe { neon::vector_reduce_sum_f32_neon(a) };
127 }
128
129 #[allow(unreachable_code)]
130 scalar::vector_reduce_sum_f32_scalar(a)
131}
132
133mod scalar {
137 pub fn vector_add_f32_scalar(a: &[f32], b: &[f32], c: &mut [f32]) {
138 for i in 0..a.len() {
139 c[i] = a[i] + b[i];
140 }
141 }
142
143 pub fn vector_mul_f32_scalar(a: &[f32], b: &[f32], c: &mut [f32]) {
144 for i in 0..a.len() {
145 c[i] = a[i] * b[i];
146 }
147 }
148
149 pub fn vector_scale_f32_scalar(a: &[f32], scalar: f32, c: &mut [f32]) {
150 for i in 0..a.len() {
151 c[i] = a[i] * scalar;
152 }
153 }
154
155 pub fn vector_dot_f32_scalar(a: &[f32], b: &[f32]) -> f32 {
156 let mut sum = 0.0f32;
157 for i in 0..a.len() {
158 sum += a[i] * b[i];
159 }
160 sum
161 }
162
163 pub fn vector_reduce_sum_f32_scalar(a: &[f32]) -> f32 {
164 let mut sum = 0.0f32;
165 for &val in a {
166 sum += val;
167 }
168 sum
169 }
170}
171
172#[cfg(target_arch = "x86_64")]
176mod avx2 {
177 #[cfg(target_arch = "x86_64")]
178 use std::arch::x86_64::*;
179
180 const AVX2_F32_LANES: usize = 8;
181
182 #[target_feature(enable = "avx2")]
187 pub unsafe fn vector_add_f32_avx2(a: &[f32], b: &[f32], c: &mut [f32]) {
188 let n = a.len();
189 let chunks = n / AVX2_F32_LANES;
190 let remainder = n % AVX2_F32_LANES;
191
192 for i in 0..chunks {
193 let offset = i * AVX2_F32_LANES;
194 let va = _mm256_loadu_ps(a.as_ptr().add(offset));
195 let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
196 let vc = _mm256_add_ps(va, vb);
197 _mm256_storeu_ps(c.as_mut_ptr().add(offset), vc);
198 }
199
200 let tail_start = chunks * AVX2_F32_LANES;
202 for i in 0..remainder {
203 c[tail_start + i] = a[tail_start + i] + b[tail_start + i];
204 }
205 }
206
207 #[target_feature(enable = "avx2")]
209 pub unsafe fn vector_mul_f32_avx2(a: &[f32], b: &[f32], c: &mut [f32]) {
210 let n = a.len();
211 let chunks = n / AVX2_F32_LANES;
212 let remainder = n % AVX2_F32_LANES;
213
214 for i in 0..chunks {
215 let offset = i * AVX2_F32_LANES;
216 let va = _mm256_loadu_ps(a.as_ptr().add(offset));
217 let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
218 let vc = _mm256_mul_ps(va, vb);
219 _mm256_storeu_ps(c.as_mut_ptr().add(offset), vc);
220 }
221
222 let tail_start = chunks * AVX2_F32_LANES;
223 for i in 0..remainder {
224 c[tail_start + i] = a[tail_start + i] * b[tail_start + i];
225 }
226 }
227
228 #[target_feature(enable = "avx2")]
230 pub unsafe fn vector_scale_f32_avx2(a: &[f32], scalar: f32, c: &mut [f32]) {
231 let n = a.len();
232 let chunks = n / AVX2_F32_LANES;
233 let remainder = n % AVX2_F32_LANES;
234 let vs = _mm256_set1_ps(scalar);
235
236 for i in 0..chunks {
237 let offset = i * AVX2_F32_LANES;
238 let va = _mm256_loadu_ps(a.as_ptr().add(offset));
239 let vc = _mm256_mul_ps(va, vs);
240 _mm256_storeu_ps(c.as_mut_ptr().add(offset), vc);
241 }
242
243 let tail_start = chunks * AVX2_F32_LANES;
244 for i in 0..remainder {
245 c[tail_start + i] = a[tail_start + i] * scalar;
246 }
247 }
248
249 #[target_feature(enable = "avx2")]
251 pub unsafe fn vector_dot_f32_avx2(a: &[f32], b: &[f32]) -> f32 {
252 let n = a.len();
253 let chunks = n / AVX2_F32_LANES;
254 let remainder = n % AVX2_F32_LANES;
255
256 let mut acc = _mm256_setzero_ps();
257
258 for i in 0..chunks {
259 let offset = i * AVX2_F32_LANES;
260 let va = _mm256_loadu_ps(a.as_ptr().add(offset));
261 let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
262 acc = _mm256_add_ps(acc, _mm256_mul_ps(va, vb));
264 }
265
266 let sum = hsum_avx2(acc);
268
269 let tail_start = chunks * AVX2_F32_LANES;
271 let mut tail_sum = 0.0f32;
272 for i in 0..remainder {
273 tail_sum += a[tail_start + i] * b[tail_start + i];
274 }
275
276 sum + tail_sum
277 }
278
279 #[target_feature(enable = "avx2")]
281 pub unsafe fn vector_reduce_sum_f32_avx2(a: &[f32]) -> f32 {
282 let n = a.len();
283 let chunks = n / AVX2_F32_LANES;
284 let remainder = n % AVX2_F32_LANES;
285
286 let mut acc = _mm256_setzero_ps();
287
288 for i in 0..chunks {
289 let offset = i * AVX2_F32_LANES;
290 let va = _mm256_loadu_ps(a.as_ptr().add(offset));
291 acc = _mm256_add_ps(acc, va);
292 }
293
294 let sum = hsum_avx2(acc);
295
296 let tail_start = chunks * AVX2_F32_LANES;
297 let mut tail_sum = 0.0f32;
298 for i in 0..remainder {
299 tail_sum += a[tail_start + i];
300 }
301
302 sum + tail_sum
303 }
304
305 #[target_feature(enable = "avx2")]
307 unsafe fn hsum_avx2(v: __m256) -> f32 {
308 let hi128 = _mm256_extractf128_ps(v, 1);
310 let lo128 = _mm256_castps256_ps128(v);
311 let sum128 = _mm_add_ps(lo128, hi128);
312 let shuf = _mm_movehdup_ps(sum128); let sums = _mm_add_ps(sum128, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let result = _mm_add_ss(sums, shuf2);
317 _mm_cvtss_f32(result)
318 }
319}
320
321#[cfg(target_arch = "aarch64")]
325mod neon {
326 use std::arch::aarch64::*;
327
328 const NEON_F32_LANES: usize = 4;
329
330 pub unsafe fn vector_add_f32_neon(a: &[f32], b: &[f32], c: &mut [f32]) {
335 let n = a.len();
336 let chunks = n / NEON_F32_LANES;
337 let remainder = n % NEON_F32_LANES;
338
339 for i in 0..chunks {
340 let offset = i * NEON_F32_LANES;
341 let va = vld1q_f32(a.as_ptr().add(offset));
342 let vb = vld1q_f32(b.as_ptr().add(offset));
343 let vc = vaddq_f32(va, vb);
344 vst1q_f32(c.as_mut_ptr().add(offset), vc);
345 }
346
347 let tail_start = chunks * NEON_F32_LANES;
348 for i in 0..remainder {
349 c[tail_start + i] = a[tail_start + i] + b[tail_start + i];
350 }
351 }
352
353 pub unsafe fn vector_mul_f32_neon(a: &[f32], b: &[f32], c: &mut [f32]) {
355 let n = a.len();
356 let chunks = n / NEON_F32_LANES;
357 let remainder = n % NEON_F32_LANES;
358
359 for i in 0..chunks {
360 let offset = i * NEON_F32_LANES;
361 let va = vld1q_f32(a.as_ptr().add(offset));
362 let vb = vld1q_f32(b.as_ptr().add(offset));
363 let vc = vmulq_f32(va, vb);
364 vst1q_f32(c.as_mut_ptr().add(offset), vc);
365 }
366
367 let tail_start = chunks * NEON_F32_LANES;
368 for i in 0..remainder {
369 c[tail_start + i] = a[tail_start + i] * b[tail_start + i];
370 }
371 }
372
373 pub unsafe fn vector_scale_f32_neon(a: &[f32], scalar: f32, c: &mut [f32]) {
375 let n = a.len();
376 let chunks = n / NEON_F32_LANES;
377 let remainder = n % NEON_F32_LANES;
378 let vs = vdupq_n_f32(scalar);
379
380 for i in 0..chunks {
381 let offset = i * NEON_F32_LANES;
382 let va = vld1q_f32(a.as_ptr().add(offset));
383 let vc = vmulq_f32(va, vs);
384 vst1q_f32(c.as_mut_ptr().add(offset), vc);
385 }
386
387 let tail_start = chunks * NEON_F32_LANES;
388 for i in 0..remainder {
389 c[tail_start + i] = a[tail_start + i] * scalar;
390 }
391 }
392
393 pub unsafe fn vector_dot_f32_neon(a: &[f32], b: &[f32]) -> f32 {
395 let n = a.len();
396 let chunks = n / NEON_F32_LANES;
397 let remainder = n % NEON_F32_LANES;
398
399 let mut acc = vdupq_n_f32(0.0);
400
401 for i in 0..chunks {
402 let offset = i * NEON_F32_LANES;
403 let va = vld1q_f32(a.as_ptr().add(offset));
404 let vb = vld1q_f32(b.as_ptr().add(offset));
405 acc = vfmaq_f32(acc, va, vb);
406 }
407
408 let sum = vaddvq_f32(acc);
409
410 let tail_start = chunks * NEON_F32_LANES;
411 let mut tail_sum = 0.0f32;
412 for i in 0..remainder {
413 tail_sum += a[tail_start + i] * b[tail_start + i];
414 }
415
416 sum + tail_sum
417 }
418
419 pub unsafe fn vector_reduce_sum_f32_neon(a: &[f32]) -> f32 {
421 let n = a.len();
422 let chunks = n / NEON_F32_LANES;
423 let remainder = n % NEON_F32_LANES;
424
425 let mut acc = vdupq_n_f32(0.0);
426
427 for i in 0..chunks {
428 let offset = i * NEON_F32_LANES;
429 let va = vld1q_f32(a.as_ptr().add(offset));
430 acc = vaddq_f32(acc, va);
431 }
432
433 let sum = vaddvq_f32(acc);
434
435 let tail_start = chunks * NEON_F32_LANES;
436 let mut tail_sum = 0.0f32;
437 for i in 0..remainder {
438 tail_sum += a[tail_start + i];
439 }
440
441 sum + tail_sum
442 }
443}
444
445#[cfg(test)]
449mod tests {
450 use super::*;
451
452 const EPSILON: f32 = 1e-5;
453
454 fn approx_eq(a: f32, b: f32) -> bool {
455 (a - b).abs() < EPSILON
456 }
457
458 #[test]
459 fn test_vector_add_basic() {
460 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
461 let b = vec![9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
462 let mut c = vec![0.0; 9];
463
464 vector_add_f32(&a, &b, &mut c);
465
466 for val in &c {
467 assert!(approx_eq(*val, 10.0), "Expected 10.0, got {val}");
468 }
469 }
470
471 #[test]
472 fn test_vector_mul_basic() {
473 let a = vec![1.0, 2.0, 3.0, 4.0];
474 let b = vec![2.0, 3.0, 4.0, 5.0];
475 let mut c = vec![0.0; 4];
476
477 vector_mul_f32(&a, &b, &mut c);
478
479 assert!(approx_eq(c[0], 2.0));
480 assert!(approx_eq(c[1], 6.0));
481 assert!(approx_eq(c[2], 12.0));
482 assert!(approx_eq(c[3], 20.0));
483 }
484
485 #[test]
486 fn test_vector_scale_basic() {
487 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
488 let mut c = vec![0.0; 5];
489
490 vector_scale_f32(&a, 3.0, &mut c);
491
492 assert!(approx_eq(c[0], 3.0));
493 assert!(approx_eq(c[1], 6.0));
494 assert!(approx_eq(c[2], 9.0));
495 assert!(approx_eq(c[3], 12.0));
496 assert!(approx_eq(c[4], 15.0));
497 }
498
499 #[test]
500 fn test_vector_dot_basic() {
501 let a = vec![1.0, 2.0, 3.0, 4.0];
502 let b = vec![1.0, 1.0, 1.0, 1.0];
503
504 let result = vector_dot_f32(&a, &b);
505 assert!(approx_eq(result, 10.0));
506 }
507
508 #[test]
509 fn test_vector_reduce_sum_basic() {
510 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
511 let result = vector_reduce_sum_f32(&a);
512 assert!(approx_eq(result, 55.0));
513 }
514
515 #[test]
516 fn test_empty_vectors() {
517 let a: Vec<f32> = vec![];
518 let b: Vec<f32> = vec![];
519 let mut c: Vec<f32> = vec![];
520
521 vector_add_f32(&a, &b, &mut c);
522 vector_mul_f32(&a, &b, &mut c);
523 vector_scale_f32(&a, 2.0, &mut c);
524 assert!(approx_eq(vector_dot_f32(&a, &b), 0.0));
525 assert!(approx_eq(vector_reduce_sum_f32(&a), 0.0));
526 }
527
528 #[test]
529 fn test_large_vector() {
530 let n = 1024;
531 let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
532 let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
533 let mut c = vec![0.0; n];
534
535 vector_add_f32(&a, &b, &mut c);
536
537 for val in &c {
538 assert!(approx_eq(*val, n as f32));
539 }
540 }
541
542 #[test]
543 #[should_panic(expected = "a.len() != b.len()")]
544 fn test_mismatched_lengths_add() {
545 let a = vec![1.0, 2.0];
546 let b = vec![1.0];
547 let mut c = vec![0.0; 2];
548 vector_add_f32(&a, &b, &mut c);
549 }
550}