Skip to main content

cuda_rust_wasm/simd/
vector_ops.rs

1//! SIMD-accelerated vector operations for CPU fallback paths
2//!
3//! Provides vectorized element-wise add, multiply, scale, dot product, and
4//! reduction operations. Architecture-specific implementations are selected
5//! at compile time via `cfg(target_arch)`, with a scalar fallback for
6//! unsupported platforms.
7
8// ---------------------------------------------------------------------------
9// Public API
10// ---------------------------------------------------------------------------
11
12/// Element-wise addition: `c[i] = a[i] + b[i]`
13///
14/// # Panics
15/// Panics if `a`, `b`, and `c` do not all have the same length.
16pub fn vector_add_f32(a: &[f32], b: &[f32], c: &mut [f32]) {
17    assert_eq!(a.len(), b.len(), "vector_add_f32: a.len() != b.len()");
18    assert_eq!(a.len(), c.len(), "vector_add_f32: a.len() != c.len()");
19
20    #[cfg(target_arch = "x86_64")]
21    {
22        if is_x86_feature_detected!("avx2") {
23            // Safety: length equality checked above; AVX2 detected at runtime.
24            unsafe { avx2::vector_add_f32_avx2(a, b, c) };
25            return;
26        }
27    }
28
29    #[cfg(target_arch = "aarch64")]
30    {
31        // Safety: NEON is mandatory on aarch64; length equality checked above.
32        unsafe { neon::vector_add_f32_neon(a, b, c) };
33        return;
34    }
35
36    // Scalar fallback (also used on wasm32 and other architectures)
37    #[allow(unreachable_code)]
38    scalar::vector_add_f32_scalar(a, b, c);
39}
40
41/// Element-wise multiplication: `c[i] = a[i] * b[i]`
42///
43/// # Panics
44/// Panics if `a`, `b`, and `c` do not all have the same length.
45pub fn vector_mul_f32(a: &[f32], b: &[f32], c: &mut [f32]) {
46    assert_eq!(a.len(), b.len(), "vector_mul_f32: a.len() != b.len()");
47    assert_eq!(a.len(), c.len(), "vector_mul_f32: a.len() != c.len()");
48
49    #[cfg(target_arch = "x86_64")]
50    {
51        if is_x86_feature_detected!("avx2") {
52            unsafe { avx2::vector_mul_f32_avx2(a, b, c) };
53            return;
54        }
55    }
56
57    #[cfg(target_arch = "aarch64")]
58    {
59        unsafe { neon::vector_mul_f32_neon(a, b, c) };
60        return;
61    }
62
63    #[allow(unreachable_code)]
64    scalar::vector_mul_f32_scalar(a, b, c);
65}
66
67/// Scale every element: `c[i] = a[i] * scalar`
68///
69/// # Panics
70/// Panics if `a` and `c` do not have the same length.
71pub fn vector_scale_f32(a: &[f32], scalar: f32, c: &mut [f32]) {
72    assert_eq!(a.len(), c.len(), "vector_scale_f32: a.len() != c.len()");
73
74    #[cfg(target_arch = "x86_64")]
75    {
76        if is_x86_feature_detected!("avx2") {
77            unsafe { avx2::vector_scale_f32_avx2(a, scalar, c) };
78            return;
79        }
80    }
81
82    #[cfg(target_arch = "aarch64")]
83    {
84        unsafe { neon::vector_scale_f32_neon(a, scalar, c) };
85        return;
86    }
87
88    #[allow(unreachable_code)]
89    scalar::vector_scale_f32_scalar(a, scalar, c);
90}
91
92/// Dot product: `sum(a[i] * b[i])`
93///
94/// # Panics
95/// Panics if `a` and `b` do not have the same length.
96pub fn vector_dot_f32(a: &[f32], b: &[f32]) -> f32 {
97    assert_eq!(a.len(), b.len(), "vector_dot_f32: a.len() != b.len()");
98
99    #[cfg(target_arch = "x86_64")]
100    {
101        if is_x86_feature_detected!("avx2") {
102            return unsafe { avx2::vector_dot_f32_avx2(a, b) };
103        }
104    }
105
106    #[cfg(target_arch = "aarch64")]
107    {
108        return unsafe { neon::vector_dot_f32_neon(a, b) };
109    }
110
111    #[allow(unreachable_code)]
112    scalar::vector_dot_f32_scalar(a, b)
113}
114
115/// Sum reduction: `sum(a[i])`
116pub fn vector_reduce_sum_f32(a: &[f32]) -> f32 {
117    #[cfg(target_arch = "x86_64")]
118    {
119        if is_x86_feature_detected!("avx2") {
120            return unsafe { avx2::vector_reduce_sum_f32_avx2(a) };
121        }
122    }
123
124    #[cfg(target_arch = "aarch64")]
125    {
126        return unsafe { neon::vector_reduce_sum_f32_neon(a) };
127    }
128
129    #[allow(unreachable_code)]
130    scalar::vector_reduce_sum_f32_scalar(a)
131}
132
133// ---------------------------------------------------------------------------
134// Scalar fallback implementation
135// ---------------------------------------------------------------------------
136mod scalar {
137    pub fn vector_add_f32_scalar(a: &[f32], b: &[f32], c: &mut [f32]) {
138        for i in 0..a.len() {
139            c[i] = a[i] + b[i];
140        }
141    }
142
143    pub fn vector_mul_f32_scalar(a: &[f32], b: &[f32], c: &mut [f32]) {
144        for i in 0..a.len() {
145            c[i] = a[i] * b[i];
146        }
147    }
148
149    pub fn vector_scale_f32_scalar(a: &[f32], scalar: f32, c: &mut [f32]) {
150        for i in 0..a.len() {
151            c[i] = a[i] * scalar;
152        }
153    }
154
155    pub fn vector_dot_f32_scalar(a: &[f32], b: &[f32]) -> f32 {
156        let mut sum = 0.0f32;
157        for i in 0..a.len() {
158            sum += a[i] * b[i];
159        }
160        sum
161    }
162
163    pub fn vector_reduce_sum_f32_scalar(a: &[f32]) -> f32 {
164        let mut sum = 0.0f32;
165        for &val in a {
166            sum += val;
167        }
168        sum
169    }
170}
171
172// ---------------------------------------------------------------------------
173// AVX2 implementation (x86_64)
174// ---------------------------------------------------------------------------
175#[cfg(target_arch = "x86_64")]
176mod avx2 {
177    #[cfg(target_arch = "x86_64")]
178    use std::arch::x86_64::*;
179
180    const AVX2_F32_LANES: usize = 8;
181
182    /// AVX2 vector addition: processes 8 f32s per iteration.
183    ///
184    /// # Safety
185    /// Caller must ensure AVX2 is available and all slices have the same length.
186    #[target_feature(enable = "avx2")]
187    pub unsafe fn vector_add_f32_avx2(a: &[f32], b: &[f32], c: &mut [f32]) {
188        let n = a.len();
189        let chunks = n / AVX2_F32_LANES;
190        let remainder = n % AVX2_F32_LANES;
191
192        for i in 0..chunks {
193            let offset = i * AVX2_F32_LANES;
194            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
195            let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
196            let vc = _mm256_add_ps(va, vb);
197            _mm256_storeu_ps(c.as_mut_ptr().add(offset), vc);
198        }
199
200        // Handle remaining elements
201        let tail_start = chunks * AVX2_F32_LANES;
202        for i in 0..remainder {
203            c[tail_start + i] = a[tail_start + i] + b[tail_start + i];
204        }
205    }
206
207    /// AVX2 vector multiplication.
208    #[target_feature(enable = "avx2")]
209    pub unsafe fn vector_mul_f32_avx2(a: &[f32], b: &[f32], c: &mut [f32]) {
210        let n = a.len();
211        let chunks = n / AVX2_F32_LANES;
212        let remainder = n % AVX2_F32_LANES;
213
214        for i in 0..chunks {
215            let offset = i * AVX2_F32_LANES;
216            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
217            let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
218            let vc = _mm256_mul_ps(va, vb);
219            _mm256_storeu_ps(c.as_mut_ptr().add(offset), vc);
220        }
221
222        let tail_start = chunks * AVX2_F32_LANES;
223        for i in 0..remainder {
224            c[tail_start + i] = a[tail_start + i] * b[tail_start + i];
225        }
226    }
227
228    /// AVX2 scalar multiplication.
229    #[target_feature(enable = "avx2")]
230    pub unsafe fn vector_scale_f32_avx2(a: &[f32], scalar: f32, c: &mut [f32]) {
231        let n = a.len();
232        let chunks = n / AVX2_F32_LANES;
233        let remainder = n % AVX2_F32_LANES;
234        let vs = _mm256_set1_ps(scalar);
235
236        for i in 0..chunks {
237            let offset = i * AVX2_F32_LANES;
238            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
239            let vc = _mm256_mul_ps(va, vs);
240            _mm256_storeu_ps(c.as_mut_ptr().add(offset), vc);
241        }
242
243        let tail_start = chunks * AVX2_F32_LANES;
244        for i in 0..remainder {
245            c[tail_start + i] = a[tail_start + i] * scalar;
246        }
247    }
248
249    /// AVX2 dot product.
250    #[target_feature(enable = "avx2")]
251    pub unsafe fn vector_dot_f32_avx2(a: &[f32], b: &[f32]) -> f32 {
252        let n = a.len();
253        let chunks = n / AVX2_F32_LANES;
254        let remainder = n % AVX2_F32_LANES;
255
256        let mut acc = _mm256_setzero_ps();
257
258        for i in 0..chunks {
259            let offset = i * AVX2_F32_LANES;
260            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
261            let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
262            // Use FMA if available for better precision and performance
263            acc = _mm256_add_ps(acc, _mm256_mul_ps(va, vb));
264        }
265
266        // Horizontal sum of the 8-wide accumulator
267        let sum = hsum_avx2(acc);
268
269        // Tail
270        let tail_start = chunks * AVX2_F32_LANES;
271        let mut tail_sum = 0.0f32;
272        for i in 0..remainder {
273            tail_sum += a[tail_start + i] * b[tail_start + i];
274        }
275
276        sum + tail_sum
277    }
278
279    /// AVX2 sum reduction.
280    #[target_feature(enable = "avx2")]
281    pub unsafe fn vector_reduce_sum_f32_avx2(a: &[f32]) -> f32 {
282        let n = a.len();
283        let chunks = n / AVX2_F32_LANES;
284        let remainder = n % AVX2_F32_LANES;
285
286        let mut acc = _mm256_setzero_ps();
287
288        for i in 0..chunks {
289            let offset = i * AVX2_F32_LANES;
290            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
291            acc = _mm256_add_ps(acc, va);
292        }
293
294        let sum = hsum_avx2(acc);
295
296        let tail_start = chunks * AVX2_F32_LANES;
297        let mut tail_sum = 0.0f32;
298        for i in 0..remainder {
299            tail_sum += a[tail_start + i];
300        }
301
302        sum + tail_sum
303    }
304
305    /// Horizontal sum of an __m256 register (8 x f32 -> single f32).
306    #[target_feature(enable = "avx2")]
307    unsafe fn hsum_avx2(v: __m256) -> f32 {
308        // Add high 128 to low 128
309        let hi128 = _mm256_extractf128_ps(v, 1);
310        let lo128 = _mm256_castps256_ps128(v);
311        let sum128 = _mm_add_ps(lo128, hi128);
312        // Horizontal add within 128 bits
313        let shuf = _mm_movehdup_ps(sum128); // [1,1,3,3]
314        let sums = _mm_add_ps(sum128, shuf); // [0+1, _, 2+3, _]
315        let shuf2 = _mm_movehl_ps(sums, sums); // [2+3, _, _, _]
316        let result = _mm_add_ss(sums, shuf2);
317        _mm_cvtss_f32(result)
318    }
319}
320
321// ---------------------------------------------------------------------------
322// NEON implementation (aarch64)
323// ---------------------------------------------------------------------------
324#[cfg(target_arch = "aarch64")]
325mod neon {
326    use std::arch::aarch64::*;
327
328    const NEON_F32_LANES: usize = 4;
329
330    /// NEON vector addition: processes 4 f32s per iteration.
331    ///
332    /// # Safety
333    /// Caller must ensure all slices have the same length. NEON is mandatory on aarch64.
334    pub unsafe fn vector_add_f32_neon(a: &[f32], b: &[f32], c: &mut [f32]) {
335        let n = a.len();
336        let chunks = n / NEON_F32_LANES;
337        let remainder = n % NEON_F32_LANES;
338
339        for i in 0..chunks {
340            let offset = i * NEON_F32_LANES;
341            let va = vld1q_f32(a.as_ptr().add(offset));
342            let vb = vld1q_f32(b.as_ptr().add(offset));
343            let vc = vaddq_f32(va, vb);
344            vst1q_f32(c.as_mut_ptr().add(offset), vc);
345        }
346
347        let tail_start = chunks * NEON_F32_LANES;
348        for i in 0..remainder {
349            c[tail_start + i] = a[tail_start + i] + b[tail_start + i];
350        }
351    }
352
353    /// NEON vector multiplication.
354    pub unsafe fn vector_mul_f32_neon(a: &[f32], b: &[f32], c: &mut [f32]) {
355        let n = a.len();
356        let chunks = n / NEON_F32_LANES;
357        let remainder = n % NEON_F32_LANES;
358
359        for i in 0..chunks {
360            let offset = i * NEON_F32_LANES;
361            let va = vld1q_f32(a.as_ptr().add(offset));
362            let vb = vld1q_f32(b.as_ptr().add(offset));
363            let vc = vmulq_f32(va, vb);
364            vst1q_f32(c.as_mut_ptr().add(offset), vc);
365        }
366
367        let tail_start = chunks * NEON_F32_LANES;
368        for i in 0..remainder {
369            c[tail_start + i] = a[tail_start + i] * b[tail_start + i];
370        }
371    }
372
373    /// NEON scalar multiplication.
374    pub unsafe fn vector_scale_f32_neon(a: &[f32], scalar: f32, c: &mut [f32]) {
375        let n = a.len();
376        let chunks = n / NEON_F32_LANES;
377        let remainder = n % NEON_F32_LANES;
378        let vs = vdupq_n_f32(scalar);
379
380        for i in 0..chunks {
381            let offset = i * NEON_F32_LANES;
382            let va = vld1q_f32(a.as_ptr().add(offset));
383            let vc = vmulq_f32(va, vs);
384            vst1q_f32(c.as_mut_ptr().add(offset), vc);
385        }
386
387        let tail_start = chunks * NEON_F32_LANES;
388        for i in 0..remainder {
389            c[tail_start + i] = a[tail_start + i] * scalar;
390        }
391    }
392
393    /// NEON dot product.
394    pub unsafe fn vector_dot_f32_neon(a: &[f32], b: &[f32]) -> f32 {
395        let n = a.len();
396        let chunks = n / NEON_F32_LANES;
397        let remainder = n % NEON_F32_LANES;
398
399        let mut acc = vdupq_n_f32(0.0);
400
401        for i in 0..chunks {
402            let offset = i * NEON_F32_LANES;
403            let va = vld1q_f32(a.as_ptr().add(offset));
404            let vb = vld1q_f32(b.as_ptr().add(offset));
405            acc = vfmaq_f32(acc, va, vb);
406        }
407
408        let sum = vaddvq_f32(acc);
409
410        let tail_start = chunks * NEON_F32_LANES;
411        let mut tail_sum = 0.0f32;
412        for i in 0..remainder {
413            tail_sum += a[tail_start + i] * b[tail_start + i];
414        }
415
416        sum + tail_sum
417    }
418
419    /// NEON sum reduction.
420    pub unsafe fn vector_reduce_sum_f32_neon(a: &[f32]) -> f32 {
421        let n = a.len();
422        let chunks = n / NEON_F32_LANES;
423        let remainder = n % NEON_F32_LANES;
424
425        let mut acc = vdupq_n_f32(0.0);
426
427        for i in 0..chunks {
428            let offset = i * NEON_F32_LANES;
429            let va = vld1q_f32(a.as_ptr().add(offset));
430            acc = vaddq_f32(acc, va);
431        }
432
433        let sum = vaddvq_f32(acc);
434
435        let tail_start = chunks * NEON_F32_LANES;
436        let mut tail_sum = 0.0f32;
437        for i in 0..remainder {
438            tail_sum += a[tail_start + i];
439        }
440
441        sum + tail_sum
442    }
443}
444
445// ---------------------------------------------------------------------------
446// Tests
447// ---------------------------------------------------------------------------
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    const EPSILON: f32 = 1e-5;
453
454    fn approx_eq(a: f32, b: f32) -> bool {
455        (a - b).abs() < EPSILON
456    }
457
458    #[test]
459    fn test_vector_add_basic() {
460        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
461        let b = vec![9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
462        let mut c = vec![0.0; 9];
463
464        vector_add_f32(&a, &b, &mut c);
465
466        for val in &c {
467            assert!(approx_eq(*val, 10.0), "Expected 10.0, got {val}");
468        }
469    }
470
471    #[test]
472    fn test_vector_mul_basic() {
473        let a = vec![1.0, 2.0, 3.0, 4.0];
474        let b = vec![2.0, 3.0, 4.0, 5.0];
475        let mut c = vec![0.0; 4];
476
477        vector_mul_f32(&a, &b, &mut c);
478
479        assert!(approx_eq(c[0], 2.0));
480        assert!(approx_eq(c[1], 6.0));
481        assert!(approx_eq(c[2], 12.0));
482        assert!(approx_eq(c[3], 20.0));
483    }
484
485    #[test]
486    fn test_vector_scale_basic() {
487        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
488        let mut c = vec![0.0; 5];
489
490        vector_scale_f32(&a, 3.0, &mut c);
491
492        assert!(approx_eq(c[0], 3.0));
493        assert!(approx_eq(c[1], 6.0));
494        assert!(approx_eq(c[2], 9.0));
495        assert!(approx_eq(c[3], 12.0));
496        assert!(approx_eq(c[4], 15.0));
497    }
498
499    #[test]
500    fn test_vector_dot_basic() {
501        let a = vec![1.0, 2.0, 3.0, 4.0];
502        let b = vec![1.0, 1.0, 1.0, 1.0];
503
504        let result = vector_dot_f32(&a, &b);
505        assert!(approx_eq(result, 10.0));
506    }
507
508    #[test]
509    fn test_vector_reduce_sum_basic() {
510        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
511        let result = vector_reduce_sum_f32(&a);
512        assert!(approx_eq(result, 55.0));
513    }
514
515    #[test]
516    fn test_empty_vectors() {
517        let a: Vec<f32> = vec![];
518        let b: Vec<f32> = vec![];
519        let mut c: Vec<f32> = vec![];
520
521        vector_add_f32(&a, &b, &mut c);
522        vector_mul_f32(&a, &b, &mut c);
523        vector_scale_f32(&a, 2.0, &mut c);
524        assert!(approx_eq(vector_dot_f32(&a, &b), 0.0));
525        assert!(approx_eq(vector_reduce_sum_f32(&a), 0.0));
526    }
527
528    #[test]
529    fn test_large_vector() {
530        let n = 1024;
531        let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
532        let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
533        let mut c = vec![0.0; n];
534
535        vector_add_f32(&a, &b, &mut c);
536
537        for val in &c {
538            assert!(approx_eq(*val, n as f32));
539        }
540    }
541
542    #[test]
543    #[should_panic(expected = "a.len() != b.len()")]
544    fn test_mismatched_lengths_add() {
545        let a = vec![1.0, 2.0];
546        let b = vec![1.0];
547        let mut c = vec![0.0; 2];
548        vector_add_f32(&a, &b, &mut c);
549    }
550}