1use rustfft::num_complex::Complex;
18
19#[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
22#[inline(always)]
23unsafe fn fcmla_mul_acc(
24 mut r: std::arch::aarch64::float32x4_t,
25 a: std::arch::aarch64::float32x4_t,
26 b: std::arch::aarch64::float32x4_t,
27) -> std::arch::aarch64::float32x4_t {
28 unsafe {
29 std::arch::asm!(
30 "fcmla {r:v}.4s, {a:v}.4s, {b:v}.4s, #0",
31 "fcmla {r:v}.4s, {a:v}.4s, {b:v}.4s, #90",
32 r = inout(vreg) r,
33 a = in(vreg) a,
34 b = in(vreg) b,
35 options(pure, nomem, nostack),
36 );
37 }
38 r
39}
40
41#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
43const SHUFFLE_SWAP_RE_IM: i32 = 0b10110001; #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
46#[inline]
49pub unsafe fn complex_mul_add_simd_chunk(
50 dst: &mut [Complex<f32>],
51 src: &[Complex<f32>],
52 hrtf: &[Complex<f32>],
53 start: usize,
54) {
55 use std::arch::x86_64::*;
56
57 unsafe {
60 let src_ptr = src.as_ptr().add(start) as *const f32;
61 let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
62 let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
63
64 let a = _mm256_loadu_ps(src_ptr);
66 let b = _mm256_loadu_ps(hrtf_ptr);
67 let dst_val = _mm256_loadu_ps(dst_ptr);
68
69 let a_re = _mm256_moveldup_ps(a);
75 let a_im = _mm256_movehdup_ps(a);
76
77 let ac_ad = _mm256_mul_ps(a_re, b);
79
80 let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
82
83 let bd_bc = _mm256_mul_ps(a_im, b_swapped);
85
86 let result = _mm256_addsub_ps(ac_ad, bd_bc);
89
90 let final_result = _mm256_add_ps(dst_val, result);
92
93 _mm256_storeu_ps(dst_ptr, final_result);
94 }
95}
96
97#[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
98#[inline]
101pub unsafe fn complex_mul_add_simd_chunk(
102 dst: &mut [Complex<f32>],
103 src: &[Complex<f32>],
104 hrtf: &[Complex<f32>],
105 start: usize,
106) {
107 use std::arch::aarch64::*;
108
109 unsafe {
110 let src_ptr = src.as_ptr().add(start) as *const f32;
111 let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
112 let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
113
114 let a = vld1q_f32(src_ptr);
115 let b = vld1q_f32(hrtf_ptr);
116 let r = vld1q_f32(dst_ptr);
117 let result = fcmla_mul_acc(r, a, b);
118 vst1q_f32(dst_ptr, result);
119 }
120}
121
122#[cfg(all(
123 target_arch = "aarch64",
124 target_feature = "neon",
125 not(target_feature = "fcma")
126))]
127#[inline]
130pub unsafe fn complex_mul_add_simd_chunk(
131 dst: &mut [Complex<f32>],
132 src: &[Complex<f32>],
133 hrtf: &[Complex<f32>],
134 start: usize,
135) {
136 use std::arch::aarch64::*;
137
138 unsafe {
139 let src_ptr = src.as_ptr().add(start) as *const f32;
140 let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
141 let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
142
143 let a = vld1q_f32(src_ptr);
144 let b = vld1q_f32(hrtf_ptr);
145 let dst_val = vld1q_f32(dst_ptr);
146
147 let a_re = vtrn1q_f32(a, a);
148 let a_im = vtrn2q_f32(a, a);
149 let ac_ad = vmulq_f32(a_re, b);
150 let b_swapped = vrev64q_f32(b);
151 let bd_bc = vmulq_f32(a_im, b_swapped);
152
153 let sign_bit: u32 = 0x80000000;
154 let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
155 sign_bit,
156 vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
157 ));
158
159 let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
160 vreinterpretq_u32_f32(bd_bc),
161 vreinterpretq_u32_f32(neg_mask),
162 ));
163 let result = vaddq_f32(ac_ad, bd_bc_negated);
164 let final_result = vaddq_f32(dst_val, result);
165
166 vst1q_f32(dst_ptr, final_result);
167 }
168}
169
170#[cfg(not(any(
171 all(target_arch = "x86_64", target_feature = "avx2"),
172 all(target_arch = "aarch64", target_feature = "neon")
173)))]
174#[inline]
175pub fn complex_mul_add_simd_chunk(
176 dst: &mut [Complex<f32>],
177 src: &[Complex<f32>],
178 hrtf: &[Complex<f32>],
179 start: usize,
180) {
181 dst[start] += src[start] * hrtf[start];
183}
184
185#[inline]
194pub fn complex_mul_add_simd(dst: &mut [Complex<f32>], src: &[Complex<f32>], hrtf: &[Complex<f32>]) {
195 let len = dst.len();
196
197 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
198 {
199 let simd_len = (len / 4) * 4;
201
202 for i in (0..simd_len).step_by(4) {
203 unsafe {
204 complex_mul_add_simd_chunk(dst, src, hrtf, i);
205 }
206 }
207
208 for i in simd_len..len {
210 dst[i] += src[i] * hrtf[i];
211 }
212 }
213
214 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
215 {
216 let simd_len = (len / 2) * 2;
218
219 for i in (0..simd_len).step_by(2) {
220 unsafe {
221 complex_mul_add_simd_chunk(dst, src, hrtf, i);
222 }
223 }
224
225 for i in simd_len..len {
227 dst[i] += src[i] * hrtf[i];
228 }
229 }
230
231 #[cfg(not(any(
232 all(target_arch = "x86_64", target_feature = "avx2"),
233 all(target_arch = "aarch64", target_feature = "neon")
234 )))]
235 {
236 for i in 0..len {
238 dst[i] += src[i] * hrtf[i];
239 }
240 }
241}
242
243#[inline]
247pub fn complex_mul_simd(dst: &mut [Complex<f32>], src: &[Complex<f32>], hrtf: &[Complex<f32>]) {
248 let len = dst.len();
249
250 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
251 {
252 use std::arch::x86_64::*;
253
254 let simd_len = (len / 4) * 4;
255
256 for i in (0..simd_len).step_by(4) {
257 unsafe {
258 let src_ptr = src.as_ptr().add(i) as *const f32;
259 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
260 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
261
262 let a = _mm256_loadu_ps(src_ptr);
263 let b = _mm256_loadu_ps(hrtf_ptr);
264
265 let a_re = _mm256_moveldup_ps(a);
266 let a_im = _mm256_movehdup_ps(a);
267 let ac_ad = _mm256_mul_ps(a_re, b);
268 let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
269 let bd_bc = _mm256_mul_ps(a_im, b_swapped);
270 let result = _mm256_addsub_ps(ac_ad, bd_bc);
271
272 _mm256_storeu_ps(dst_ptr, result);
273 }
274 }
275
276 for i in simd_len..len {
277 dst[i] = src[i] * hrtf[i];
278 }
279 }
280
281 #[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
282 {
283 use std::arch::aarch64::*;
284
285 let simd_len = (len / 2) * 2;
286
287 for i in (0..simd_len).step_by(2) {
288 unsafe {
289 let src_ptr = src.as_ptr().add(i) as *const f32;
290 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
291 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
292
293 let a = vld1q_f32(src_ptr);
294 let b = vld1q_f32(hrtf_ptr);
295 let r = vdupq_n_f32(0.0);
296 let result = fcmla_mul_acc(r, a, b);
297 vst1q_f32(dst_ptr, result);
298 }
299 }
300
301 for i in simd_len..len {
302 dst[i] = src[i] * hrtf[i];
303 }
304 }
305
306 #[cfg(all(
307 target_arch = "aarch64",
308 target_feature = "neon",
309 not(target_feature = "fcma")
310 ))]
311 {
312 use std::arch::aarch64::*;
313
314 let simd_len = (len / 2) * 2;
315
316 for i in (0..simd_len).step_by(2) {
317 unsafe {
318 let src_ptr = src.as_ptr().add(i) as *const f32;
319 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
320 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
321
322 let a = vld1q_f32(src_ptr);
323 let b = vld1q_f32(hrtf_ptr);
324
325 let a_re = vtrn1q_f32(a, a);
326 let a_im = vtrn2q_f32(a, a);
327 let ac_ad = vmulq_f32(a_re, b);
328 let b_swapped = vrev64q_f32(b);
329 let bd_bc = vmulq_f32(a_im, b_swapped);
330
331 let sign_bit: u32 = 0x80000000;
332 let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
333 sign_bit,
334 vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
335 ));
336
337 let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
338 vreinterpretq_u32_f32(bd_bc),
339 vreinterpretq_u32_f32(neg_mask),
340 ));
341
342 let result = vaddq_f32(ac_ad, bd_bc_negated);
343 vst1q_f32(dst_ptr, result);
344 }
345 }
346
347 for i in simd_len..len {
348 dst[i] = src[i] * hrtf[i];
349 }
350 }
351
352 #[cfg(not(any(
353 all(target_arch = "x86_64", target_feature = "avx2"),
354 all(target_arch = "aarch64", target_feature = "neon")
355 )))]
356 {
357 for i in 0..len {
358 dst[i] = src[i] * hrtf[i];
359 }
360 }
361}
362
363#[inline]
367#[allow(dead_code)]
368pub fn complex_mul_inplace_simd(dst: &mut [Complex<f32>], hrtf: &[Complex<f32>]) {
369 let len = dst.len();
370
371 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
372 {
373 use std::arch::x86_64::*;
374
375 let simd_len = (len / 4) * 4;
376
377 for i in (0..simd_len).step_by(4) {
378 unsafe {
379 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
380 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
381
382 let a = _mm256_loadu_ps(dst_ptr);
383 let b = _mm256_loadu_ps(hrtf_ptr);
384
385 let a_re = _mm256_moveldup_ps(a);
386 let a_im = _mm256_movehdup_ps(a);
387 let ac_ad = _mm256_mul_ps(a_re, b);
388 let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
389 let bd_bc = _mm256_mul_ps(a_im, b_swapped);
390 let result = _mm256_addsub_ps(ac_ad, bd_bc);
391
392 _mm256_storeu_ps(dst_ptr, result);
393 }
394 }
395
396 for i in simd_len..len {
397 dst[i] *= hrtf[i];
398 }
399 }
400
401 #[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
402 {
403 use std::arch::aarch64::*;
404
405 let simd_len = (len / 2) * 2;
406
407 for i in (0..simd_len).step_by(2) {
408 unsafe {
409 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
410 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
411
412 let a = vld1q_f32(dst_ptr);
413 let b = vld1q_f32(hrtf_ptr);
414 let r = vdupq_n_f32(0.0);
415 let result = fcmla_mul_acc(r, a, b);
416 vst1q_f32(dst_ptr, result);
417 }
418 }
419
420 for i in simd_len..len {
421 dst[i] *= hrtf[i];
422 }
423 }
424
425 #[cfg(all(
426 target_arch = "aarch64",
427 target_feature = "neon",
428 not(target_feature = "fcma")
429 ))]
430 {
431 use std::arch::aarch64::*;
432
433 let simd_len = (len / 2) * 2;
434
435 for i in (0..simd_len).step_by(2) {
436 unsafe {
437 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
438 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
439
440 let a = vld1q_f32(dst_ptr);
441 let b = vld1q_f32(hrtf_ptr);
442
443 let a_re = vtrn1q_f32(a, a);
444 let a_im = vtrn2q_f32(a, a);
445 let ac_ad = vmulq_f32(a_re, b);
446 let b_swapped = vrev64q_f32(b);
447 let bd_bc = vmulq_f32(a_im, b_swapped);
448
449 let sign_bit: u32 = 0x80000000;
450 let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
451 sign_bit,
452 vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
453 ));
454
455 let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
456 vreinterpretq_u32_f32(bd_bc),
457 vreinterpretq_u32_f32(neg_mask),
458 ));
459
460 let result = vaddq_f32(ac_ad, bd_bc_negated);
461 vst1q_f32(dst_ptr, result);
462 }
463 }
464
465 for i in simd_len..len {
466 dst[i] *= hrtf[i];
467 }
468 }
469
470 #[cfg(not(any(
471 all(target_arch = "x86_64", target_feature = "avx2"),
472 all(target_arch = "aarch64", target_feature = "neon")
473 )))]
474 {
475 for i in 0..len {
476 dst[i] *= hrtf[i];
477 }
478 }
479}
480
481#[inline]
489pub fn scale_add_simd(dst: &mut [f32], src: &[f32], scale: f32) {
490 debug_assert_eq!(dst.len(), src.len());
491 let len = dst.len();
492
493 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
494 {
495 use std::arch::x86_64::*;
496
497 let scale_vec = unsafe { _mm256_set1_ps(scale) };
498 let simd_len = (len / 8) * 8;
499
500 for i in (0..simd_len).step_by(8) {
501 unsafe {
502 let src_ptr = src.as_ptr().add(i);
503 let dst_ptr = dst.as_mut_ptr().add(i);
504
505 let s = _mm256_loadu_ps(src_ptr);
506 let d = _mm256_loadu_ps(dst_ptr);
507
508 let ss = _mm256_mul_ps(s, scale_vec);
510 let result = _mm256_add_ps(d, ss);
511
512 _mm256_storeu_ps(dst_ptr, result);
513 }
514 }
515
516 for i in simd_len..len {
518 dst[i] += src[i] * scale;
519 }
520 }
521
522 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
523 {
524 use std::arch::aarch64::*;
525
526 let scale_vec = unsafe { vdupq_n_f32(scale) };
527 let simd_len = (len / 4) * 4;
528
529 for i in (0..simd_len).step_by(4) {
530 unsafe {
531 let src_ptr = src.as_ptr().add(i);
532 let dst_ptr = dst.as_mut_ptr().add(i);
533
534 let s = vld1q_f32(src_ptr);
535 let d = vld1q_f32(dst_ptr);
536
537 let result = vfmaq_f32(d, s, scale_vec);
539
540 vst1q_f32(dst_ptr, result);
541 }
542 }
543
544 for i in simd_len..len {
546 dst[i] += src[i] * scale;
547 }
548 }
549
550 #[cfg(not(any(
551 all(target_arch = "x86_64", target_feature = "avx2"),
552 all(target_arch = "aarch64", target_feature = "neon")
553 )))]
554 {
555 for i in 0..len {
556 dst[i] += src[i] * scale;
557 }
558 }
559}
560
561#[inline]
565pub fn scale_add_simd_inplace(data: &mut [f32], scale: f32) {
566 let len = data.len();
567
568 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
569 {
570 use std::arch::x86_64::*;
571
572 let scale_vec = unsafe { _mm256_set1_ps(scale) };
573 let simd_len = (len / 8) * 8;
574
575 for i in (0..simd_len).step_by(8) {
576 unsafe {
577 let ptr = data.as_mut_ptr().add(i);
578 let d = _mm256_loadu_ps(ptr);
579 _mm256_storeu_ps(ptr, _mm256_mul_ps(d, scale_vec));
580 }
581 }
582
583 for sample in data.iter_mut().take(len).skip(simd_len) {
584 *sample *= scale;
585 }
586 }
587
588 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
589 {
590 use std::arch::aarch64::*;
591
592 let scale_vec = unsafe { vdupq_n_f32(scale) };
593 let simd_len = (len / 4) * 4;
594
595 for i in (0..simd_len).step_by(4) {
596 unsafe {
597 let ptr = data.as_mut_ptr().add(i);
598 let d = vld1q_f32(ptr);
599 vst1q_f32(ptr, vmulq_f32(d, scale_vec));
600 }
601 }
602
603 for sample in &mut data[simd_len..len] {
604 *sample *= scale;
605 }
606 }
607
608 #[cfg(not(any(
609 all(target_arch = "x86_64", target_feature = "avx2"),
610 all(target_arch = "aarch64", target_feature = "neon")
611 )))]
612 {
613 for sample in data {
614 *sample *= scale;
615 }
616 }
617}
618
619#[inline]
626pub fn blend_simd(dst: &mut [f32], prev: &[f32], alpha: f32) {
627 debug_assert_eq!(dst.len(), prev.len());
628 let len = dst.len();
629
630 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
631 {
632 use std::arch::x86_64::*;
633
634 let alpha_vec = unsafe { _mm256_set1_ps(alpha) };
635 let simd_len = (len / 8) * 8;
636
637 for i in (0..simd_len).step_by(8) {
638 unsafe {
639 let prev_ptr = prev.as_ptr().add(i);
640 let dst_ptr = dst.as_mut_ptr().add(i);
641
642 let p = _mm256_loadu_ps(prev_ptr);
643 let d = _mm256_loadu_ps(dst_ptr);
644
645 let diff = _mm256_sub_ps(d, p);
647 let result = _mm256_fmadd_ps(alpha_vec, diff, p);
648
649 _mm256_storeu_ps(dst_ptr, result);
650 }
651 }
652
653 for i in simd_len..len {
654 dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
655 }
656 }
657
658 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
659 {
660 use std::arch::aarch64::*;
661
662 let alpha_vec = unsafe { vdupq_n_f32(alpha) };
663 let simd_len = (len / 4) * 4;
664
665 for i in (0..simd_len).step_by(4) {
666 unsafe {
667 let prev_ptr = prev.as_ptr().add(i);
668 let dst_ptr = dst.as_mut_ptr().add(i);
669
670 let p = vld1q_f32(prev_ptr);
671 let d = vld1q_f32(dst_ptr);
672
673 let diff = vsubq_f32(d, p);
675 let result = vfmaq_f32(p, alpha_vec, diff);
676
677 vst1q_f32(dst_ptr, result);
678 }
679 }
680
681 for i in simd_len..len {
682 dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
683 }
684 }
685
686 #[cfg(not(any(
687 all(target_arch = "x86_64", target_feature = "avx2"),
688 all(target_arch = "aarch64", target_feature = "neon")
689 )))]
690 {
691 for i in 0..len {
692 dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
693 }
694 }
695}
696
697#[inline]
701pub fn window_mul_simd(dst: &mut [f32], src: &[f32], window: &[f32]) {
702 debug_assert_eq!(dst.len(), src.len());
703 debug_assert_eq!(dst.len(), window.len());
704 let len = dst.len();
705
706 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
707 {
708 use std::arch::x86_64::*;
709
710 let simd_len = (len / 8) * 8;
711
712 for i in (0..simd_len).step_by(8) {
713 unsafe {
714 let src_ptr = src.as_ptr().add(i);
715 let win_ptr = window.as_ptr().add(i);
716 let dst_ptr = dst.as_mut_ptr().add(i);
717
718 let s = _mm256_loadu_ps(src_ptr);
719 let w = _mm256_loadu_ps(win_ptr);
720 let result = _mm256_mul_ps(s, w);
721
722 _mm256_storeu_ps(dst_ptr, result);
723 }
724 }
725
726 for i in simd_len..len {
727 dst[i] = src[i] * window[i];
728 }
729 }
730
731 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
732 {
733 use std::arch::aarch64::*;
734
735 let simd_len = (len / 4) * 4;
736
737 for i in (0..simd_len).step_by(4) {
738 unsafe {
739 let src_ptr = src.as_ptr().add(i);
740 let win_ptr = window.as_ptr().add(i);
741 let dst_ptr = dst.as_mut_ptr().add(i);
742
743 let s = vld1q_f32(src_ptr);
744 let w = vld1q_f32(win_ptr);
745 let result = vmulq_f32(s, w);
746
747 vst1q_f32(dst_ptr, result);
748 }
749 }
750
751 for i in simd_len..len {
752 dst[i] = src[i] * window[i];
753 }
754 }
755
756 #[cfg(not(any(
757 all(target_arch = "x86_64", target_feature = "avx2"),
758 all(target_arch = "aarch64", target_feature = "neon")
759 )))]
760 {
761 for i in 0..len {
762 dst[i] = src[i] * window[i];
763 }
764 }
765}
766
767#[inline]
769pub fn window_mul_simd_inplace(data: &mut [f32], window: &[f32]) {
770 let len = data.len().min(window.len());
771
772 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
773 {
774 use std::arch::x86_64::*;
775 let simd_len = (len / 8) * 8;
776 for i in (0..simd_len).step_by(8) {
777 unsafe {
778 let ptr = data.as_mut_ptr().add(i);
779 let win_ptr = window.as_ptr().add(i);
780 let d = _mm256_loadu_ps(ptr);
781 let w = _mm256_loadu_ps(win_ptr);
782 _mm256_storeu_ps(ptr, _mm256_mul_ps(d, w));
783 }
784 }
785 for i in simd_len..len {
786 data[i] *= window[i];
787 }
788 }
789
790 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
791 {
792 use std::arch::aarch64::*;
793 let simd_len = (len / 4) * 4;
794 for i in (0..simd_len).step_by(4) {
795 unsafe {
796 let ptr = data.as_mut_ptr().add(i);
797 let win_ptr = window.as_ptr().add(i);
798 let d = vld1q_f32(ptr);
799 let w = vld1q_f32(win_ptr);
800 vst1q_f32(ptr, vmulq_f32(d, w));
801 }
802 }
803 for i in simd_len..len {
804 data[i] *= window[i];
805 }
806 }
807
808 #[cfg(not(any(
809 all(target_arch = "x86_64", target_feature = "avx2"),
810 all(target_arch = "aarch64", target_feature = "neon")
811 )))]
812 {
813 for i in 0..len {
814 data[i] *= window[i];
815 }
816 }
817}
818
819#[inline]
824pub fn deinterleave_stereo(input: &[f32], left: &mut [f32], right: &mut [f32]) {
825 debug_assert_eq!(input.len(), left.len() * 2);
826 debug_assert_eq!(left.len(), right.len());
827
828 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
829 {
830 use std::arch::x86_64::*;
831
832 let len = left.len();
833 let simd_len = (len / 8) * 8;
834
835 for i in (0..simd_len).step_by(8) {
836 unsafe {
837 let in_ptr = input.as_ptr().add(i * 2);
839 let v0 = _mm256_loadu_ps(in_ptr); let v1 = _mm256_loadu_ps(in_ptr.add(8)); let shuf_l = _mm256_shuffle_ps(v0, v1, 0b10_00_10_00); let shuf_r = _mm256_shuffle_ps(v0, v1, 0b11_01_11_01); let left_vec = _mm256_permute4x64_pd(
849 std::mem::transmute::<__m256, __m256d>(shuf_l),
850 0b11_01_10_00,
851 );
852 let right_vec = _mm256_permute4x64_pd(
853 std::mem::transmute::<__m256, __m256d>(shuf_r),
854 0b11_01_10_00,
855 );
856
857 _mm256_storeu_ps(
858 left.as_mut_ptr().add(i),
859 std::mem::transmute::<__m256d, __m256>(left_vec),
860 );
861 _mm256_storeu_ps(
862 right.as_mut_ptr().add(i),
863 std::mem::transmute::<__m256d, __m256>(right_vec),
864 );
865 }
866 }
867
868 for i in simd_len..len {
870 left[i] = input[i * 2];
871 right[i] = input[i * 2 + 1];
872 }
873 }
874
875 #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
876 {
877 for (i, chunk) in input.chunks_exact(2).enumerate() {
879 left[i] = chunk[0];
880 right[i] = chunk[1];
881 }
882 }
883}
884
885#[inline]
890#[allow(dead_code)]
891pub fn interleave_stereo(left: &[f32], right: &[f32], output: &mut [f32]) {
892 debug_assert_eq!(left.len(), right.len());
893 debug_assert_eq!(output.len(), left.len() * 2);
894
895 for i in 0..left.len() {
897 output[i * 2] = left[i];
898 output[i * 2 + 1] = right[i];
899 }
900}
901
902#[inline]
908pub fn flush_denormals_inplace(samples: &mut [f32]) {
909 const DENORM_THRESHOLD: f32 = 1e-30;
910
911 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
912 {
913 use std::arch::x86_64::*;
914
915 let threshold = unsafe { _mm256_set1_ps(DENORM_THRESHOLD) };
916 let zero = unsafe { _mm256_set1_ps(0.0) };
917 let len = samples.len();
918 let simd_len = (len / 8) * 8;
919
920 for i in (0..simd_len).step_by(8) {
921 unsafe {
922 let ptr = samples.as_mut_ptr().add(i);
923 let val = _mm256_loadu_ps(ptr);
924 let abs_val = _mm256_andnot_ps(_mm256_set1_ps(-0.0), val);
925 let mask = _mm256_cmp_ps(abs_val, threshold, _CMP_LT_OQ);
926 let result = _mm256_blendv_ps(val, zero, mask);
927 _mm256_storeu_ps(ptr, result);
928 }
929 }
930
931 for sample in samples.iter_mut().take(len).skip(simd_len) {
932 if sample.abs() < DENORM_THRESHOLD {
933 *sample = 0.0;
934 }
935 }
936 }
937
938 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
939 {
940 use std::arch::aarch64::*;
941
942 let threshold = unsafe { vdupq_n_f32(DENORM_THRESHOLD) };
943 let zero = unsafe { vdupq_n_f32(0.0) };
944 let len = samples.len();
945 let simd_len = (len / 4) * 4;
946
947 for i in (0..simd_len).step_by(4) {
948 unsafe {
949 let ptr = samples.as_mut_ptr().add(i);
950 let val = vld1q_f32(ptr);
951 let abs_val = vabsq_f32(val);
952 let mask = vcltq_f32(abs_val, threshold);
953 let result = vbslq_f32(mask, zero, val);
954 vst1q_f32(ptr, result);
955 }
956 }
957
958 for sample in &mut samples[simd_len..len] {
959 if sample.abs() < DENORM_THRESHOLD {
960 *sample = 0.0;
961 }
962 }
963 }
964
965 #[cfg(not(any(
966 all(target_arch = "x86_64", target_feature = "avx2"),
967 all(target_arch = "aarch64", target_feature = "neon")
968 )))]
969 {
970 for sample in samples {
971 if sample.abs() < DENORM_THRESHOLD {
972 *sample = 0.0;
973 }
974 }
975 }
976}
977
978#[inline]
989pub fn enable_ftz_daz() -> bool {
990 #[cfg(target_arch = "x86_64")]
991 {
992 unsafe {
996 let mut mxcsr: u32 = 0;
997 std::arch::asm!("stmxcsr [{}]", in(reg) &mut mxcsr, options(nostack, preserves_flags));
998 mxcsr |= (1 << 15) | (1 << 6); std::arch::asm!("ldmxcsr [{}]", in(reg) &mxcsr, options(nostack, preserves_flags));
1000 }
1001 true
1002 }
1003
1004 #[cfg(target_arch = "aarch64")]
1005 {
1006 unsafe {
1010 let mut fpcr: u64;
1011 std::arch::asm!("mrs {}, fpcr", out(reg) fpcr);
1012 fpcr |= 1 << 24; std::arch::asm!("msr fpcr, {}", in(reg) fpcr);
1014 }
1015 true
1016 }
1017
1018 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1019 {
1020 false
1021 }
1022}
1023
1024#[inline]
1026pub fn flush_denormals_complex_inplace(samples: &mut [Complex<f32>]) {
1027 let len = samples.len() * 2;
1030 let ptr = samples.as_mut_ptr() as *mut f32;
1031 let f32_samples = unsafe { std::slice::from_raw_parts_mut(ptr, len) };
1032 flush_denormals_inplace(f32_samples);
1033}
1034
1035#[cfg(test)]
1036mod denorm_tests {
1037 use super::*;
1038
1039 #[test]
1040 fn test_flush_denormals_basic() {
1041 let mut samples = [1e-31_f32, 1e-20, 1e-10, 0.0, -1e-31, 1.0];
1042 flush_denormals_inplace(&mut samples);
1043 assert_eq!(samples[0], 0.0);
1044 assert_eq!(samples[1], 1e-20);
1045 assert_eq!(samples[2], 1e-10);
1046 assert_eq!(samples[3], 0.0);
1047 assert_eq!(samples[4], 0.0);
1048 assert_eq!(samples[5], 1.0);
1049 }
1050
1051 #[test]
1052 fn test_flush_denormals_complex() {
1053 use rustfft::num_complex::Complex;
1054 let mut samples = [
1055 Complex::new(1e-31, 1e-30),
1056 Complex::new(1.0, 1e-31),
1057 Complex::new(0.0, 0.0),
1058 ];
1059 flush_denormals_complex_inplace(&mut samples);
1060 assert_eq!(samples[0].re, 0.0);
1061 assert!((samples[0].im - 1e-30).abs() < 1e-35);
1062 assert_eq!(samples[1].re, 1.0);
1063 assert_eq!(samples[1].im, 0.0);
1064 assert_eq!(samples[2].re, 0.0);
1065 assert_eq!(samples[2].im, 0.0);
1066 }
1067
1068 #[test]
1069 fn test_flush_denormals_empty() {
1070 let mut samples: [f32; 0] = [];
1071 flush_denormals_inplace(&mut samples);
1072 }
1073
1074 #[test]
1075 fn test_flush_denormals_unaligned() {
1076 let mut samples = [1e-31_f32; 7];
1077 flush_denormals_inplace(&mut samples);
1078 for s in samples.iter() {
1079 assert_eq!(*s, 0.0);
1080 }
1081 }
1082}
1083
1084#[cfg(test)]
1085#[allow(clippy::needless_range_loop)]
1086mod tests {
1087 use super::*;
1088
1089 #[test]
1094 fn test_flush_denormals_basic() {
1095 let mut samples = [1e-31_f32, 1e-20, 1e-10, 0.0, -1e-31, 1.0];
1096 flush_denormals_inplace(&mut samples);
1097 assert_eq!(samples[0], 0.0);
1098 assert_eq!(samples[1], 1e-20);
1099 assert_eq!(samples[2], 1e-10);
1100 assert_eq!(samples[3], 0.0);
1101 assert_eq!(samples[4], 0.0);
1102 assert_eq!(samples[5], 1.0);
1103 }
1104
1105 #[test]
1106 fn test_flush_denormals_complex() {
1107 use rustfft::num_complex::Complex;
1108 let mut samples = [
1109 Complex::new(1e-31, 1e-30),
1110 Complex::new(1.0, 1e-31),
1111 Complex::new(0.0, 0.0),
1112 ];
1113 flush_denormals_complex_inplace(&mut samples);
1114 assert_eq!(samples[0].re, 0.0);
1115 assert!((samples[0].im - 1e-30).abs() < 1e-35);
1116 assert_eq!(samples[1].re, 1.0);
1117 assert_eq!(samples[1].im, 0.0);
1118 assert_eq!(samples[2].re, 0.0);
1119 assert_eq!(samples[2].im, 0.0);
1120 }
1121
1122 #[test]
1123 fn test_flush_denormals_empty() {
1124 let mut samples: [f32; 0] = [];
1125 flush_denormals_inplace(&mut samples);
1126 }
1127
1128 #[test]
1129 fn test_flush_denormals_unaligned() {
1130 let mut samples = [1e-31_f32; 7];
1131 flush_denormals_inplace(&mut samples);
1132 for s in samples.iter() {
1133 assert_eq!(*s, 0.0);
1134 }
1135 }
1136
1137 #[test]
1138 fn test_enable_ftz_daz_does_not_panic() {
1139 let result = enable_ftz_daz();
1142 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
1143 assert!(
1144 result,
1145 "enable_ftz_daz should return true on supported platforms"
1146 );
1147 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1148 assert!(
1149 !result,
1150 "enable_ftz_daz should return false on unsupported platforms"
1151 );
1152 }
1153
1154 #[test]
1155 fn test_apply_gain_simd_known_values() {
1156 let mut buffer = vec![1.0, 2.0, 3.0, 4.0, -1.0, 0.5, -0.5, 0.0, 1.5];
1158 let expected: Vec<f32> = buffer.iter().map(|&x| x * 2.0).collect();
1159 apply_gain_simd(&mut buffer, 2.0);
1160 for (i, (&got, &exp)) in buffer.iter().zip(expected.iter()).enumerate() {
1161 assert!(
1162 (got - exp).abs() < 1e-6,
1163 "apply_gain_simd mismatch at index {}: got {}, expected {}",
1164 i,
1165 got,
1166 exp
1167 );
1168 }
1169 }
1170
1171 #[test]
1172 fn test_apply_gain_simd_zero_gain() {
1173 let mut buffer = vec![1.0, -2.0, 3.5, 0.7];
1174 apply_gain_simd(&mut buffer, 0.0);
1175 for (i, &v) in buffer.iter().enumerate() {
1176 assert_eq!(
1177 v, 0.0,
1178 "apply_gain_simd with zero gain: index {} not zero",
1179 i
1180 );
1181 }
1182 }
1183
1184 #[test]
1185 fn test_apply_gain_simd_unity_gain() {
1186 let original = vec![1.0, -2.0, 3.5, 0.7, 0.0, -0.1];
1187 let mut buffer = original.clone();
1188 apply_gain_simd(&mut buffer, 1.0);
1189 assert_eq!(buffer, original);
1190 }
1191
1192 #[test]
1193 fn test_apply_per_channel_gain_simd_stereo() {
1194 let mut buffer = vec![1.0, 2.0, 3.0, 4.0]; let gains = vec![0.5, 2.0]; apply_per_channel_gain_simd(&mut buffer, 2, &gains);
1198 assert!((buffer[0] - 0.5).abs() < 1e-6, "L frame 0");
1199 assert!((buffer[1] - 4.0).abs() < 1e-6, "R frame 0");
1200 assert!((buffer[2] - 1.5).abs() < 1e-6, "L frame 1");
1201 assert!((buffer[3] - 8.0).abs() < 1e-6, "R frame 1");
1202 }
1203
1204 #[test]
1212 fn test_simd_complex_mul_add_correctness() {
1213 use rustfft::num_complex::Complex;
1215
1216 let src = vec![
1218 Complex::new(1.0, 2.0),
1219 Complex::new(3.0, 4.0),
1220 Complex::new(-1.0, 0.5),
1221 Complex::new(0.0, -2.0),
1222 Complex::new(2.5, -1.5),
1223 Complex::new(-3.5, 2.5),
1224 Complex::new(1.1, -0.9),
1225 Complex::new(-0.8, 1.2),
1226 ];
1227
1228 let hrtf = vec![
1229 Complex::new(0.5, 0.25),
1230 Complex::new(-1.0, 1.5),
1231 Complex::new(2.0, -0.5),
1232 Complex::new(0.75, 0.75),
1233 Complex::new(-0.5, 2.0),
1234 Complex::new(1.5, -1.0),
1235 Complex::new(0.9, 0.3),
1236 Complex::new(-1.1, 0.7),
1237 ];
1238
1239 let initial = vec![
1240 Complex::new(0.1, 0.2),
1241 Complex::new(0.3, 0.4),
1242 Complex::new(0.5, 0.6),
1243 Complex::new(0.7, 0.8),
1244 Complex::new(0.9, 1.0),
1245 Complex::new(1.1, 1.2),
1246 Complex::new(1.3, 1.4),
1247 Complex::new(1.5, 1.6),
1248 ];
1249
1250 let mut expected = initial.clone();
1252 for i in 0..src.len() {
1253 expected[i] += src[i] * hrtf[i];
1254 }
1255
1256 let mut result = initial.clone();
1258 complex_mul_add_simd(&mut result, &src, &hrtf);
1259
1260 const EPSILON: f32 = 1e-6;
1262 for i in 0..src.len() {
1263 assert!(
1264 (result[i].re - expected[i].re).abs() < EPSILON,
1265 "SIMD result[{}].re = {}, expected = {} (diff = {})",
1266 i,
1267 result[i].re,
1268 expected[i].re,
1269 (result[i].re - expected[i].re).abs()
1270 );
1271 assert!(
1272 (result[i].im - expected[i].im).abs() < EPSILON,
1273 "SIMD result[{}].im = {}, expected = {} (diff = {})",
1274 i,
1275 result[i].im,
1276 expected[i].im,
1277 (result[i].im - expected[i].im).abs()
1278 );
1279 }
1280 }
1281
1282 #[test]
1283 fn test_simd_complex_mul_correctness() {
1284 use rustfft::num_complex::Complex;
1286
1287 let src = vec![
1288 Complex::new(2.0, 3.0),
1289 Complex::new(-1.5, 2.5),
1290 Complex::new(0.5, -1.0),
1291 Complex::new(4.0, -2.0),
1292 ];
1293
1294 let hrtf = vec![
1295 Complex::new(1.0, 0.5),
1296 Complex::new(2.0, -1.0),
1297 Complex::new(-0.5, 1.5),
1298 Complex::new(0.75, 0.25),
1299 ];
1300
1301 let expected: Vec<Complex<f32>> = src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
1303
1304 let mut result = vec![Complex::new(0.0, 0.0); src.len()];
1306 complex_mul_simd(&mut result, &src, &hrtf);
1307
1308 const EPSILON: f32 = 1e-6;
1310 for i in 0..src.len() {
1311 assert!(
1312 (result[i].re - expected[i].re).abs() < EPSILON,
1313 "SIMD result[{}].re = {}, expected = {}",
1314 i,
1315 result[i].re,
1316 expected[i].re
1317 );
1318 assert!(
1319 (result[i].im - expected[i].im).abs() < EPSILON,
1320 "SIMD result[{}].im = {}, expected = {}",
1321 i,
1322 result[i].im,
1323 expected[i].im
1324 );
1325 }
1326 }
1327
1328 #[test]
1329 fn test_simd_edge_cases() {
1330 use rustfft::num_complex::Complex;
1332
1333 let src = vec![
1335 Complex::new(1.0, 2.0),
1336 Complex::new(3.0, 4.0),
1337 Complex::new(5.0, 6.0),
1338 Complex::new(7.0, 8.0),
1339 ];
1340 let zero = vec![Complex::new(0.0, 0.0); 4];
1341 let mut result = src.clone();
1342 let input = result.clone();
1343 complex_mul_simd(&mut result, &input, &zero);
1344 for i in 0..4 {
1345 assert_eq!(result[i].re, 0.0);
1346 assert_eq!(result[i].im, 0.0);
1347 }
1348
1349 let one = vec![Complex::new(1.0, 0.0); 4];
1351 let mut result = vec![Complex::new(0.0, 0.0); 4];
1352 complex_mul_simd(&mut result, &src, &one);
1353 for i in 0..4 {
1354 assert!((result[i].re - src[i].re).abs() < 1e-6);
1355 assert!((result[i].im - src[i].im).abs() < 1e-6);
1356 }
1357
1358 let a = Complex::new(3.0, 4.0);
1360 let a_conj = Complex::new(3.0, -4.0);
1361 let src = vec![a, a, a, a];
1362 let conj = vec![a_conj, a_conj, a_conj, a_conj];
1363 let mut result = vec![Complex::new(0.0, 0.0); 4];
1364 complex_mul_simd(&mut result, &src, &conj);
1365
1366 for i in 0..4 {
1368 assert!((result[i].re - 25.0).abs() < 1e-5);
1369 assert!(result[i].im.abs() < 1e-5); }
1371 }
1372
1373 #[test]
1374 fn test_simd_large_buffer() {
1375 use rustfft::num_complex::Complex;
1377
1378 for fft_size in [512, 1024, 2048, 4096] {
1379 let mut src = Vec::with_capacity(fft_size);
1380 let mut hrtf = Vec::with_capacity(fft_size);
1381
1382 for i in 0..fft_size {
1384 let phase = (i as f32) * 0.01;
1385 src.push(Complex::new(phase.cos(), phase.sin()));
1386 hrtf.push(Complex::new(0.5, 0.25));
1387 }
1388
1389 let mut expected = vec![Complex::new(0.1, 0.2); fft_size];
1391 for i in 0..fft_size {
1392 expected[i] += src[i] * hrtf[i];
1393 }
1394
1395 let mut result = vec![Complex::new(0.1, 0.2); fft_size];
1397 complex_mul_add_simd(&mut result, &src, &hrtf);
1398
1399 for i in 0..fft_size {
1401 assert!(
1402 (result[i].re - expected[i].re).abs() < 1e-5,
1403 "FFT size {}, index {}: SIMD mismatch",
1404 fft_size,
1405 i
1406 );
1407 assert!(
1408 (result[i].im - expected[i].im).abs() < 1e-5,
1409 "FFT size {}, index {}: SIMD mismatch",
1410 fft_size,
1411 i
1412 );
1413 }
1414 }
1415 }
1416
1417 #[test]
1418 fn test_simd_unaligned_sizes() {
1419 use rustfft::num_complex::Complex;
1422
1423 for size in [1, 3, 5, 7, 9, 13, 17] {
1424 let src: Vec<Complex<f32>> = (0..size)
1425 .map(|i| Complex::new(i as f32, (i as f32) * 0.5))
1426 .collect();
1427 let hrtf: Vec<Complex<f32>> = (0..size)
1428 .map(|i| Complex::new(0.5, (i as f32) * 0.1))
1429 .collect();
1430
1431 let expected: Vec<Complex<f32>> =
1433 src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
1434
1435 let mut result = vec![Complex::new(0.0, 0.0); size];
1437 complex_mul_simd(&mut result, &src, &hrtf);
1438
1439 for i in 0..size {
1441 assert!(
1442 (result[i].re - expected[i].re).abs() < 1e-6,
1443 "Size {}, index {}: re mismatch",
1444 size,
1445 i
1446 );
1447 assert!(
1448 (result[i].im - expected[i].im).abs() < 1e-6,
1449 "Size {}, index {}: im mismatch",
1450 size,
1451 i
1452 );
1453 }
1454 }
1455 }
1456
1457 #[test]
1462 fn test_simd_complex_mul_inplace_correctness() {
1463 use rustfft::num_complex::Complex;
1465
1466 let src = vec![
1467 Complex::new(2.0, 3.0),
1468 Complex::new(-1.5, 2.5),
1469 Complex::new(0.5, -1.0),
1470 Complex::new(4.0, -2.0),
1471 ];
1472
1473 let hrtf = vec![
1474 Complex::new(1.0, 0.5),
1475 Complex::new(2.0, -1.0),
1476 Complex::new(-0.5, 1.5),
1477 Complex::new(0.75, 0.25),
1478 ];
1479
1480 let mut expected = src.clone();
1482 for i in 0..expected.len() {
1483 expected[i] *= hrtf[i];
1484 }
1485
1486 let mut result = src.clone();
1488 complex_mul_inplace_simd(&mut result, &hrtf);
1489
1490 const EPSILON: f32 = 1e-6;
1491 for i in 0..result.len() {
1492 assert!(
1493 (result[i].re - expected[i].re).abs() < EPSILON,
1494 "Index {}: re mismatch {} vs {}",
1495 i,
1496 result[i].re,
1497 expected[i].re
1498 );
1499 assert!(
1500 (result[i].im - expected[i].im).abs() < EPSILON,
1501 "Index {}: im mismatch {} vs {}",
1502 i,
1503 result[i].im,
1504 expected[i].im
1505 );
1506 }
1507 }
1508
1509 #[test]
1510 fn test_simd_inplace_large_buffers() {
1511 use rustfft::num_complex::Complex;
1513
1514 for fft_size in [128, 256, 512, 1024, 2048] {
1515 let mut src: Vec<Complex<f32>> = (0..fft_size)
1516 .map(|i| {
1517 let phase = (i as f32) * 0.01;
1518 Complex::new(phase.cos(), phase.sin())
1519 })
1520 .collect();
1521
1522 let hrtf: Vec<Complex<f32>> = (0..fft_size)
1523 .map(|i| Complex::new(0.5 + (i as f32) * 0.001, 0.25))
1524 .collect();
1525
1526 let mut expected = src.clone();
1528 for i in 0..fft_size {
1529 expected[i] *= hrtf[i];
1530 }
1531
1532 complex_mul_inplace_simd(&mut src, &hrtf);
1534
1535 for i in 0..fft_size {
1537 assert!(
1538 (src[i].re - expected[i].re).abs() < 1e-5,
1539 "FFT size {}, index {}: re mismatch",
1540 fft_size,
1541 i
1542 );
1543 assert!(
1544 (src[i].im - expected[i].im).abs() < 1e-5,
1545 "FFT size {}, index {}: im mismatch",
1546 fft_size,
1547 i
1548 );
1549 }
1550 }
1551 }
1552
1553 #[test]
1554 fn test_simd_inplace_unaligned() {
1555 use rustfft::num_complex::Complex;
1557
1558 for size in [1, 2, 3, 5, 6, 7, 9, 10, 11, 15, 17, 19, 23] {
1559 let mut src: Vec<Complex<f32>> = (0..size)
1560 .map(|i| Complex::new((i as f32) * 0.5, (i as f32) * -0.3))
1561 .collect();
1562
1563 let hrtf: Vec<Complex<f32>> = (0..size)
1564 .map(|i| Complex::new(1.0 + (i as f32) * 0.1, 0.5))
1565 .collect();
1566
1567 let mut expected = src.clone();
1569 for i in 0..size {
1570 expected[i] *= hrtf[i];
1571 }
1572
1573 complex_mul_inplace_simd(&mut src, &hrtf);
1575
1576 for i in 0..size {
1578 assert!(
1579 (src[i].re - expected[i].re).abs() < 1e-6,
1580 "Size {}, index {}: re mismatch",
1581 size,
1582 i
1583 );
1584 assert!(
1585 (src[i].im - expected[i].im).abs() < 1e-6,
1586 "Size {}, index {}: im mismatch",
1587 size,
1588 i
1589 );
1590 }
1591 }
1592 }
1593
1594 #[test]
1595 fn test_simd_inplace_edge_cases() {
1596 use rustfft::num_complex::Complex;
1598
1599 let mut src = vec![
1601 Complex::new(1.0, 2.0),
1602 Complex::new(3.0, 4.0),
1603 Complex::new(5.0, 6.0),
1604 Complex::new(7.0, 8.0),
1605 ];
1606 let zero = vec![Complex::new(0.0, 0.0); 4];
1607 complex_mul_inplace_simd(&mut src, &zero);
1608 for i in 0..4 {
1609 assert!(src[i].re.abs() < 1e-6, "Expected zero, got {}", src[i].re);
1610 assert!(src[i].im.abs() < 1e-6, "Expected zero, got {}", src[i].im);
1611 }
1612
1613 let original = vec![
1615 Complex::new(1.5, 2.5),
1616 Complex::new(-3.5, 4.5),
1617 Complex::new(5.5, -6.5),
1618 Complex::new(-7.5, -8.5),
1619 ];
1620 let mut src = original.clone();
1621 let one = vec![Complex::new(1.0, 0.0); 4];
1622 complex_mul_inplace_simd(&mut src, &one);
1623 for i in 0..4 {
1624 assert!((src[i].re - original[i].re).abs() < 1e-6);
1625 assert!((src[i].im - original[i].im).abs() < 1e-6);
1626 }
1627
1628 let a = Complex::new(3.0, 4.0);
1630 let a_conj = Complex::new(3.0, -4.0);
1631 let mut src = vec![a; 8];
1632 let conj = vec![a_conj; 8];
1633 complex_mul_inplace_simd(&mut src, &conj);
1634 for i in 0..8 {
1636 assert!(
1637 (src[i].re - 25.0).abs() < 1e-5,
1638 "Expected 25.0, got {}",
1639 src[i].re
1640 );
1641 assert!(src[i].im.abs() < 1e-5, "Expected ~0, got {}", src[i].im);
1642 }
1643
1644 let mut src = vec![Complex::new(1.0, 0.0); 4];
1646 let i_val = vec![Complex::new(0.0, 1.0); 4];
1647 complex_mul_inplace_simd(&mut src, &i_val);
1648 for idx in 0..4 {
1649 assert!(src[idx].re.abs() < 1e-6, "Expected 0, got {}", src[idx].re);
1650 assert!(
1651 (src[idx].im - 1.0).abs() < 1e-6,
1652 "Expected 1, got {}",
1653 src[idx].im
1654 );
1655 }
1656 }
1657
1658 #[test]
1659 fn test_simd_inplace_negative_values() {
1660 use rustfft::num_complex::Complex;
1662
1663 let mut src = vec![
1664 Complex::new(-1.0, -2.0),
1665 Complex::new(-3.0, -4.0),
1666 Complex::new(-5.0, -6.0),
1667 Complex::new(-7.0, -8.0),
1668 ];
1669
1670 let hrtf = vec![
1671 Complex::new(-0.5, -0.25),
1672 Complex::new(-1.0, -1.5),
1673 Complex::new(-2.0, 0.5),
1674 Complex::new(0.75, -0.75),
1675 ];
1676
1677 let mut expected = src.clone();
1679 for i in 0..expected.len() {
1680 expected[i] *= hrtf[i];
1681 }
1682
1683 complex_mul_inplace_simd(&mut src, &hrtf);
1685
1686 const EPSILON: f32 = 1e-6;
1687 for i in 0..src.len() {
1688 assert!((src[i].re - expected[i].re).abs() < EPSILON);
1689 assert!((src[i].im - expected[i].im).abs() < EPSILON);
1690 }
1691 }
1692
1693 #[test]
1698 fn test_covariance_basic_correctness() {
1699 use rustfft::num_complex::Complex;
1701
1702 let left = vec![
1703 Complex::new(1.0, 2.0),
1704 Complex::new(3.0, 4.0),
1705 Complex::new(-1.0, 0.5),
1706 Complex::new(0.0, -2.0),
1707 Complex::new(2.5, -1.5),
1708 Complex::new(-3.5, 2.5),
1709 Complex::new(1.1, -0.9),
1710 Complex::new(-0.8, 1.2),
1711 ];
1712
1713 let right = vec![
1714 Complex::new(0.5, 0.25),
1715 Complex::new(-1.0, 1.5),
1716 Complex::new(2.0, -0.5),
1717 Complex::new(0.75, 0.75),
1718 Complex::new(-0.5, 2.0),
1719 Complex::new(1.5, -1.0),
1720 Complex::new(0.9, 0.3),
1721 Complex::new(-1.1, 0.7),
1722 ];
1723
1724 let mut expected_xx = 0.0_f32;
1726 let mut expected_yy = 0.0_f32;
1727 let mut expected_xy = Complex::new(0.0, 0.0);
1728 for i in 0..left.len() {
1729 expected_xx += left[i].norm_sqr();
1730 expected_yy += right[i].norm_sqr();
1731 expected_xy += left[i] * right[i].conj();
1732 }
1733
1734 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, 0, left.len());
1736
1737 const EPSILON: f32 = 1e-5;
1738 assert!(
1739 (cov_xx - expected_xx).abs() < EPSILON,
1740 "cov_xx mismatch: {} vs {}",
1741 cov_xx,
1742 expected_xx
1743 );
1744 assert!(
1745 (cov_yy - expected_yy).abs() < EPSILON,
1746 "cov_yy mismatch: {} vs {}",
1747 cov_yy,
1748 expected_yy
1749 );
1750 assert!(
1751 (cov_xy.re - expected_xy.re).abs() < EPSILON,
1752 "cov_xy.re mismatch: {} vs {}",
1753 cov_xy.re,
1754 expected_xy.re
1755 );
1756 assert!(
1757 (cov_xy.im - expected_xy.im).abs() < EPSILON,
1758 "cov_xy.im mismatch: {} vs {}",
1759 cov_xy.im,
1760 expected_xy.im
1761 );
1762 }
1763
1764 #[test]
1765 fn test_covariance_with_ranges() {
1766 use rustfft::num_complex::Complex;
1768
1769 let left: Vec<Complex<f32>> = (0..32)
1770 .map(|i| Complex::new(i as f32 * 0.5, i as f32 * -0.3))
1771 .collect();
1772 let right: Vec<Complex<f32>> = (0..32)
1773 .map(|i| Complex::new(i as f32 * -0.4, i as f32 * 0.6))
1774 .collect();
1775
1776 for (start, end) in [(0, 8), (4, 12), (10, 20), (5, 25), (0, 32)] {
1778 let mut expected_xx = 0.0_f32;
1780 let mut expected_yy = 0.0_f32;
1781 let mut expected_xy = Complex::new(0.0, 0.0);
1782 for i in start..end {
1783 expected_xx += left[i].norm_sqr();
1784 expected_yy += right[i].norm_sqr();
1785 expected_xy += left[i] * right[i].conj();
1786 }
1787
1788 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, start, end);
1790
1791 const EPSILON: f32 = 1e-4;
1792 assert!(
1793 (cov_xx - expected_xx).abs() < EPSILON,
1794 "Range [{}, {}): cov_xx mismatch: {} vs {}",
1795 start,
1796 end,
1797 cov_xx,
1798 expected_xx
1799 );
1800 assert!(
1801 (cov_yy - expected_yy).abs() < EPSILON,
1802 "Range [{}, {}): cov_yy mismatch: {} vs {}",
1803 start,
1804 end,
1805 cov_yy,
1806 expected_yy
1807 );
1808 assert!(
1809 (cov_xy.re - expected_xy.re).abs() < EPSILON,
1810 "Range [{}, {}): cov_xy.re mismatch: {} vs {}",
1811 start,
1812 end,
1813 cov_xy.re,
1814 expected_xy.re
1815 );
1816 assert!(
1817 (cov_xy.im - expected_xy.im).abs() < EPSILON,
1818 "Range [{}, {}): cov_xy.im mismatch: {} vs {}",
1819 start,
1820 end,
1821 cov_xy.im,
1822 expected_xy.im
1823 );
1824 }
1825 }
1826
1827 #[test]
1828 fn test_covariance_large_buffers() {
1829 use rustfft::num_complex::Complex;
1831
1832 for fft_size in [128, 256, 512, 1024, 2048, 4096] {
1833 let left: Vec<Complex<f32>> = (0..fft_size)
1834 .map(|i| {
1835 let phase = (i as f32) * 0.01;
1836 Complex::new(phase.cos(), phase.sin())
1837 })
1838 .collect();
1839
1840 let right: Vec<Complex<f32>> = (0..fft_size)
1841 .map(|i| {
1842 let phase = (i as f32) * 0.02;
1843 Complex::new(phase.sin(), phase.cos())
1844 })
1845 .collect();
1846
1847 let mut expected_xx = 0.0_f32;
1849 let mut expected_yy = 0.0_f32;
1850 let mut expected_xy = Complex::new(0.0, 0.0);
1851 for i in 0..fft_size {
1852 expected_xx += left[i].norm_sqr();
1853 expected_yy += right[i].norm_sqr();
1854 expected_xy += left[i] * right[i].conj();
1855 }
1856
1857 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, 0, fft_size);
1859
1860 let rel_epsilon = 1e-4;
1862 assert!(
1863 (cov_xx - expected_xx).abs() < expected_xx * rel_epsilon,
1864 "FFT size {}: cov_xx mismatch",
1865 fft_size
1866 );
1867 assert!(
1868 (cov_yy - expected_yy).abs() < expected_yy * rel_epsilon,
1869 "FFT size {}: cov_yy mismatch",
1870 fft_size
1871 );
1872 assert!(
1873 (cov_xy.re - expected_xy.re).abs() < expected_xy.re.abs() * rel_epsilon + 1e-5,
1874 "FFT size {}: cov_xy.re mismatch",
1875 fft_size
1876 );
1877 assert!(
1878 (cov_xy.im - expected_xy.im).abs() < expected_xy.im.abs() * rel_epsilon + 1e-5,
1879 "FFT size {}: cov_xy.im mismatch",
1880 fft_size
1881 );
1882 }
1883 }
1884
1885 #[test]
1886 fn test_covariance_unaligned_ranges() {
1887 use rustfft::num_complex::Complex;
1889
1890 let left: Vec<Complex<f32>> = (0..50)
1891 .map(|i| Complex::new(i as f32 * 0.2, i as f32 * 0.3))
1892 .collect();
1893 let right: Vec<Complex<f32>> = (0..50)
1894 .map(|i| Complex::new(i as f32 * -0.1, i as f32 * 0.4))
1895 .collect();
1896
1897 for (start, end) in [(0, 1), (0, 3), (1, 4), (2, 7), (5, 11), (10, 23), (15, 37)] {
1899 let mut expected_xx = 0.0_f32;
1901 let mut expected_yy = 0.0_f32;
1902 let mut expected_xy = Complex::new(0.0, 0.0);
1903 for i in start..end {
1904 expected_xx += left[i].norm_sqr();
1905 expected_yy += right[i].norm_sqr();
1906 expected_xy += left[i] * right[i].conj();
1907 }
1908
1909 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, start, end);
1911
1912 const EPSILON: f32 = 1e-5;
1913 assert!(
1914 (cov_xx - expected_xx).abs() < EPSILON,
1915 "Range [{}, {}): cov_xx mismatch",
1916 start,
1917 end
1918 );
1919 assert!(
1920 (cov_yy - expected_yy).abs() < EPSILON,
1921 "Range [{}, {}): cov_yy mismatch",
1922 start,
1923 end
1924 );
1925 assert!(
1926 (cov_xy.re - expected_xy.re).abs() < EPSILON,
1927 "Range [{}, {}): cov_xy.re mismatch",
1928 start,
1929 end
1930 );
1931 assert!(
1932 (cov_xy.im - expected_xy.im).abs() < EPSILON,
1933 "Range [{}, {}): cov_xy.im mismatch",
1934 start,
1935 end
1936 );
1937 }
1938 }
1939
1940 #[test]
1941 fn test_covariance_edge_cases() {
1942 use rustfft::num_complex::Complex;
1944
1945 let zero_left = vec![Complex::new(0.0, 0.0); 8];
1947 let zero_right = vec![Complex::new(0.0, 0.0); 8];
1948 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&zero_left, &zero_right, 0, 8);
1949 assert!(cov_xx.abs() < 1e-6, "Expected zero cov_xx");
1950 assert!(cov_yy.abs() < 1e-6, "Expected zero cov_yy");
1951 assert!(cov_xy.norm_sqr() < 1e-6, "Expected zero cov_xy");
1952
1953 let real_left: Vec<Complex<f32>> = (0..8).map(|i| Complex::new(i as f32, 0.0)).collect();
1955 let real_right: Vec<Complex<f32>> = (0..8)
1956 .map(|i| Complex::new((i as f32) * 0.5, 0.0))
1957 .collect();
1958 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&real_left, &real_right, 0, 8);
1959
1960 assert!(
1962 cov_xy.im.abs() < 1e-5,
1963 "Expected real cov_xy for real signals"
1964 );
1965
1966 let mut expected_xx = 0.0;
1968 let mut expected_yy = 0.0;
1969 for i in 0..8 {
1970 expected_xx += (i * i) as f32;
1971 expected_yy += ((i as f32) * 0.5).powi(2);
1972 }
1973 assert!((cov_xx - expected_xx).abs() < 1e-5);
1974 assert!((cov_yy - expected_yy).abs() < 1e-5);
1975
1976 let imag_left: Vec<Complex<f32>> = (0..8).map(|i| Complex::new(0.0, i as f32)).collect();
1978 let imag_right: Vec<Complex<f32>> = (0..8)
1979 .map(|i| Complex::new(0.0, (i as f32) * 2.0))
1980 .collect();
1981 let (_cov_xx, _cov_yy, cov_xy) = compute_covariance_simd(&imag_left, &imag_right, 0, 8);
1982
1983 assert!(
1985 cov_xy.im.abs() < 1e-5,
1986 "Expected real cov_xy for imaginary signals"
1987 );
1988
1989 let single_left = vec![Complex::new(3.0, 4.0)];
1991 let single_right = vec![Complex::new(1.0, 2.0)];
1992 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&single_left, &single_right, 0, 1);
1993 assert!((cov_xx - 25.0).abs() < 1e-5); assert!((cov_yy - 5.0).abs() < 1e-5); assert!((cov_xy.re - 11.0).abs() < 1e-5);
1997 assert!((cov_xy.im - (-2.0)).abs() < 1e-5);
1998 }
1999
2000 #[test]
2005 fn test_numerical_accuracy_small_values() {
2006 use rustfft::num_complex::Complex;
2008
2009 let small = 1e-20_f32;
2010 let src = vec![
2011 Complex::new(small, small),
2012 Complex::new(small * 2.0, small * 3.0),
2013 Complex::new(small * 4.0, small * 5.0),
2014 Complex::new(small * 6.0, small * 7.0),
2015 ];
2016
2017 let hrtf = vec![
2018 Complex::new(1.0, 0.5),
2019 Complex::new(2.0, -1.0),
2020 Complex::new(-0.5, 1.5),
2021 Complex::new(0.75, 0.25),
2022 ];
2023
2024 let expected: Vec<Complex<f32>> = src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2026
2027 let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2029 complex_mul_simd(&mut result, &src, &hrtf);
2030
2031 for i in 0..src.len() {
2033 let re_diff = (result[i].re - expected[i].re).abs();
2034 let im_diff = (result[i].im - expected[i].im).abs();
2035
2036 if expected[i].re.abs() > 1e-15 {
2038 assert!(re_diff / expected[i].re.abs() < 1e-3);
2039 } else {
2040 assert!(re_diff < 1e-25);
2041 }
2042
2043 if expected[i].im.abs() > 1e-15 {
2044 assert!(im_diff / expected[i].im.abs() < 1e-3);
2045 } else {
2046 assert!(im_diff < 1e-25);
2047 }
2048 }
2049 }
2050
2051 #[test]
2052 fn test_numerical_accuracy_large_values() {
2053 use rustfft::num_complex::Complex;
2055
2056 let large = 1e10_f32;
2057 let src = vec![
2058 Complex::new(large, large * 0.5),
2059 Complex::new(large * 2.0, large * 1.5),
2060 Complex::new(large * 0.3, large * 0.7),
2061 Complex::new(large * 1.2, large * 0.8),
2062 ];
2063
2064 let hrtf = vec![
2065 Complex::new(1e-5, 5e-6),
2066 Complex::new(2e-5, -1e-5),
2067 Complex::new(-5e-6, 1.5e-5),
2068 Complex::new(7.5e-6, 2.5e-6),
2069 ];
2070
2071 let mut expected = vec![Complex::new(0.0, 0.0); src.len()];
2073 for i in 0..src.len() {
2074 expected[i] = src[i] * hrtf[i];
2075 }
2076
2077 let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2079 complex_mul_simd(&mut result, &src, &hrtf);
2080
2081 for i in 0..src.len() {
2083 let re_rel_err = (result[i].re - expected[i].re).abs() / expected[i].re.abs().max(1.0);
2084 let im_rel_err = (result[i].im - expected[i].im).abs() / expected[i].im.abs().max(1.0);
2085 assert!(
2086 re_rel_err < 1e-5,
2087 "Index {}: re rel error too large: {}",
2088 i,
2089 re_rel_err
2090 );
2091 assert!(
2092 im_rel_err < 1e-5,
2093 "Index {}: im rel error too large: {}",
2094 i,
2095 im_rel_err
2096 );
2097 }
2098 }
2099
2100 #[test]
2101 fn test_accumulation_accuracy() {
2102 use rustfft::num_complex::Complex;
2104
2105 let src = vec![
2106 Complex::new(0.1, 0.2),
2107 Complex::new(0.3, 0.4),
2108 Complex::new(0.5, 0.6),
2109 Complex::new(0.7, 0.8),
2110 ];
2111
2112 let hrtf = vec![
2113 Complex::new(0.5, 0.25),
2114 Complex::new(-1.0, 1.5),
2115 Complex::new(2.0, -0.5),
2116 Complex::new(0.75, 0.75),
2117 ];
2118
2119 let mut expected = vec![Complex::new(0.0, 0.0); src.len()];
2121 for _ in 0..100 {
2122 for i in 0..src.len() {
2123 expected[i] += src[i] * hrtf[i];
2124 }
2125 }
2126
2127 let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2129 for _ in 0..100 {
2130 complex_mul_add_simd(&mut result, &src, &hrtf);
2131 }
2132
2133 const REL_EPSILON: f32 = 1e-4;
2135 for i in 0..src.len() {
2136 let re_abs_err = (result[i].re - expected[i].re).abs();
2137 let im_abs_err = (result[i].im - expected[i].im).abs();
2138
2139 let re_err = if expected[i].re.abs() > 1e-6 {
2141 re_abs_err / expected[i].re.abs()
2142 } else {
2143 re_abs_err
2144 };
2145 let im_err = if expected[i].im.abs() > 1e-6 {
2146 im_abs_err / expected[i].im.abs()
2147 } else {
2148 im_abs_err
2149 };
2150
2151 assert!(
2152 re_err < REL_EPSILON,
2153 "Index {}: re accumulated error too large: {} (abs: {}, expected: {})",
2154 i,
2155 re_err,
2156 re_abs_err,
2157 expected[i].re
2158 );
2159 assert!(
2160 im_err < REL_EPSILON,
2161 "Index {}: im accumulated error too large: {} (abs: {}, expected: {})",
2162 i,
2163 im_err,
2164 im_abs_err,
2165 expected[i].im
2166 );
2167 }
2168 }
2169
2170 #[test]
2171 fn test_platform_specific_simd_widths() {
2172 use rustfft::num_complex::Complex;
2174
2175 let test_sizes = vec![
2179 1, 2, 3, 4, 5, 8, 9, 12, 16, ];
2189
2190 for size in test_sizes {
2191 let src: Vec<Complex<f32>> = (0..size)
2192 .map(|i| Complex::new(i as f32 * 0.3, i as f32 * -0.2))
2193 .collect();
2194 let hrtf: Vec<Complex<f32>> = (0..size)
2195 .map(|i| Complex::new(1.0 + i as f32 * 0.1, 0.5))
2196 .collect();
2197
2198 let mut result_add = vec![Complex::new(1.0, 2.0); size];
2202 let mut expected_add = result_add.clone();
2203 for i in 0..size {
2204 expected_add[i] += src[i] * hrtf[i];
2205 }
2206 complex_mul_add_simd(&mut result_add, &src, &hrtf);
2207 for i in 0..size {
2208 assert!(
2209 (result_add[i].re - expected_add[i].re).abs() < 1e-6,
2210 "mul_add size {}, index {}: re mismatch",
2211 size,
2212 i
2213 );
2214 assert!(
2215 (result_add[i].im - expected_add[i].im).abs() < 1e-6,
2216 "mul_add size {}, index {}: im mismatch",
2217 size,
2218 i
2219 );
2220 }
2221
2222 let mut result_mul = vec![Complex::new(0.0, 0.0); size];
2224 let expected_mul: Vec<Complex<f32>> =
2225 src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2226 complex_mul_simd(&mut result_mul, &src, &hrtf);
2227 for i in 0..size {
2228 assert!(
2229 (result_mul[i].re - expected_mul[i].re).abs() < 1e-6,
2230 "mul size {}, index {}: re mismatch",
2231 size,
2232 i
2233 );
2234 assert!(
2235 (result_mul[i].im - expected_mul[i].im).abs() < 1e-6,
2236 "mul size {}, index {}: im mismatch",
2237 size,
2238 i
2239 );
2240 }
2241
2242 let mut result_inplace = src.clone();
2244 let mut expected_inplace = src.clone();
2245 for i in 0..size {
2246 expected_inplace[i] *= hrtf[i];
2247 }
2248 complex_mul_inplace_simd(&mut result_inplace, &hrtf);
2249 for i in 0..size {
2250 assert!(
2251 (result_inplace[i].re - expected_inplace[i].re).abs() < 1e-6,
2252 "inplace size {}, index {}: re mismatch",
2253 size,
2254 i
2255 );
2256 assert!(
2257 (result_inplace[i].im - expected_inplace[i].im).abs() < 1e-6,
2258 "inplace size {}, index {}: im mismatch",
2259 size,
2260 i
2261 );
2262 }
2263 }
2264 }
2265
2266 #[test]
2267 fn test_stress_test_random_data() {
2268 use rustfft::num_complex::Complex;
2270
2271 let mut seed = 12345_u32;
2273 let lcg = |s: &mut u32| -> f32 {
2274 *s = s.wrapping_mul(1103515245).wrapping_add(12345);
2275 ((*s / 65536) % 32768) as f32 / 32768.0 - 0.5
2276 };
2277
2278 for size in [64, 128, 256, 512] {
2279 let src: Vec<Complex<f32>> = (0..size)
2280 .map(|_| Complex::new(lcg(&mut seed), lcg(&mut seed)))
2281 .collect();
2282 let hrtf: Vec<Complex<f32>> = (0..size)
2283 .map(|_| Complex::new(lcg(&mut seed), lcg(&mut seed)))
2284 .collect();
2285
2286 let expected: Vec<Complex<f32>> =
2288 src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2289
2290 let mut result = vec![Complex::new(0.0, 0.0); size];
2292 complex_mul_simd(&mut result, &src, &hrtf);
2293
2294 for i in 0..size {
2296 assert!(
2297 (result[i].re - expected[i].re).abs() < 1e-5,
2298 "Stress test size {}, index {}: re mismatch",
2299 size,
2300 i
2301 );
2302 assert!(
2303 (result[i].im - expected[i].im).abs() < 1e-5,
2304 "Stress test size {}, index {}: im mismatch",
2305 size,
2306 i
2307 );
2308 }
2309 }
2310 }
2311}
2312
2313pub fn compute_covariance_simd(
2331 left: &[Complex<f32>],
2332 right: &[Complex<f32>],
2333 start: usize,
2334 end: usize,
2335) -> (f32, f32, Complex<f32>) {
2336 assert_eq!(left.len(), right.len());
2337 assert!(end <= left.len());
2338 assert!(start < end);
2339
2340 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2341 let count = end - start;
2342
2343 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2344 {
2345 use std::arch::x86_64::*;
2346
2347 let mut cov_xx;
2348 let mut cov_yy;
2349 let mut cov_xy = Complex::new(0.0, 0.0);
2350
2351 let simd_len = (count / 4) * 4;
2353 let simd_end = start + simd_len;
2354
2355 unsafe {
2356 let mut sum_xx = _mm256_setzero_ps();
2357 let mut sum_yy = _mm256_setzero_ps();
2358 let mut sum_xy_re = _mm256_setzero_ps();
2359 let _sum_xy_im = _mm256_setzero_ps();
2360
2361 for i in (start..simd_end).step_by(4) {
2362 let left_ptr = left.as_ptr().add(i) as *const f32;
2363 let right_ptr = right.as_ptr().add(i) as *const f32;
2364
2365 let l = _mm256_loadu_ps(left_ptr);
2367 let r = _mm256_loadu_ps(right_ptr);
2368
2369 let l_sqr = _mm256_mul_ps(l, l);
2371 let r_sqr = _mm256_mul_ps(r, r);
2372
2373 let l_norm = _mm256_hadd_ps(l_sqr, l_sqr);
2375 let r_norm = _mm256_hadd_ps(r_sqr, r_sqr);
2376
2377 sum_xx = _mm256_add_ps(sum_xx, l_norm);
2378 sum_yy = _mm256_add_ps(sum_yy, r_norm);
2379
2380 let sign_mask = _mm256_set_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
2382 let r_conj = _mm256_xor_ps(r, sign_mask);
2383
2384 let l_re = _mm256_moveldup_ps(l);
2386 let l_im = _mm256_movehdup_ps(l);
2387
2388 let ac_ad = _mm256_mul_ps(l_re, r_conj);
2389 let r_conj_swap = _mm256_shuffle_ps(r_conj, r_conj, 0b10110001);
2390 let bd_bc = _mm256_mul_ps(l_im, r_conj_swap);
2391
2392 let result = _mm256_addsub_ps(ac_ad, bd_bc);
2393
2394 sum_xy_re = _mm256_add_ps(sum_xy_re, result);
2396 }
2397
2398 let xx_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_xx);
2400 let yy_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_yy);
2401 let xy_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_xy_re);
2402
2403 cov_xx = xx_arr[0] + xx_arr[1] + xx_arr[4] + xx_arr[5];
2406 cov_yy = yy_arr[0] + yy_arr[1] + yy_arr[4] + yy_arr[5];
2407
2408 cov_xy.re = xy_arr[0] + xy_arr[2] + xy_arr[4] + xy_arr[6];
2410 cov_xy.im = xy_arr[1] + xy_arr[3] + xy_arr[5] + xy_arr[7];
2411 }
2412
2413 for i in simd_end..end {
2415 let l = left[i];
2416 let r = right[i];
2417 cov_xx += l.norm_sqr();
2418 cov_yy += r.norm_sqr();
2419 cov_xy += l * r.conj();
2420 }
2421
2422 (cov_xx, cov_yy, cov_xy)
2423 }
2424
2425 #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
2426 {
2427 let mut cov_xx = 0.0_f32;
2428 let mut cov_yy = 0.0_f32;
2429 let mut cov_xy = Complex::new(0.0, 0.0);
2430
2431 for i in start..end {
2432 let l = left[i];
2433 let r = right[i];
2434 cov_xx += l.norm_sqr();
2435 cov_yy += r.norm_sqr();
2436 cov_xy += l * r.conj();
2437 }
2438
2439 (cov_xx, cov_yy, cov_xy)
2440 }
2441}
2442
2443#[inline]
2445pub fn apply_gain_simd(buffer: &mut [f32], gain: f32) {
2446 let len = buffer.len();
2447
2448 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2449 {
2450 use std::arch::x86_64::*;
2451 let gain_vec = unsafe { _mm256_set1_ps(gain) };
2452 let simd_len = (len / 8) * 8;
2453 for i in (0..simd_len).step_by(8) {
2454 unsafe {
2455 let ptr = buffer.as_mut_ptr().add(i);
2456 let v = _mm256_loadu_ps(ptr);
2457 let res = _mm256_mul_ps(v, gain_vec);
2458 _mm256_storeu_ps(ptr, res);
2459 }
2460 }
2461 for sample in buffer.iter_mut().take(len).skip(simd_len) {
2462 *sample *= gain;
2463 }
2464 }
2465
2466 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2467 {
2468 use std::arch::aarch64::*;
2469 let gain_vec = unsafe { vdupq_n_f32(gain) };
2470 let simd_len = (len / 4) * 4;
2471 for i in (0..simd_len).step_by(4) {
2472 unsafe {
2473 let ptr = buffer.as_mut_ptr().add(i);
2474 let v = vld1q_f32(ptr);
2475 let res = vmulq_f32(v, gain_vec);
2476 vst1q_f32(ptr, res);
2477 }
2478 }
2479 for sample in buffer[simd_len..len].iter_mut() {
2480 *sample *= gain;
2481 }
2482 }
2483
2484 #[cfg(not(any(
2485 all(target_arch = "x86_64", target_feature = "avx2"),
2486 all(target_arch = "aarch64", target_feature = "neon")
2487 )))]
2488 {
2489 for val in buffer.iter_mut() {
2490 *val *= gain;
2491 }
2492 }
2493}
2494
2495#[inline]
2497pub fn apply_per_channel_gain_simd(buffer: &mut [f32], channels: usize, gains: &[f32]) {
2498 let len = buffer.len();
2499 let num_frames = len / channels;
2500
2501 if channels == 2 {
2505 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2506 {
2507 use std::arch::x86_64::*;
2508 let gains_vec = unsafe {
2509 _mm256_set_ps(
2510 gains[1], gains[0], gains[1], gains[0], gains[1], gains[0], gains[1], gains[0],
2511 )
2512 };
2513 let simd_len = (num_frames / 4) * 4;
2514 for i in (0..simd_len).step_by(4) {
2515 unsafe {
2516 let ptr = buffer.as_mut_ptr().add(i * 2);
2517 let v = _mm256_loadu_ps(ptr);
2518 let res = _mm256_mul_ps(v, gains_vec);
2519 _mm256_storeu_ps(ptr, res);
2520 }
2521 }
2522 for i in simd_len..num_frames {
2523 buffer[i * 2] *= gains[0];
2524 buffer[i * 2 + 1] *= gains[1];
2525 }
2526 return;
2527 }
2528
2529 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2530 {
2531 use std::arch::aarch64::*;
2532 let gains_vec = unsafe {
2533 let g = [gains[0], gains[1], gains[0], gains[1]];
2534 vld1q_f32(g.as_ptr())
2535 };
2536 let simd_len = (num_frames / 2) * 2;
2537 for i in (0..simd_len).step_by(2) {
2538 unsafe {
2539 let ptr = buffer.as_mut_ptr().add(i * 2);
2540 let v = vld1q_f32(ptr);
2541 let res = vmulq_f32(v, gains_vec);
2542 vst1q_f32(ptr, res);
2543 }
2544 }
2545 for i in simd_len..num_frames {
2546 buffer[i * 2] *= gains[0];
2547 buffer[i * 2 + 1] *= gains[1];
2548 }
2549 return;
2550 }
2551 }
2552
2553 for frame in 0..num_frames {
2555 for ch in 0..channels {
2556 buffer[frame * channels + ch] *= gains[ch];
2557 }
2558 }
2559}
2560
2561#[inline(always)]
2564pub fn fast_inv_sqrt(x: f32) -> f32 {
2565 let half = 0.5 * x;
2566 let i = f32::to_bits(x);
2567 let i = 0x5f37_59df - (i >> 1); let y = f32::from_bits(i);
2569 y * (1.5 - half * y * y) }
2571
2572#[inline]
2574pub fn find_max_abs_simd(samples: &[f32]) -> f32 {
2575 let len = samples.len();
2576 if len == 0 {
2577 return 0.0;
2578 }
2579
2580 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2581 {
2582 use std::arch::x86_64::*;
2583 let mut max_vec = unsafe { _mm256_setzero_ps() };
2584 let abs_mask = unsafe { _mm256_set1_ps(-0.0) };
2585 let simd_len = (len / 8) * 8;
2586
2587 for i in (0..simd_len).step_by(8) {
2588 unsafe {
2589 let ptr = samples.as_ptr().add(i);
2590 let v = _mm256_loadu_ps(ptr);
2591 let av = _mm256_andnot_ps(abs_mask, v);
2592 max_vec = _mm256_max_ps(max_vec, av);
2593 }
2594 }
2595
2596 let mut max_val = 0.0_f32;
2597 unsafe {
2598 let arr = std::mem::transmute::<__m256, [f32; 8]>(max_vec);
2599 for &v in &arr {
2600 if v > max_val {
2601 max_val = v;
2602 }
2603 }
2604 }
2605
2606 for sample in samples.iter().take(len).skip(simd_len) {
2607 let v = sample.abs();
2608 if v > max_val {
2609 max_val = v;
2610 }
2611 }
2612 max_val
2613 }
2614
2615 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2616 {
2617 use std::arch::aarch64::*;
2618 let mut max_vec = unsafe { vdupq_n_f32(0.0) };
2619 let simd_len = (len / 4) * 4;
2620
2621 for i in (0..simd_len).step_by(4) {
2622 unsafe {
2623 let ptr = samples.as_ptr().add(i);
2624 let v = vld1q_f32(ptr);
2625 let av = vabsq_f32(v);
2626 max_vec = vmaxq_f32(max_vec, av);
2627 }
2628 }
2629
2630 let mut max_val = unsafe { vmaxvq_f32(max_vec) };
2631
2632 for sample in &samples[simd_len..len] {
2633 let v = sample.abs();
2634 if v > max_val {
2635 max_val = v;
2636 }
2637 }
2638 max_val
2639 }
2640
2641 #[cfg(not(any(
2642 all(target_arch = "x86_64", target_feature = "avx2"),
2643 all(target_arch = "aarch64", target_feature = "neon")
2644 )))]
2645 {
2646 let mut max_val = 0.0_f32;
2647 for &s in samples {
2648 let v = s.abs();
2649 if v > max_val {
2650 max_val = v;
2651 }
2652 }
2653 max_val
2654 }
2655}