oxiblas_core/simd/
complex.rs

1//! Complex number SIMD types.
2//!
3//! This module provides SIMD types for complex numbers in interleaved format.
4//! Complex numbers are stored as `[re0, im0, re1, im1, ...]` which allows
5//! efficient SIMD operations.
6//!
7//! # Layout
8//!
9//! - `C64x2`: 2 complex f64 values using 256-bit register (4 f64 lanes)
10//! - `C32x4`: 4 complex f32 values using 256-bit register (8 f32 lanes)
11//!
12//! # Operations
13//!
14//! Complex multiplication: `(a + bi)(c + di) = (ac - bd) + (ad + bc)i`
15//! This requires shuffle operations to separate real/imaginary parts.
16
17use num_complex::{Complex32, Complex64};
18
19/// Trait for complex SIMD register types.
20pub trait ComplexSimdRegister: Copy + Clone {
21    /// The underlying real scalar type.
22    type Real;
23    /// The complex scalar type.
24    type Complex;
25    /// Number of complex values in this register.
26    const COMPLEX_LANES: usize;
27
28    /// Creates a register with all complex values set to zero.
29    fn zero() -> Self;
30
31    /// Creates a register with all complex values set to the same value.
32    fn splat(value: Self::Complex) -> Self;
33
34    /// Loads complex values from aligned memory.
35    ///
36    /// # Safety
37    /// The pointer must be aligned and point to at least `COMPLEX_LANES` complex values.
38    unsafe fn load_aligned(ptr: *const Self::Complex) -> Self;
39
40    /// Loads complex values from unaligned memory.
41    ///
42    /// # Safety
43    /// The pointer must point to at least `COMPLEX_LANES` complex values.
44    unsafe fn load_unaligned(ptr: *const Self::Complex) -> Self;
45
46    /// Stores complex values to aligned memory.
47    ///
48    /// # Safety
49    /// The pointer must be aligned and point to space for at least `COMPLEX_LANES` complex values.
50    unsafe fn store_aligned(self, ptr: *mut Self::Complex);
51
52    /// Stores complex values to unaligned memory.
53    ///
54    /// # Safety
55    /// The pointer must point to space for at least `COMPLEX_LANES` complex values.
56    unsafe fn store_unaligned(self, ptr: *mut Self::Complex);
57
58    /// Complex addition.
59    fn add(self, other: Self) -> Self;
60
61    /// Complex subtraction.
62    fn sub(self, other: Self) -> Self;
63
64    /// Complex multiplication.
65    fn mul(self, other: Self) -> Self;
66
67    /// Multiplies by a real scalar.
68    fn scale_real(self, scalar: Self::Real) -> Self;
69
70    /// Conjugate: negates the imaginary part.
71    fn conj(self) -> Self;
72
73    /// Extracts a complex value at the given index.
74    fn extract(self, index: usize) -> Self::Complex;
75
76    /// Inserts a complex value at the given index.
77    fn insert(self, index: usize, value: Self::Complex) -> Self;
78
79    /// Horizontal sum of all complex values.
80    fn reduce_sum(self) -> Self::Complex;
81}
82
83// =============================================================================
84// Scalar fallback for complex SIMD
85// =============================================================================
86
87/// Scalar "register" for Complex64 - processes one complex value at a time.
88#[derive(Clone, Copy, Debug)]
89#[repr(transparent)]
90pub struct ScalarC64(pub Complex64);
91
92impl ComplexSimdRegister for ScalarC64 {
93    type Real = f64;
94    type Complex = Complex64;
95    const COMPLEX_LANES: usize = 1;
96
97    #[inline]
98    fn zero() -> Self {
99        ScalarC64(Complex64::new(0.0, 0.0))
100    }
101
102    #[inline]
103    fn splat(value: Complex64) -> Self {
104        ScalarC64(value)
105    }
106
107    #[inline]
108    unsafe fn load_aligned(ptr: *const Complex64) -> Self {
109        ScalarC64(*ptr)
110    }
111
112    #[inline]
113    unsafe fn load_unaligned(ptr: *const Complex64) -> Self {
114        ScalarC64(*ptr)
115    }
116
117    #[inline]
118    unsafe fn store_aligned(self, ptr: *mut Complex64) {
119        *ptr = self.0;
120    }
121
122    #[inline]
123    unsafe fn store_unaligned(self, ptr: *mut Complex64) {
124        *ptr = self.0;
125    }
126
127    #[inline]
128    fn add(self, other: Self) -> Self {
129        ScalarC64(self.0 + other.0)
130    }
131
132    #[inline]
133    fn sub(self, other: Self) -> Self {
134        ScalarC64(self.0 - other.0)
135    }
136
137    #[inline]
138    fn mul(self, other: Self) -> Self {
139        ScalarC64(self.0 * other.0)
140    }
141
142    #[inline]
143    fn scale_real(self, scalar: f64) -> Self {
144        ScalarC64(Complex64::new(self.0.re * scalar, self.0.im * scalar))
145    }
146
147    #[inline]
148    fn conj(self) -> Self {
149        ScalarC64(self.0.conj())
150    }
151
152    #[inline]
153    fn extract(self, _index: usize) -> Complex64 {
154        self.0
155    }
156
157    #[inline]
158    fn insert(self, _index: usize, value: Complex64) -> Self {
159        ScalarC64(value)
160    }
161
162    #[inline]
163    fn reduce_sum(self) -> Complex64 {
164        self.0
165    }
166}
167
168/// Scalar "register" for Complex32 - processes one complex value at a time.
169#[derive(Clone, Copy, Debug)]
170#[repr(transparent)]
171pub struct ScalarC32(pub Complex32);
172
173impl ComplexSimdRegister for ScalarC32 {
174    type Real = f32;
175    type Complex = Complex32;
176    const COMPLEX_LANES: usize = 1;
177
178    #[inline]
179    fn zero() -> Self {
180        ScalarC32(Complex32::new(0.0, 0.0))
181    }
182
183    #[inline]
184    fn splat(value: Complex32) -> Self {
185        ScalarC32(value)
186    }
187
188    #[inline]
189    unsafe fn load_aligned(ptr: *const Complex32) -> Self {
190        ScalarC32(*ptr)
191    }
192
193    #[inline]
194    unsafe fn load_unaligned(ptr: *const Complex32) -> Self {
195        ScalarC32(*ptr)
196    }
197
198    #[inline]
199    unsafe fn store_aligned(self, ptr: *mut Complex32) {
200        *ptr = self.0;
201    }
202
203    #[inline]
204    unsafe fn store_unaligned(self, ptr: *mut Complex32) {
205        *ptr = self.0;
206    }
207
208    #[inline]
209    fn add(self, other: Self) -> Self {
210        ScalarC32(self.0 + other.0)
211    }
212
213    #[inline]
214    fn sub(self, other: Self) -> Self {
215        ScalarC32(self.0 - other.0)
216    }
217
218    #[inline]
219    fn mul(self, other: Self) -> Self {
220        ScalarC32(self.0 * other.0)
221    }
222
223    #[inline]
224    fn scale_real(self, scalar: f32) -> Self {
225        ScalarC32(Complex32::new(self.0.re * scalar, self.0.im * scalar))
226    }
227
228    #[inline]
229    fn conj(self) -> Self {
230        ScalarC32(self.0.conj())
231    }
232
233    #[inline]
234    fn extract(self, _index: usize) -> Complex32 {
235        self.0
236    }
237
238    #[inline]
239    fn insert(self, _index: usize, value: Complex32) -> Self {
240        ScalarC32(value)
241    }
242
243    #[inline]
244    fn reduce_sum(self) -> Complex32 {
245        self.0
246    }
247}
248
249// =============================================================================
250// AArch64 NEON complex SIMD
251// =============================================================================
252
253#[cfg(target_arch = "aarch64")]
254mod aarch64_impl {
255    use super::*;
256    use core::arch::aarch64::*;
257
258    /// 2 complex f64 values using two 128-bit NEON registers.
259    #[derive(Clone, Copy)]
260    pub struct C64x2 {
261        /// First complex value (re0, im0)
262        c0: float64x2_t,
263        /// Second complex value (re1, im1)
264        c1: float64x2_t,
265    }
266
267    impl ComplexSimdRegister for C64x2 {
268        type Real = f64;
269        type Complex = Complex64;
270        const COMPLEX_LANES: usize = 2;
271
272        #[inline]
273        fn zero() -> Self {
274            unsafe {
275                C64x2 {
276                    c0: vdupq_n_f64(0.0),
277                    c1: vdupq_n_f64(0.0),
278                }
279            }
280        }
281
282        #[inline]
283        fn splat(value: Complex64) -> Self {
284            unsafe {
285                let c = vld1q_f64([value.re, value.im].as_ptr());
286                C64x2 { c0: c, c1: c }
287            }
288        }
289
290        #[inline]
291        unsafe fn load_aligned(ptr: *const Complex64) -> Self {
292            let p = ptr as *const f64;
293            C64x2 {
294                c0: vld1q_f64(p),
295                c1: vld1q_f64(p.add(2)),
296            }
297        }
298
299        #[inline]
300        unsafe fn load_unaligned(ptr: *const Complex64) -> Self {
301            Self::load_aligned(ptr)
302        }
303
304        #[inline]
305        unsafe fn store_aligned(self, ptr: *mut Complex64) {
306            let p = ptr as *mut f64;
307            vst1q_f64(p, self.c0);
308            vst1q_f64(p.add(2), self.c1);
309        }
310
311        #[inline]
312        unsafe fn store_unaligned(self, ptr: *mut Complex64) {
313            self.store_aligned(ptr);
314        }
315
316        #[inline]
317        fn add(self, other: Self) -> Self {
318            unsafe {
319                C64x2 {
320                    c0: vaddq_f64(self.c0, other.c0),
321                    c1: vaddq_f64(self.c1, other.c1),
322                }
323            }
324        }
325
326        #[inline]
327        fn sub(self, other: Self) -> Self {
328            unsafe {
329                C64x2 {
330                    c0: vsubq_f64(self.c0, other.c0),
331                    c1: vsubq_f64(self.c1, other.c1),
332                }
333            }
334        }
335
336        #[inline]
337        fn mul(self, other: Self) -> Self {
338            // (a + bi)(c + di) = (ac - bd) + (ad + bc)i
339            unsafe {
340                // For c0: self.c0 = [a, b], other.c0 = [c, d]
341                let a = vdupq_laneq_f64(self.c0, 0); // [a, a]
342                let b = vdupq_laneq_f64(self.c0, 1); // [b, b]
343                let c = vdupq_laneq_f64(other.c0, 0); // [c, c]
344                let d = vdupq_laneq_f64(other.c0, 1); // [d, d]
345
346                // ac, ad
347                let ac = vmulq_f64(a, c);
348                let ad = vmulq_f64(a, d);
349                // bd, bc
350                let bd = vmulq_f64(b, d);
351                let bc = vmulq_f64(b, c);
352
353                // [ac - bd, ad + bc]
354                let re0 = vsubq_f64(ac, bd);
355                let im0 = vaddq_f64(ad, bc);
356                let c0_new = vzip1q_f64(re0, im0);
357
358                // Same for c1
359                let a1 = vdupq_laneq_f64(self.c1, 0);
360                let b1 = vdupq_laneq_f64(self.c1, 1);
361                let c1 = vdupq_laneq_f64(other.c1, 0);
362                let d1 = vdupq_laneq_f64(other.c1, 1);
363
364                let ac1 = vmulq_f64(a1, c1);
365                let ad1 = vmulq_f64(a1, d1);
366                let bd1 = vmulq_f64(b1, d1);
367                let bc1 = vmulq_f64(b1, c1);
368
369                let re1 = vsubq_f64(ac1, bd1);
370                let im1 = vaddq_f64(ad1, bc1);
371                let c1_new = vzip1q_f64(re1, im1);
372
373                C64x2 {
374                    c0: c0_new,
375                    c1: c1_new,
376                }
377            }
378        }
379
380        #[inline]
381        fn scale_real(self, scalar: f64) -> Self {
382            unsafe {
383                let s = vdupq_n_f64(scalar);
384                C64x2 {
385                    c0: vmulq_f64(self.c0, s),
386                    c1: vmulq_f64(self.c1, s),
387                }
388            }
389        }
390
391        #[inline]
392        fn conj(self) -> Self {
393            unsafe {
394                // Negate imaginary parts: [re, -im]
395                let neg_mask = vld1q_f64([1.0, -1.0].as_ptr());
396                C64x2 {
397                    c0: vmulq_f64(self.c0, neg_mask),
398                    c1: vmulq_f64(self.c1, neg_mask),
399                }
400            }
401        }
402
403        #[inline]
404        fn extract(self, index: usize) -> Complex64 {
405            debug_assert!(index < 2);
406            unsafe {
407                let arr = if index == 0 {
408                    let mut a = [0.0f64; 2];
409                    vst1q_f64(a.as_mut_ptr(), self.c0);
410                    a
411                } else {
412                    let mut a = [0.0f64; 2];
413                    vst1q_f64(a.as_mut_ptr(), self.c1);
414                    a
415                };
416                Complex64::new(arr[0], arr[1])
417            }
418        }
419
420        #[inline]
421        fn insert(self, index: usize, value: Complex64) -> Self {
422            debug_assert!(index < 2);
423            unsafe {
424                let new_c = vld1q_f64([value.re, value.im].as_ptr());
425                if index == 0 {
426                    C64x2 {
427                        c0: new_c,
428                        c1: self.c1,
429                    }
430                } else {
431                    C64x2 {
432                        c0: self.c0,
433                        c1: new_c,
434                    }
435                }
436            }
437        }
438
439        #[inline]
440        fn reduce_sum(self) -> Complex64 {
441            unsafe {
442                let sum = vaddq_f64(self.c0, self.c1);
443                let mut arr = [0.0f64; 2];
444                vst1q_f64(arr.as_mut_ptr(), sum);
445                Complex64::new(arr[0], arr[1])
446            }
447        }
448    }
449
450    /// 4 complex f32 values using two 128-bit NEON registers.
451    #[derive(Clone, Copy)]
452    pub struct C32x4 {
453        /// First two complex values (re0, im0, re1, im1)
454        lo: float32x4_t,
455        /// Second two complex values (re2, im2, re3, im3)
456        hi: float32x4_t,
457    }
458
459    impl ComplexSimdRegister for C32x4 {
460        type Real = f32;
461        type Complex = Complex32;
462        const COMPLEX_LANES: usize = 4;
463
464        #[inline]
465        fn zero() -> Self {
466            unsafe {
467                C32x4 {
468                    lo: vdupq_n_f32(0.0),
469                    hi: vdupq_n_f32(0.0),
470                }
471            }
472        }
473
474        #[inline]
475        fn splat(value: Complex32) -> Self {
476            unsafe {
477                let vals = [value.re, value.im, value.re, value.im];
478                let v = vld1q_f32(vals.as_ptr());
479                C32x4 { lo: v, hi: v }
480            }
481        }
482
483        #[inline]
484        unsafe fn load_aligned(ptr: *const Complex32) -> Self {
485            let p = ptr as *const f32;
486            C32x4 {
487                lo: vld1q_f32(p),
488                hi: vld1q_f32(p.add(4)),
489            }
490        }
491
492        #[inline]
493        unsafe fn load_unaligned(ptr: *const Complex32) -> Self {
494            Self::load_aligned(ptr)
495        }
496
497        #[inline]
498        unsafe fn store_aligned(self, ptr: *mut Complex32) {
499            let p = ptr as *mut f32;
500            vst1q_f32(p, self.lo);
501            vst1q_f32(p.add(4), self.hi);
502        }
503
504        #[inline]
505        unsafe fn store_unaligned(self, ptr: *mut Complex32) {
506            self.store_aligned(ptr);
507        }
508
509        #[inline]
510        fn add(self, other: Self) -> Self {
511            unsafe {
512                C32x4 {
513                    lo: vaddq_f32(self.lo, other.lo),
514                    hi: vaddq_f32(self.hi, other.hi),
515                }
516            }
517        }
518
519        #[inline]
520        fn sub(self, other: Self) -> Self {
521            unsafe {
522                C32x4 {
523                    lo: vsubq_f32(self.lo, other.lo),
524                    hi: vsubq_f32(self.hi, other.hi),
525                }
526            }
527        }
528
529        #[inline]
530        fn mul(self, other: Self) -> Self {
531            // For each pair [a, b] * [c, d] = [ac-bd, ad+bc]
532            // Manual implementation using shuffle and FMA
533            unsafe {
534                // lo = [a0, b0, a1, b1], other.lo = [c0, d0, c1, d1]
535                // We need: [a0*c0 - b0*d0, a0*d0 + b0*c0, a1*c1 - b1*d1, a1*d1 + b1*c1]
536
537                // Extract real and imaginary parts using zip/unzip
538                // uzp1 gives [a0, a1, c0, c1] when applied to two interleaved vectors
539                // uzp2 gives [b0, b1, d0, d1]
540
541                // For lo register:
542                let reals_self_lo = vuzp1q_f32(self.lo, self.lo); // [a0, a1, a0, a1]
543                let imags_self_lo = vuzp2q_f32(self.lo, self.lo); // [b0, b1, b0, b1]
544                let reals_other_lo = vuzp1q_f32(other.lo, other.lo); // [c0, c1, c0, c1]
545                let imags_other_lo = vuzp2q_f32(other.lo, other.lo); // [d0, d1, d0, d1]
546
547                // ac, bd, ad, bc
548                let ac_lo = vmulq_f32(reals_self_lo, reals_other_lo);
549                let bd_lo = vmulq_f32(imags_self_lo, imags_other_lo);
550                let ad_lo = vmulq_f32(reals_self_lo, imags_other_lo);
551                let bc_lo = vmulq_f32(imags_self_lo, reals_other_lo);
552
553                // ac - bd (real part), ad + bc (imag part)
554                let re_lo = vsubq_f32(ac_lo, bd_lo);
555                let im_lo = vaddq_f32(ad_lo, bc_lo);
556
557                // Interleave back: [re0, im0, re1, im1]
558                let lo_result = vzip1q_f32(re_lo, im_lo);
559
560                // Same for hi register
561                let reals_self_hi = vuzp1q_f32(self.hi, self.hi);
562                let imags_self_hi = vuzp2q_f32(self.hi, self.hi);
563                let reals_other_hi = vuzp1q_f32(other.hi, other.hi);
564                let imags_other_hi = vuzp2q_f32(other.hi, other.hi);
565
566                let ac_hi = vmulq_f32(reals_self_hi, reals_other_hi);
567                let bd_hi = vmulq_f32(imags_self_hi, imags_other_hi);
568                let ad_hi = vmulq_f32(reals_self_hi, imags_other_hi);
569                let bc_hi = vmulq_f32(imags_self_hi, reals_other_hi);
570
571                let re_hi = vsubq_f32(ac_hi, bd_hi);
572                let im_hi = vaddq_f32(ad_hi, bc_hi);
573                let hi_result = vzip1q_f32(re_hi, im_hi);
574
575                C32x4 {
576                    lo: lo_result,
577                    hi: hi_result,
578                }
579            }
580        }
581
582        #[inline]
583        fn scale_real(self, scalar: f32) -> Self {
584            unsafe {
585                let s = vdupq_n_f32(scalar);
586                C32x4 {
587                    lo: vmulq_f32(self.lo, s),
588                    hi: vmulq_f32(self.hi, s),
589                }
590            }
591        }
592
593        #[inline]
594        fn conj(self) -> Self {
595            unsafe {
596                let neg_mask = vld1q_f32([1.0, -1.0, 1.0, -1.0].as_ptr());
597                C32x4 {
598                    lo: vmulq_f32(self.lo, neg_mask),
599                    hi: vmulq_f32(self.hi, neg_mask),
600                }
601            }
602        }
603
604        #[inline]
605        fn extract(self, index: usize) -> Complex32 {
606            debug_assert!(index < 4);
607            unsafe {
608                let mut arr = [0.0f32; 8];
609                vst1q_f32(arr.as_mut_ptr(), self.lo);
610                vst1q_f32(arr.as_mut_ptr().add(4), self.hi);
611                Complex32::new(arr[index * 2], arr[index * 2 + 1])
612            }
613        }
614
615        #[inline]
616        fn insert(self, index: usize, value: Complex32) -> Self {
617            debug_assert!(index < 4);
618            unsafe {
619                let mut arr = [0.0f32; 8];
620                vst1q_f32(arr.as_mut_ptr(), self.lo);
621                vst1q_f32(arr.as_mut_ptr().add(4), self.hi);
622                arr[index * 2] = value.re;
623                arr[index * 2 + 1] = value.im;
624                C32x4 {
625                    lo: vld1q_f32(arr.as_ptr()),
626                    hi: vld1q_f32(arr.as_ptr().add(4)),
627                }
628            }
629        }
630
631        #[inline]
632        fn reduce_sum(self) -> Complex32 {
633            unsafe {
634                let sum = vaddq_f32(self.lo, self.hi);
635                // sum = [a, b, c, d] where (a,b) and (c,d) are complex
636                let mut arr = [0.0f32; 4];
637                vst1q_f32(arr.as_mut_ptr(), sum);
638                Complex32::new(arr[0] + arr[2], arr[1] + arr[3])
639            }
640        }
641    }
642}
643
644#[cfg(target_arch = "aarch64")]
645pub use aarch64_impl::{C32x4, C64x2};
646
647// =============================================================================
648// Trait for complex scalar types to select SIMD register
649// =============================================================================
650
651/// Trait for complex scalar types with associated SIMD register.
652pub trait ComplexSimdScalar: Copy {
653    /// 256-bit SIMD register type for this complex scalar.
654    type Simd256: ComplexSimdRegister<Complex = Self>;
655}
656
657impl ComplexSimdScalar for Complex64 {
658    #[cfg(target_arch = "aarch64")]
659    type Simd256 = C64x2;
660    #[cfg(not(target_arch = "aarch64"))]
661    type Simd256 = ScalarC64;
662}
663
664impl ComplexSimdScalar for Complex32 {
665    #[cfg(target_arch = "aarch64")]
666    type Simd256 = C32x4;
667    #[cfg(not(target_arch = "aarch64"))]
668    type Simd256 = ScalarC32;
669}
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674
675    #[test]
676    fn test_scalar_c64_basic() {
677        let a = ScalarC64::splat(Complex64::new(2.0, 3.0));
678        let b = ScalarC64::splat(Complex64::new(4.0, 5.0));
679
680        // Addition
681        let sum = a.add(b);
682        assert_eq!(sum.0, Complex64::new(6.0, 8.0));
683
684        // Multiplication: (2+3i)(4+5i) = 8 + 10i + 12i - 15 = -7 + 22i
685        let prod = a.mul(b);
686        assert_eq!(prod.0, Complex64::new(-7.0, 22.0));
687
688        // Conjugate
689        let conj = a.conj();
690        assert_eq!(conj.0, Complex64::new(2.0, -3.0));
691    }
692
693    #[test]
694    fn test_scalar_c32_basic() {
695        let a = ScalarC32::splat(Complex32::new(1.0, 2.0));
696        let b = ScalarC32::splat(Complex32::new(3.0, 4.0));
697
698        let sum = a.add(b);
699        assert_eq!(sum.0, Complex32::new(4.0, 6.0));
700
701        // (1+2i)(3+4i) = 3 + 4i + 6i - 8 = -5 + 10i
702        let prod = a.mul(b);
703        assert_eq!(prod.0, Complex32::new(-5.0, 10.0));
704    }
705
706    #[test]
707    fn test_scalar_scale_real() {
708        let a = ScalarC64::splat(Complex64::new(2.0, 3.0));
709        let scaled = a.scale_real(2.0);
710        assert_eq!(scaled.0, Complex64::new(4.0, 6.0));
711    }
712
713    #[cfg(target_arch = "aarch64")]
714    #[test]
715    fn test_c64x2_basic() {
716        let a = C64x2::splat(Complex64::new(2.0, 3.0));
717        let b = C64x2::splat(Complex64::new(4.0, 5.0));
718
719        let sum = a.add(b);
720        assert_eq!(sum.extract(0), Complex64::new(6.0, 8.0));
721        assert_eq!(sum.extract(1), Complex64::new(6.0, 8.0));
722
723        let conj = a.conj();
724        assert_eq!(conj.extract(0), Complex64::new(2.0, -3.0));
725    }
726
727    #[cfg(target_arch = "aarch64")]
728    #[test]
729    fn test_c64x2_mul() {
730        let a = C64x2::splat(Complex64::new(2.0, 3.0));
731        let b = C64x2::splat(Complex64::new(4.0, 5.0));
732
733        // (2+3i)(4+5i) = 8 + 10i + 12i - 15 = -7 + 22i
734        let prod = a.mul(b);
735        let result = prod.extract(0);
736
737        assert!((result.re - (-7.0)).abs() < 1e-10);
738        assert!((result.im - 22.0).abs() < 1e-10);
739    }
740
741    #[cfg(target_arch = "aarch64")]
742    #[test]
743    fn test_c64x2_reduce_sum() {
744        let a = C64x2::zero()
745            .insert(0, Complex64::new(1.0, 2.0))
746            .insert(1, Complex64::new(3.0, 4.0));
747
748        let sum = a.reduce_sum();
749        assert_eq!(sum, Complex64::new(4.0, 6.0));
750    }
751
752    #[cfg(target_arch = "aarch64")]
753    #[test]
754    fn test_c32x4_basic() {
755        let a = C32x4::splat(Complex32::new(1.0, 2.0));
756        let b = C32x4::splat(Complex32::new(3.0, 4.0));
757
758        let sum = a.add(b);
759        assert_eq!(sum.extract(0), Complex32::new(4.0, 6.0));
760        assert_eq!(sum.extract(3), Complex32::new(4.0, 6.0));
761    }
762
763    #[cfg(target_arch = "aarch64")]
764    #[test]
765    fn test_c32x4_reduce_sum() {
766        let a = C32x4::zero()
767            .insert(0, Complex32::new(1.0, 0.0))
768            .insert(1, Complex32::new(2.0, 0.0))
769            .insert(2, Complex32::new(3.0, 0.0))
770            .insert(3, Complex32::new(4.0, 0.0));
771
772        let sum = a.reduce_sum();
773        assert_eq!(sum, Complex32::new(10.0, 0.0));
774    }
775
776    #[cfg(target_arch = "aarch64")]
777    #[test]
778    fn test_c32x4_mul() {
779        let a = C32x4::splat(Complex32::new(1.0, 2.0));
780        let b = C32x4::splat(Complex32::new(3.0, 4.0));
781
782        // (1+2i)(3+4i) = 3 + 4i + 6i - 8 = -5 + 10i
783        let prod = a.mul(b);
784        let result = prod.extract(0);
785
786        assert!((result.re - (-5.0)).abs() < 1e-5);
787        assert!((result.im - 10.0).abs() < 1e-5);
788    }
789
790    #[cfg(target_arch = "aarch64")]
791    #[test]
792    fn test_c32x4_load_store() {
793        unsafe {
794            let data = [
795                Complex32::new(1.0, 2.0),
796                Complex32::new(3.0, 4.0),
797                Complex32::new(5.0, 6.0),
798                Complex32::new(7.0, 8.0),
799            ];
800
801            let v = C32x4::load_unaligned(data.as_ptr());
802
803            assert_eq!(v.extract(0), Complex32::new(1.0, 2.0));
804            assert_eq!(v.extract(1), Complex32::new(3.0, 4.0));
805            assert_eq!(v.extract(2), Complex32::new(5.0, 6.0));
806            assert_eq!(v.extract(3), Complex32::new(7.0, 8.0));
807
808            let mut out = [Complex32::new(0.0, 0.0); 4];
809            v.store_unaligned(out.as_mut_ptr());
810
811            assert_eq!(out, data);
812        }
813    }
814}