1use num_complex::{Complex32, Complex64};
18
19pub trait ComplexSimdRegister: Copy + Clone {
21 type Real;
23 type Complex;
25 const COMPLEX_LANES: usize;
27
28 fn zero() -> Self;
30
31 fn splat(value: Self::Complex) -> Self;
33
34 unsafe fn load_aligned(ptr: *const Self::Complex) -> Self;
39
40 unsafe fn load_unaligned(ptr: *const Self::Complex) -> Self;
45
46 unsafe fn store_aligned(self, ptr: *mut Self::Complex);
51
52 unsafe fn store_unaligned(self, ptr: *mut Self::Complex);
57
58 fn add(self, other: Self) -> Self;
60
61 fn sub(self, other: Self) -> Self;
63
64 fn mul(self, other: Self) -> Self;
66
67 fn scale_real(self, scalar: Self::Real) -> Self;
69
70 fn conj(self) -> Self;
72
73 fn extract(self, index: usize) -> Self::Complex;
75
76 fn insert(self, index: usize, value: Self::Complex) -> Self;
78
79 fn reduce_sum(self) -> Self::Complex;
81}
82
83#[derive(Clone, Copy, Debug)]
89#[repr(transparent)]
90pub struct ScalarC64(pub Complex64);
91
92impl ComplexSimdRegister for ScalarC64 {
93 type Real = f64;
94 type Complex = Complex64;
95 const COMPLEX_LANES: usize = 1;
96
97 #[inline]
98 fn zero() -> Self {
99 ScalarC64(Complex64::new(0.0, 0.0))
100 }
101
102 #[inline]
103 fn splat(value: Complex64) -> Self {
104 ScalarC64(value)
105 }
106
107 #[inline]
108 unsafe fn load_aligned(ptr: *const Complex64) -> Self {
109 ScalarC64(*ptr)
110 }
111
112 #[inline]
113 unsafe fn load_unaligned(ptr: *const Complex64) -> Self {
114 ScalarC64(*ptr)
115 }
116
117 #[inline]
118 unsafe fn store_aligned(self, ptr: *mut Complex64) {
119 *ptr = self.0;
120 }
121
122 #[inline]
123 unsafe fn store_unaligned(self, ptr: *mut Complex64) {
124 *ptr = self.0;
125 }
126
127 #[inline]
128 fn add(self, other: Self) -> Self {
129 ScalarC64(self.0 + other.0)
130 }
131
132 #[inline]
133 fn sub(self, other: Self) -> Self {
134 ScalarC64(self.0 - other.0)
135 }
136
137 #[inline]
138 fn mul(self, other: Self) -> Self {
139 ScalarC64(self.0 * other.0)
140 }
141
142 #[inline]
143 fn scale_real(self, scalar: f64) -> Self {
144 ScalarC64(Complex64::new(self.0.re * scalar, self.0.im * scalar))
145 }
146
147 #[inline]
148 fn conj(self) -> Self {
149 ScalarC64(self.0.conj())
150 }
151
152 #[inline]
153 fn extract(self, _index: usize) -> Complex64 {
154 self.0
155 }
156
157 #[inline]
158 fn insert(self, _index: usize, value: Complex64) -> Self {
159 ScalarC64(value)
160 }
161
162 #[inline]
163 fn reduce_sum(self) -> Complex64 {
164 self.0
165 }
166}
167
168#[derive(Clone, Copy, Debug)]
170#[repr(transparent)]
171pub struct ScalarC32(pub Complex32);
172
173impl ComplexSimdRegister for ScalarC32 {
174 type Real = f32;
175 type Complex = Complex32;
176 const COMPLEX_LANES: usize = 1;
177
178 #[inline]
179 fn zero() -> Self {
180 ScalarC32(Complex32::new(0.0, 0.0))
181 }
182
183 #[inline]
184 fn splat(value: Complex32) -> Self {
185 ScalarC32(value)
186 }
187
188 #[inline]
189 unsafe fn load_aligned(ptr: *const Complex32) -> Self {
190 ScalarC32(*ptr)
191 }
192
193 #[inline]
194 unsafe fn load_unaligned(ptr: *const Complex32) -> Self {
195 ScalarC32(*ptr)
196 }
197
198 #[inline]
199 unsafe fn store_aligned(self, ptr: *mut Complex32) {
200 *ptr = self.0;
201 }
202
203 #[inline]
204 unsafe fn store_unaligned(self, ptr: *mut Complex32) {
205 *ptr = self.0;
206 }
207
208 #[inline]
209 fn add(self, other: Self) -> Self {
210 ScalarC32(self.0 + other.0)
211 }
212
213 #[inline]
214 fn sub(self, other: Self) -> Self {
215 ScalarC32(self.0 - other.0)
216 }
217
218 #[inline]
219 fn mul(self, other: Self) -> Self {
220 ScalarC32(self.0 * other.0)
221 }
222
223 #[inline]
224 fn scale_real(self, scalar: f32) -> Self {
225 ScalarC32(Complex32::new(self.0.re * scalar, self.0.im * scalar))
226 }
227
228 #[inline]
229 fn conj(self) -> Self {
230 ScalarC32(self.0.conj())
231 }
232
233 #[inline]
234 fn extract(self, _index: usize) -> Complex32 {
235 self.0
236 }
237
238 #[inline]
239 fn insert(self, _index: usize, value: Complex32) -> Self {
240 ScalarC32(value)
241 }
242
243 #[inline]
244 fn reduce_sum(self) -> Complex32 {
245 self.0
246 }
247}
248
249#[cfg(target_arch = "aarch64")]
254mod aarch64_impl {
255 use super::*;
256 use core::arch::aarch64::*;
257
258 #[derive(Clone, Copy)]
260 pub struct C64x2 {
261 c0: float64x2_t,
263 c1: float64x2_t,
265 }
266
267 impl ComplexSimdRegister for C64x2 {
268 type Real = f64;
269 type Complex = Complex64;
270 const COMPLEX_LANES: usize = 2;
271
272 #[inline]
273 fn zero() -> Self {
274 unsafe {
275 C64x2 {
276 c0: vdupq_n_f64(0.0),
277 c1: vdupq_n_f64(0.0),
278 }
279 }
280 }
281
282 #[inline]
283 fn splat(value: Complex64) -> Self {
284 unsafe {
285 let c = vld1q_f64([value.re, value.im].as_ptr());
286 C64x2 { c0: c, c1: c }
287 }
288 }
289
290 #[inline]
291 unsafe fn load_aligned(ptr: *const Complex64) -> Self {
292 let p = ptr as *const f64;
293 C64x2 {
294 c0: vld1q_f64(p),
295 c1: vld1q_f64(p.add(2)),
296 }
297 }
298
299 #[inline]
300 unsafe fn load_unaligned(ptr: *const Complex64) -> Self {
301 Self::load_aligned(ptr)
302 }
303
304 #[inline]
305 unsafe fn store_aligned(self, ptr: *mut Complex64) {
306 let p = ptr as *mut f64;
307 vst1q_f64(p, self.c0);
308 vst1q_f64(p.add(2), self.c1);
309 }
310
311 #[inline]
312 unsafe fn store_unaligned(self, ptr: *mut Complex64) {
313 self.store_aligned(ptr);
314 }
315
316 #[inline]
317 fn add(self, other: Self) -> Self {
318 unsafe {
319 C64x2 {
320 c0: vaddq_f64(self.c0, other.c0),
321 c1: vaddq_f64(self.c1, other.c1),
322 }
323 }
324 }
325
326 #[inline]
327 fn sub(self, other: Self) -> Self {
328 unsafe {
329 C64x2 {
330 c0: vsubq_f64(self.c0, other.c0),
331 c1: vsubq_f64(self.c1, other.c1),
332 }
333 }
334 }
335
336 #[inline]
337 fn mul(self, other: Self) -> Self {
338 unsafe {
340 let a = vdupq_laneq_f64(self.c0, 0); let b = vdupq_laneq_f64(self.c0, 1); let c = vdupq_laneq_f64(other.c0, 0); let d = vdupq_laneq_f64(other.c0, 1); let ac = vmulq_f64(a, c);
348 let ad = vmulq_f64(a, d);
349 let bd = vmulq_f64(b, d);
351 let bc = vmulq_f64(b, c);
352
353 let re0 = vsubq_f64(ac, bd);
355 let im0 = vaddq_f64(ad, bc);
356 let c0_new = vzip1q_f64(re0, im0);
357
358 let a1 = vdupq_laneq_f64(self.c1, 0);
360 let b1 = vdupq_laneq_f64(self.c1, 1);
361 let c1 = vdupq_laneq_f64(other.c1, 0);
362 let d1 = vdupq_laneq_f64(other.c1, 1);
363
364 let ac1 = vmulq_f64(a1, c1);
365 let ad1 = vmulq_f64(a1, d1);
366 let bd1 = vmulq_f64(b1, d1);
367 let bc1 = vmulq_f64(b1, c1);
368
369 let re1 = vsubq_f64(ac1, bd1);
370 let im1 = vaddq_f64(ad1, bc1);
371 let c1_new = vzip1q_f64(re1, im1);
372
373 C64x2 {
374 c0: c0_new,
375 c1: c1_new,
376 }
377 }
378 }
379
380 #[inline]
381 fn scale_real(self, scalar: f64) -> Self {
382 unsafe {
383 let s = vdupq_n_f64(scalar);
384 C64x2 {
385 c0: vmulq_f64(self.c0, s),
386 c1: vmulq_f64(self.c1, s),
387 }
388 }
389 }
390
391 #[inline]
392 fn conj(self) -> Self {
393 unsafe {
394 let neg_mask = vld1q_f64([1.0, -1.0].as_ptr());
396 C64x2 {
397 c0: vmulq_f64(self.c0, neg_mask),
398 c1: vmulq_f64(self.c1, neg_mask),
399 }
400 }
401 }
402
403 #[inline]
404 fn extract(self, index: usize) -> Complex64 {
405 debug_assert!(index < 2);
406 unsafe {
407 let arr = if index == 0 {
408 let mut a = [0.0f64; 2];
409 vst1q_f64(a.as_mut_ptr(), self.c0);
410 a
411 } else {
412 let mut a = [0.0f64; 2];
413 vst1q_f64(a.as_mut_ptr(), self.c1);
414 a
415 };
416 Complex64::new(arr[0], arr[1])
417 }
418 }
419
420 #[inline]
421 fn insert(self, index: usize, value: Complex64) -> Self {
422 debug_assert!(index < 2);
423 unsafe {
424 let new_c = vld1q_f64([value.re, value.im].as_ptr());
425 if index == 0 {
426 C64x2 {
427 c0: new_c,
428 c1: self.c1,
429 }
430 } else {
431 C64x2 {
432 c0: self.c0,
433 c1: new_c,
434 }
435 }
436 }
437 }
438
439 #[inline]
440 fn reduce_sum(self) -> Complex64 {
441 unsafe {
442 let sum = vaddq_f64(self.c0, self.c1);
443 let mut arr = [0.0f64; 2];
444 vst1q_f64(arr.as_mut_ptr(), sum);
445 Complex64::new(arr[0], arr[1])
446 }
447 }
448 }
449
450 #[derive(Clone, Copy)]
452 pub struct C32x4 {
453 lo: float32x4_t,
455 hi: float32x4_t,
457 }
458
459 impl ComplexSimdRegister for C32x4 {
460 type Real = f32;
461 type Complex = Complex32;
462 const COMPLEX_LANES: usize = 4;
463
464 #[inline]
465 fn zero() -> Self {
466 unsafe {
467 C32x4 {
468 lo: vdupq_n_f32(0.0),
469 hi: vdupq_n_f32(0.0),
470 }
471 }
472 }
473
474 #[inline]
475 fn splat(value: Complex32) -> Self {
476 unsafe {
477 let vals = [value.re, value.im, value.re, value.im];
478 let v = vld1q_f32(vals.as_ptr());
479 C32x4 { lo: v, hi: v }
480 }
481 }
482
483 #[inline]
484 unsafe fn load_aligned(ptr: *const Complex32) -> Self {
485 let p = ptr as *const f32;
486 C32x4 {
487 lo: vld1q_f32(p),
488 hi: vld1q_f32(p.add(4)),
489 }
490 }
491
492 #[inline]
493 unsafe fn load_unaligned(ptr: *const Complex32) -> Self {
494 Self::load_aligned(ptr)
495 }
496
497 #[inline]
498 unsafe fn store_aligned(self, ptr: *mut Complex32) {
499 let p = ptr as *mut f32;
500 vst1q_f32(p, self.lo);
501 vst1q_f32(p.add(4), self.hi);
502 }
503
504 #[inline]
505 unsafe fn store_unaligned(self, ptr: *mut Complex32) {
506 self.store_aligned(ptr);
507 }
508
509 #[inline]
510 fn add(self, other: Self) -> Self {
511 unsafe {
512 C32x4 {
513 lo: vaddq_f32(self.lo, other.lo),
514 hi: vaddq_f32(self.hi, other.hi),
515 }
516 }
517 }
518
519 #[inline]
520 fn sub(self, other: Self) -> Self {
521 unsafe {
522 C32x4 {
523 lo: vsubq_f32(self.lo, other.lo),
524 hi: vsubq_f32(self.hi, other.hi),
525 }
526 }
527 }
528
529 #[inline]
530 fn mul(self, other: Self) -> Self {
531 unsafe {
534 let reals_self_lo = vuzp1q_f32(self.lo, self.lo); let imags_self_lo = vuzp2q_f32(self.lo, self.lo); let reals_other_lo = vuzp1q_f32(other.lo, other.lo); let imags_other_lo = vuzp2q_f32(other.lo, other.lo); let ac_lo = vmulq_f32(reals_self_lo, reals_other_lo);
549 let bd_lo = vmulq_f32(imags_self_lo, imags_other_lo);
550 let ad_lo = vmulq_f32(reals_self_lo, imags_other_lo);
551 let bc_lo = vmulq_f32(imags_self_lo, reals_other_lo);
552
553 let re_lo = vsubq_f32(ac_lo, bd_lo);
555 let im_lo = vaddq_f32(ad_lo, bc_lo);
556
557 let lo_result = vzip1q_f32(re_lo, im_lo);
559
560 let reals_self_hi = vuzp1q_f32(self.hi, self.hi);
562 let imags_self_hi = vuzp2q_f32(self.hi, self.hi);
563 let reals_other_hi = vuzp1q_f32(other.hi, other.hi);
564 let imags_other_hi = vuzp2q_f32(other.hi, other.hi);
565
566 let ac_hi = vmulq_f32(reals_self_hi, reals_other_hi);
567 let bd_hi = vmulq_f32(imags_self_hi, imags_other_hi);
568 let ad_hi = vmulq_f32(reals_self_hi, imags_other_hi);
569 let bc_hi = vmulq_f32(imags_self_hi, reals_other_hi);
570
571 let re_hi = vsubq_f32(ac_hi, bd_hi);
572 let im_hi = vaddq_f32(ad_hi, bc_hi);
573 let hi_result = vzip1q_f32(re_hi, im_hi);
574
575 C32x4 {
576 lo: lo_result,
577 hi: hi_result,
578 }
579 }
580 }
581
582 #[inline]
583 fn scale_real(self, scalar: f32) -> Self {
584 unsafe {
585 let s = vdupq_n_f32(scalar);
586 C32x4 {
587 lo: vmulq_f32(self.lo, s),
588 hi: vmulq_f32(self.hi, s),
589 }
590 }
591 }
592
593 #[inline]
594 fn conj(self) -> Self {
595 unsafe {
596 let neg_mask = vld1q_f32([1.0, -1.0, 1.0, -1.0].as_ptr());
597 C32x4 {
598 lo: vmulq_f32(self.lo, neg_mask),
599 hi: vmulq_f32(self.hi, neg_mask),
600 }
601 }
602 }
603
604 #[inline]
605 fn extract(self, index: usize) -> Complex32 {
606 debug_assert!(index < 4);
607 unsafe {
608 let mut arr = [0.0f32; 8];
609 vst1q_f32(arr.as_mut_ptr(), self.lo);
610 vst1q_f32(arr.as_mut_ptr().add(4), self.hi);
611 Complex32::new(arr[index * 2], arr[index * 2 + 1])
612 }
613 }
614
615 #[inline]
616 fn insert(self, index: usize, value: Complex32) -> Self {
617 debug_assert!(index < 4);
618 unsafe {
619 let mut arr = [0.0f32; 8];
620 vst1q_f32(arr.as_mut_ptr(), self.lo);
621 vst1q_f32(arr.as_mut_ptr().add(4), self.hi);
622 arr[index * 2] = value.re;
623 arr[index * 2 + 1] = value.im;
624 C32x4 {
625 lo: vld1q_f32(arr.as_ptr()),
626 hi: vld1q_f32(arr.as_ptr().add(4)),
627 }
628 }
629 }
630
631 #[inline]
632 fn reduce_sum(self) -> Complex32 {
633 unsafe {
634 let sum = vaddq_f32(self.lo, self.hi);
635 let mut arr = [0.0f32; 4];
637 vst1q_f32(arr.as_mut_ptr(), sum);
638 Complex32::new(arr[0] + arr[2], arr[1] + arr[3])
639 }
640 }
641 }
642}
643
644#[cfg(target_arch = "aarch64")]
645pub use aarch64_impl::{C32x4, C64x2};
646
647pub trait ComplexSimdScalar: Copy {
653 type Simd256: ComplexSimdRegister<Complex = Self>;
655}
656
657impl ComplexSimdScalar for Complex64 {
658 #[cfg(target_arch = "aarch64")]
659 type Simd256 = C64x2;
660 #[cfg(not(target_arch = "aarch64"))]
661 type Simd256 = ScalarC64;
662}
663
664impl ComplexSimdScalar for Complex32 {
665 #[cfg(target_arch = "aarch64")]
666 type Simd256 = C32x4;
667 #[cfg(not(target_arch = "aarch64"))]
668 type Simd256 = ScalarC32;
669}
670
671#[cfg(test)]
672mod tests {
673 use super::*;
674
675 #[test]
676 fn test_scalar_c64_basic() {
677 let a = ScalarC64::splat(Complex64::new(2.0, 3.0));
678 let b = ScalarC64::splat(Complex64::new(4.0, 5.0));
679
680 let sum = a.add(b);
682 assert_eq!(sum.0, Complex64::new(6.0, 8.0));
683
684 let prod = a.mul(b);
686 assert_eq!(prod.0, Complex64::new(-7.0, 22.0));
687
688 let conj = a.conj();
690 assert_eq!(conj.0, Complex64::new(2.0, -3.0));
691 }
692
693 #[test]
694 fn test_scalar_c32_basic() {
695 let a = ScalarC32::splat(Complex32::new(1.0, 2.0));
696 let b = ScalarC32::splat(Complex32::new(3.0, 4.0));
697
698 let sum = a.add(b);
699 assert_eq!(sum.0, Complex32::new(4.0, 6.0));
700
701 let prod = a.mul(b);
703 assert_eq!(prod.0, Complex32::new(-5.0, 10.0));
704 }
705
706 #[test]
707 fn test_scalar_scale_real() {
708 let a = ScalarC64::splat(Complex64::new(2.0, 3.0));
709 let scaled = a.scale_real(2.0);
710 assert_eq!(scaled.0, Complex64::new(4.0, 6.0));
711 }
712
713 #[cfg(target_arch = "aarch64")]
714 #[test]
715 fn test_c64x2_basic() {
716 let a = C64x2::splat(Complex64::new(2.0, 3.0));
717 let b = C64x2::splat(Complex64::new(4.0, 5.0));
718
719 let sum = a.add(b);
720 assert_eq!(sum.extract(0), Complex64::new(6.0, 8.0));
721 assert_eq!(sum.extract(1), Complex64::new(6.0, 8.0));
722
723 let conj = a.conj();
724 assert_eq!(conj.extract(0), Complex64::new(2.0, -3.0));
725 }
726
727 #[cfg(target_arch = "aarch64")]
728 #[test]
729 fn test_c64x2_mul() {
730 let a = C64x2::splat(Complex64::new(2.0, 3.0));
731 let b = C64x2::splat(Complex64::new(4.0, 5.0));
732
733 let prod = a.mul(b);
735 let result = prod.extract(0);
736
737 assert!((result.re - (-7.0)).abs() < 1e-10);
738 assert!((result.im - 22.0).abs() < 1e-10);
739 }
740
741 #[cfg(target_arch = "aarch64")]
742 #[test]
743 fn test_c64x2_reduce_sum() {
744 let a = C64x2::zero()
745 .insert(0, Complex64::new(1.0, 2.0))
746 .insert(1, Complex64::new(3.0, 4.0));
747
748 let sum = a.reduce_sum();
749 assert_eq!(sum, Complex64::new(4.0, 6.0));
750 }
751
752 #[cfg(target_arch = "aarch64")]
753 #[test]
754 fn test_c32x4_basic() {
755 let a = C32x4::splat(Complex32::new(1.0, 2.0));
756 let b = C32x4::splat(Complex32::new(3.0, 4.0));
757
758 let sum = a.add(b);
759 assert_eq!(sum.extract(0), Complex32::new(4.0, 6.0));
760 assert_eq!(sum.extract(3), Complex32::new(4.0, 6.0));
761 }
762
763 #[cfg(target_arch = "aarch64")]
764 #[test]
765 fn test_c32x4_reduce_sum() {
766 let a = C32x4::zero()
767 .insert(0, Complex32::new(1.0, 0.0))
768 .insert(1, Complex32::new(2.0, 0.0))
769 .insert(2, Complex32::new(3.0, 0.0))
770 .insert(3, Complex32::new(4.0, 0.0));
771
772 let sum = a.reduce_sum();
773 assert_eq!(sum, Complex32::new(10.0, 0.0));
774 }
775
776 #[cfg(target_arch = "aarch64")]
777 #[test]
778 fn test_c32x4_mul() {
779 let a = C32x4::splat(Complex32::new(1.0, 2.0));
780 let b = C32x4::splat(Complex32::new(3.0, 4.0));
781
782 let prod = a.mul(b);
784 let result = prod.extract(0);
785
786 assert!((result.re - (-5.0)).abs() < 1e-5);
787 assert!((result.im - 10.0).abs() < 1e-5);
788 }
789
790 #[cfg(target_arch = "aarch64")]
791 #[test]
792 fn test_c32x4_load_store() {
793 unsafe {
794 let data = [
795 Complex32::new(1.0, 2.0),
796 Complex32::new(3.0, 4.0),
797 Complex32::new(5.0, 6.0),
798 Complex32::new(7.0, 8.0),
799 ];
800
801 let v = C32x4::load_unaligned(data.as_ptr());
802
803 assert_eq!(v.extract(0), Complex32::new(1.0, 2.0));
804 assert_eq!(v.extract(1), Complex32::new(3.0, 4.0));
805 assert_eq!(v.extract(2), Complex32::new(5.0, 6.0));
806 assert_eq!(v.extract(3), Complex32::new(7.0, 8.0));
807
808 let mut out = [Complex32::new(0.0, 0.0); 4];
809 v.store_unaligned(out.as_mut_ptr());
810
811 assert_eq!(out, data);
812 }
813 }
814}