1use rustfft::num_complex::Complex;
23
24#[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
27#[inline(always)]
28unsafe fn fcmla_mul_acc(
29 mut r: std::arch::aarch64::float32x4_t,
30 a: std::arch::aarch64::float32x4_t,
31 b: std::arch::aarch64::float32x4_t,
32) -> std::arch::aarch64::float32x4_t {
33 unsafe {
34 std::arch::asm!(
35 "fcmla {r:v}.4s, {a:v}.4s, {b:v}.4s, #0",
36 "fcmla {r:v}.4s, {a:v}.4s, {b:v}.4s, #90",
37 r = inout(vreg) r,
38 a = in(vreg) a,
39 b = in(vreg) b,
40 options(pure, nomem, nostack),
41 );
42 }
43 r
44}
45
46#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
48const SHUFFLE_SWAP_RE_IM: i32 = 0b10110001; #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
51#[inline]
54pub unsafe fn complex_mul_add_simd_chunk(
55 dst: &mut [Complex<f32>],
56 src: &[Complex<f32>],
57 hrtf: &[Complex<f32>],
58 start: usize,
59) {
60 use std::arch::x86_64::*;
61
62 unsafe {
65 let src_ptr = src.as_ptr().add(start) as *const f32;
66 let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
67 let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
68
69 let a = _mm256_loadu_ps(src_ptr);
71 let b = _mm256_loadu_ps(hrtf_ptr);
72 let dst_val = _mm256_loadu_ps(dst_ptr);
73
74 let a_re = _mm256_moveldup_ps(a);
80 let a_im = _mm256_movehdup_ps(a);
81
82 let ac_ad = _mm256_mul_ps(a_re, b);
84
85 let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
87
88 let bd_bc = _mm256_mul_ps(a_im, b_swapped);
90
91 let result = _mm256_addsub_ps(ac_ad, bd_bc);
94
95 let final_result = _mm256_add_ps(dst_val, result);
97
98 _mm256_storeu_ps(dst_ptr, final_result);
99 }
100}
101
102#[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
103#[inline]
106pub unsafe fn complex_mul_add_simd_chunk(
107 dst: &mut [Complex<f32>],
108 src: &[Complex<f32>],
109 hrtf: &[Complex<f32>],
110 start: usize,
111) {
112 use std::arch::aarch64::*;
113
114 unsafe {
115 let src_ptr = src.as_ptr().add(start) as *const f32;
116 let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
117 let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
118
119 let a = vld1q_f32(src_ptr);
120 let b = vld1q_f32(hrtf_ptr);
121 let r = vld1q_f32(dst_ptr);
122 let result = fcmla_mul_acc(r, a, b);
123 vst1q_f32(dst_ptr, result);
124 }
125}
126
127#[cfg(all(
128 target_arch = "aarch64",
129 target_feature = "neon",
130 not(target_feature = "fcma")
131))]
132#[inline]
135pub unsafe fn complex_mul_add_simd_chunk(
136 dst: &mut [Complex<f32>],
137 src: &[Complex<f32>],
138 hrtf: &[Complex<f32>],
139 start: usize,
140) {
141 use std::arch::aarch64::*;
142
143 unsafe {
144 let src_ptr = src.as_ptr().add(start) as *const f32;
145 let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
146 let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
147
148 let a = vld1q_f32(src_ptr);
149 let b = vld1q_f32(hrtf_ptr);
150 let dst_val = vld1q_f32(dst_ptr);
151
152 let a_re = vtrn1q_f32(a, a);
153 let a_im = vtrn2q_f32(a, a);
154 let ac_ad = vmulq_f32(a_re, b);
155 let b_swapped = vrev64q_f32(b);
156 let bd_bc = vmulq_f32(a_im, b_swapped);
157
158 let sign_bit: u32 = 0x80000000;
159 let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
160 sign_bit,
161 vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
162 ));
163
164 let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
165 vreinterpretq_u32_f32(bd_bc),
166 vreinterpretq_u32_f32(neg_mask),
167 ));
168 let result = vaddq_f32(ac_ad, bd_bc_negated);
169 let final_result = vaddq_f32(dst_val, result);
170
171 vst1q_f32(dst_ptr, final_result);
172 }
173}
174
175#[cfg(not(any(
176 all(target_arch = "x86_64", target_feature = "avx2"),
177 all(target_arch = "aarch64", target_feature = "neon")
178)))]
179#[inline]
180pub fn complex_mul_add_simd_chunk(
181 dst: &mut [Complex<f32>],
182 src: &[Complex<f32>],
183 hrtf: &[Complex<f32>],
184 start: usize,
185) {
186 dst[start] += src[start] * hrtf[start];
188}
189
190#[inline]
199pub fn complex_mul_add_simd(dst: &mut [Complex<f32>], src: &[Complex<f32>], hrtf: &[Complex<f32>]) {
200 let len = dst.len();
201
202 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
203 {
204 let simd_len = (len / 4) * 4;
206
207 for i in (0..simd_len).step_by(4) {
208 unsafe {
209 complex_mul_add_simd_chunk(dst, src, hrtf, i);
210 }
211 }
212
213 for i in simd_len..len {
215 dst[i] += src[i] * hrtf[i];
216 }
217 }
218
219 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
220 {
221 let simd_len = (len / 2) * 2;
223
224 for i in (0..simd_len).step_by(2) {
225 unsafe {
226 complex_mul_add_simd_chunk(dst, src, hrtf, i);
227 }
228 }
229
230 for i in simd_len..len {
232 dst[i] += src[i] * hrtf[i];
233 }
234 }
235
236 #[cfg(not(any(
237 all(target_arch = "x86_64", target_feature = "avx2"),
238 all(target_arch = "aarch64", target_feature = "neon")
239 )))]
240 {
241 for i in 0..len {
243 dst[i] += src[i] * hrtf[i];
244 }
245 }
246}
247
248#[inline]
252pub fn complex_mul_simd(dst: &mut [Complex<f32>], src: &[Complex<f32>], hrtf: &[Complex<f32>]) {
253 let len = dst.len();
254
255 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
256 {
257 use std::arch::x86_64::*;
258
259 let simd_len = (len / 4) * 4;
260
261 for i in (0..simd_len).step_by(4) {
262 unsafe {
263 let src_ptr = src.as_ptr().add(i) as *const f32;
264 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
265 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
266
267 let a = _mm256_loadu_ps(src_ptr);
268 let b = _mm256_loadu_ps(hrtf_ptr);
269
270 let a_re = _mm256_moveldup_ps(a);
271 let a_im = _mm256_movehdup_ps(a);
272 let ac_ad = _mm256_mul_ps(a_re, b);
273 let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
274 let bd_bc = _mm256_mul_ps(a_im, b_swapped);
275 let result = _mm256_addsub_ps(ac_ad, bd_bc);
276
277 _mm256_storeu_ps(dst_ptr, result);
278 }
279 }
280
281 for i in simd_len..len {
282 dst[i] = src[i] * hrtf[i];
283 }
284 }
285
286 #[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
287 {
288 use std::arch::aarch64::*;
289
290 let simd_len = (len / 2) * 2;
291
292 for i in (0..simd_len).step_by(2) {
293 unsafe {
294 let src_ptr = src.as_ptr().add(i) as *const f32;
295 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
296 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
297
298 let a = vld1q_f32(src_ptr);
299 let b = vld1q_f32(hrtf_ptr);
300 let r = vdupq_n_f32(0.0);
301 let result = fcmla_mul_acc(r, a, b);
302 vst1q_f32(dst_ptr, result);
303 }
304 }
305
306 for i in simd_len..len {
307 dst[i] = src[i] * hrtf[i];
308 }
309 }
310
311 #[cfg(all(
312 target_arch = "aarch64",
313 target_feature = "neon",
314 not(target_feature = "fcma")
315 ))]
316 {
317 use std::arch::aarch64::*;
318
319 let simd_len = (len / 2) * 2;
320
321 for i in (0..simd_len).step_by(2) {
322 unsafe {
323 let src_ptr = src.as_ptr().add(i) as *const f32;
324 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
325 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
326
327 let a = vld1q_f32(src_ptr);
328 let b = vld1q_f32(hrtf_ptr);
329
330 let a_re = vtrn1q_f32(a, a);
331 let a_im = vtrn2q_f32(a, a);
332 let ac_ad = vmulq_f32(a_re, b);
333 let b_swapped = vrev64q_f32(b);
334 let bd_bc = vmulq_f32(a_im, b_swapped);
335
336 let sign_bit: u32 = 0x80000000;
337 let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
338 sign_bit,
339 vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
340 ));
341
342 let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
343 vreinterpretq_u32_f32(bd_bc),
344 vreinterpretq_u32_f32(neg_mask),
345 ));
346
347 let result = vaddq_f32(ac_ad, bd_bc_negated);
348 vst1q_f32(dst_ptr, result);
349 }
350 }
351
352 for i in simd_len..len {
353 dst[i] = src[i] * hrtf[i];
354 }
355 }
356
357 #[cfg(not(any(
358 all(target_arch = "x86_64", target_feature = "avx2"),
359 all(target_arch = "aarch64", target_feature = "neon")
360 )))]
361 {
362 for i in 0..len {
363 dst[i] = src[i] * hrtf[i];
364 }
365 }
366}
367
368#[inline]
372#[allow(dead_code)]
373pub fn complex_mul_inplace_simd(dst: &mut [Complex<f32>], hrtf: &[Complex<f32>]) {
374 let len = dst.len();
375
376 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
377 {
378 use std::arch::x86_64::*;
379
380 let simd_len = (len / 4) * 4;
381
382 for i in (0..simd_len).step_by(4) {
383 unsafe {
384 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
385 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
386
387 let a = _mm256_loadu_ps(dst_ptr);
388 let b = _mm256_loadu_ps(hrtf_ptr);
389
390 let a_re = _mm256_moveldup_ps(a);
391 let a_im = _mm256_movehdup_ps(a);
392 let ac_ad = _mm256_mul_ps(a_re, b);
393 let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
394 let bd_bc = _mm256_mul_ps(a_im, b_swapped);
395 let result = _mm256_addsub_ps(ac_ad, bd_bc);
396
397 _mm256_storeu_ps(dst_ptr, result);
398 }
399 }
400
401 for i in simd_len..len {
402 dst[i] *= hrtf[i];
403 }
404 }
405
406 #[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
407 {
408 use std::arch::aarch64::*;
409
410 let simd_len = (len / 2) * 2;
411
412 for i in (0..simd_len).step_by(2) {
413 unsafe {
414 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
415 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
416
417 let a = vld1q_f32(dst_ptr);
418 let b = vld1q_f32(hrtf_ptr);
419 let r = vdupq_n_f32(0.0);
420 let result = fcmla_mul_acc(r, a, b);
421 vst1q_f32(dst_ptr, result);
422 }
423 }
424
425 for i in simd_len..len {
426 dst[i] *= hrtf[i];
427 }
428 }
429
430 #[cfg(all(
431 target_arch = "aarch64",
432 target_feature = "neon",
433 not(target_feature = "fcma")
434 ))]
435 {
436 use std::arch::aarch64::*;
437
438 let simd_len = (len / 2) * 2;
439
440 for i in (0..simd_len).step_by(2) {
441 unsafe {
442 let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
443 let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
444
445 let a = vld1q_f32(dst_ptr);
446 let b = vld1q_f32(hrtf_ptr);
447
448 let a_re = vtrn1q_f32(a, a);
449 let a_im = vtrn2q_f32(a, a);
450 let ac_ad = vmulq_f32(a_re, b);
451 let b_swapped = vrev64q_f32(b);
452 let bd_bc = vmulq_f32(a_im, b_swapped);
453
454 let sign_bit: u32 = 0x80000000;
455 let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
456 sign_bit,
457 vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
458 ));
459
460 let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
461 vreinterpretq_u32_f32(bd_bc),
462 vreinterpretq_u32_f32(neg_mask),
463 ));
464
465 let result = vaddq_f32(ac_ad, bd_bc_negated);
466 vst1q_f32(dst_ptr, result);
467 }
468 }
469
470 for i in simd_len..len {
471 dst[i] *= hrtf[i];
472 }
473 }
474
475 #[cfg(not(any(
476 all(target_arch = "x86_64", target_feature = "avx2"),
477 all(target_arch = "aarch64", target_feature = "neon")
478 )))]
479 {
480 for i in 0..len {
481 dst[i] *= hrtf[i];
482 }
483 }
484}
485
486#[inline]
494pub fn scale_add_simd(dst: &mut [f32], src: &[f32], scale: f32) {
495 debug_assert_eq!(dst.len(), src.len());
496 let len = dst.len();
497
498 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
499 {
500 use std::arch::x86_64::*;
501
502 let scale_vec = unsafe { _mm256_set1_ps(scale) };
503 let simd_len = (len / 8) * 8;
504
505 for i in (0..simd_len).step_by(8) {
506 unsafe {
507 let src_ptr = src.as_ptr().add(i);
508 let dst_ptr = dst.as_mut_ptr().add(i);
509
510 let s = _mm256_loadu_ps(src_ptr);
511 let d = _mm256_loadu_ps(dst_ptr);
512
513 let ss = _mm256_mul_ps(s, scale_vec);
515 let result = _mm256_add_ps(d, ss);
516
517 _mm256_storeu_ps(dst_ptr, result);
518 }
519 }
520
521 for i in simd_len..len {
523 dst[i] += src[i] * scale;
524 }
525 }
526
527 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
528 {
529 use std::arch::aarch64::*;
530
531 let scale_vec = unsafe { vdupq_n_f32(scale) };
532 let simd_len = (len / 4) * 4;
533
534 for i in (0..simd_len).step_by(4) {
535 unsafe {
536 let src_ptr = src.as_ptr().add(i);
537 let dst_ptr = dst.as_mut_ptr().add(i);
538
539 let s = vld1q_f32(src_ptr);
540 let d = vld1q_f32(dst_ptr);
541
542 let result = vfmaq_f32(d, s, scale_vec);
544
545 vst1q_f32(dst_ptr, result);
546 }
547 }
548
549 for i in simd_len..len {
551 dst[i] += src[i] * scale;
552 }
553 }
554
555 #[cfg(not(any(
556 all(target_arch = "x86_64", target_feature = "avx2"),
557 all(target_arch = "aarch64", target_feature = "neon")
558 )))]
559 {
560 for i in 0..len {
561 dst[i] += src[i] * scale;
562 }
563 }
564}
565
566#[inline]
570pub fn scale_add_simd_inplace(data: &mut [f32], scale: f32) {
571 let len = data.len();
572
573 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
574 {
575 use std::arch::x86_64::*;
576
577 let scale_vec = unsafe { _mm256_set1_ps(scale) };
578 let simd_len = (len / 8) * 8;
579
580 for i in (0..simd_len).step_by(8) {
581 unsafe {
582 let ptr = data.as_mut_ptr().add(i);
583 let d = _mm256_loadu_ps(ptr);
584 _mm256_storeu_ps(ptr, _mm256_mul_ps(d, scale_vec));
585 }
586 }
587
588 for sample in data.iter_mut().take(len).skip(simd_len) {
589 *sample *= scale;
590 }
591 }
592
593 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
594 {
595 use std::arch::aarch64::*;
596
597 let scale_vec = unsafe { vdupq_n_f32(scale) };
598 let simd_len = (len / 4) * 4;
599
600 for i in (0..simd_len).step_by(4) {
601 unsafe {
602 let ptr = data.as_mut_ptr().add(i);
603 let d = vld1q_f32(ptr);
604 vst1q_f32(ptr, vmulq_f32(d, scale_vec));
605 }
606 }
607
608 for sample in &mut data[simd_len..len] {
609 *sample *= scale;
610 }
611 }
612
613 #[cfg(not(any(
614 all(target_arch = "x86_64", target_feature = "avx2"),
615 all(target_arch = "aarch64", target_feature = "neon")
616 )))]
617 {
618 for sample in data {
619 *sample *= scale;
620 }
621 }
622}
623
624#[inline]
631pub fn blend_simd(dst: &mut [f32], prev: &[f32], alpha: f32) {
632 debug_assert_eq!(dst.len(), prev.len());
633 let len = dst.len();
634
635 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
636 {
637 use std::arch::x86_64::*;
638
639 let alpha_vec = unsafe { _mm256_set1_ps(alpha) };
640 let simd_len = (len / 8) * 8;
641
642 for i in (0..simd_len).step_by(8) {
643 unsafe {
644 let prev_ptr = prev.as_ptr().add(i);
645 let dst_ptr = dst.as_mut_ptr().add(i);
646
647 let p = _mm256_loadu_ps(prev_ptr);
648 let d = _mm256_loadu_ps(dst_ptr);
649
650 let diff = _mm256_sub_ps(d, p);
652 let result = _mm256_fmadd_ps(alpha_vec, diff, p);
653
654 _mm256_storeu_ps(dst_ptr, result);
655 }
656 }
657
658 for i in simd_len..len {
659 dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
660 }
661 }
662
663 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
664 {
665 use std::arch::aarch64::*;
666
667 let alpha_vec = unsafe { vdupq_n_f32(alpha) };
668 let simd_len = (len / 4) * 4;
669
670 for i in (0..simd_len).step_by(4) {
671 unsafe {
672 let prev_ptr = prev.as_ptr().add(i);
673 let dst_ptr = dst.as_mut_ptr().add(i);
674
675 let p = vld1q_f32(prev_ptr);
676 let d = vld1q_f32(dst_ptr);
677
678 let diff = vsubq_f32(d, p);
680 let result = vfmaq_f32(p, alpha_vec, diff);
681
682 vst1q_f32(dst_ptr, result);
683 }
684 }
685
686 for i in simd_len..len {
687 dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
688 }
689 }
690
691 #[cfg(not(any(
692 all(target_arch = "x86_64", target_feature = "avx2"),
693 all(target_arch = "aarch64", target_feature = "neon")
694 )))]
695 {
696 for i in 0..len {
697 dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
698 }
699 }
700}
701
702#[inline]
706pub fn window_mul_simd(dst: &mut [f32], src: &[f32], window: &[f32]) {
707 debug_assert_eq!(dst.len(), src.len());
708 debug_assert_eq!(dst.len(), window.len());
709 let len = dst.len();
710
711 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
712 {
713 use std::arch::x86_64::*;
714
715 let simd_len = (len / 8) * 8;
716
717 for i in (0..simd_len).step_by(8) {
718 unsafe {
719 let src_ptr = src.as_ptr().add(i);
720 let win_ptr = window.as_ptr().add(i);
721 let dst_ptr = dst.as_mut_ptr().add(i);
722
723 let s = _mm256_loadu_ps(src_ptr);
724 let w = _mm256_loadu_ps(win_ptr);
725 let result = _mm256_mul_ps(s, w);
726
727 _mm256_storeu_ps(dst_ptr, result);
728 }
729 }
730
731 for i in simd_len..len {
732 dst[i] = src[i] * window[i];
733 }
734 }
735
736 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
737 {
738 use std::arch::aarch64::*;
739
740 let simd_len = (len / 4) * 4;
741
742 for i in (0..simd_len).step_by(4) {
743 unsafe {
744 let src_ptr = src.as_ptr().add(i);
745 let win_ptr = window.as_ptr().add(i);
746 let dst_ptr = dst.as_mut_ptr().add(i);
747
748 let s = vld1q_f32(src_ptr);
749 let w = vld1q_f32(win_ptr);
750 let result = vmulq_f32(s, w);
751
752 vst1q_f32(dst_ptr, result);
753 }
754 }
755
756 for i in simd_len..len {
757 dst[i] = src[i] * window[i];
758 }
759 }
760
761 #[cfg(not(any(
762 all(target_arch = "x86_64", target_feature = "avx2"),
763 all(target_arch = "aarch64", target_feature = "neon")
764 )))]
765 {
766 for i in 0..len {
767 dst[i] = src[i] * window[i];
768 }
769 }
770}
771
772#[inline]
774pub fn window_mul_simd_inplace(data: &mut [f32], window: &[f32]) {
775 let len = data.len().min(window.len());
776
777 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
778 {
779 use std::arch::x86_64::*;
780 let simd_len = (len / 8) * 8;
781 for i in (0..simd_len).step_by(8) {
782 unsafe {
783 let ptr = data.as_mut_ptr().add(i);
784 let win_ptr = window.as_ptr().add(i);
785 let d = _mm256_loadu_ps(ptr);
786 let w = _mm256_loadu_ps(win_ptr);
787 _mm256_storeu_ps(ptr, _mm256_mul_ps(d, w));
788 }
789 }
790 for i in simd_len..len {
791 data[i] *= window[i];
792 }
793 }
794
795 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
796 {
797 use std::arch::aarch64::*;
798 let simd_len = (len / 4) * 4;
799 for i in (0..simd_len).step_by(4) {
800 unsafe {
801 let ptr = data.as_mut_ptr().add(i);
802 let win_ptr = window.as_ptr().add(i);
803 let d = vld1q_f32(ptr);
804 let w = vld1q_f32(win_ptr);
805 vst1q_f32(ptr, vmulq_f32(d, w));
806 }
807 }
808 for i in simd_len..len {
809 data[i] *= window[i];
810 }
811 }
812
813 #[cfg(not(any(
814 all(target_arch = "x86_64", target_feature = "avx2"),
815 all(target_arch = "aarch64", target_feature = "neon")
816 )))]
817 {
818 for i in 0..len {
819 data[i] *= window[i];
820 }
821 }
822}
823
824#[inline]
829pub fn deinterleave_stereo(input: &[f32], left: &mut [f32], right: &mut [f32]) {
830 debug_assert_eq!(input.len(), left.len() * 2);
831 debug_assert_eq!(left.len(), right.len());
832
833 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
834 {
835 use std::arch::x86_64::*;
836
837 let len = left.len();
838 let simd_len = (len / 8) * 8;
839
840 for i in (0..simd_len).step_by(8) {
841 unsafe {
842 let in_ptr = input.as_ptr().add(i * 2);
844 let v0 = _mm256_loadu_ps(in_ptr); let v1 = _mm256_loadu_ps(in_ptr.add(8)); let shuf_l = _mm256_shuffle_ps(v0, v1, 0b10_00_10_00); let shuf_r = _mm256_shuffle_ps(v0, v1, 0b11_01_11_01); let left_vec = _mm256_permute4x64_pd(
854 std::mem::transmute::<__m256, __m256d>(shuf_l),
855 0b11_01_10_00,
856 );
857 let right_vec = _mm256_permute4x64_pd(
858 std::mem::transmute::<__m256, __m256d>(shuf_r),
859 0b11_01_10_00,
860 );
861
862 _mm256_storeu_ps(
863 left.as_mut_ptr().add(i),
864 std::mem::transmute::<__m256d, __m256>(left_vec),
865 );
866 _mm256_storeu_ps(
867 right.as_mut_ptr().add(i),
868 std::mem::transmute::<__m256d, __m256>(right_vec),
869 );
870 }
871 }
872
873 for i in simd_len..len {
875 left[i] = input[i * 2];
876 right[i] = input[i * 2 + 1];
877 }
878 }
879
880 #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
881 {
882 for (i, chunk) in input.chunks_exact(2).enumerate() {
884 left[i] = chunk[0];
885 right[i] = chunk[1];
886 }
887 }
888}
889
890#[inline]
895#[allow(dead_code)]
896pub fn interleave_stereo(left: &[f32], right: &[f32], output: &mut [f32]) {
897 debug_assert_eq!(left.len(), right.len());
898 debug_assert_eq!(output.len(), left.len() * 2);
899
900 for i in 0..left.len() {
902 output[i * 2] = left[i];
903 output[i * 2 + 1] = right[i];
904 }
905}
906
907#[inline]
914pub fn flush_denormals_inplace(samples: &mut [f32]) {
915 const DENORM_THRESHOLD: f32 = f32::MIN_POSITIVE;
918
919 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
920 {
921 use std::arch::x86_64::*;
922
923 let threshold = unsafe { _mm256_set1_ps(DENORM_THRESHOLD) };
924 let zero = unsafe { _mm256_set1_ps(0.0) };
925 let len = samples.len();
926 let simd_len = (len / 8) * 8;
927
928 for i in (0..simd_len).step_by(8) {
929 unsafe {
930 let ptr = samples.as_mut_ptr().add(i);
931 let val = _mm256_loadu_ps(ptr);
932 let abs_val = _mm256_andnot_ps(_mm256_set1_ps(-0.0), val);
933 let mask = _mm256_cmp_ps(abs_val, threshold, _CMP_LT_OQ);
934 let result = _mm256_blendv_ps(val, zero, mask);
935 _mm256_storeu_ps(ptr, result);
936 }
937 }
938
939 for sample in samples.iter_mut().take(len).skip(simd_len) {
940 if sample.abs() < DENORM_THRESHOLD {
941 *sample = 0.0;
942 }
943 }
944 }
945
946 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
947 {
948 use std::arch::aarch64::*;
949
950 let threshold = unsafe { vdupq_n_f32(DENORM_THRESHOLD) };
951 let zero = unsafe { vdupq_n_f32(0.0) };
952 let len = samples.len();
953 let simd_len = (len / 4) * 4;
954
955 for i in (0..simd_len).step_by(4) {
956 unsafe {
957 let ptr = samples.as_mut_ptr().add(i);
958 let val = vld1q_f32(ptr);
959 let abs_val = vabsq_f32(val);
960 let mask = vcltq_f32(abs_val, threshold);
961 let result = vbslq_f32(mask, zero, val);
962 vst1q_f32(ptr, result);
963 }
964 }
965
966 for sample in &mut samples[simd_len..len] {
967 if sample.abs() < DENORM_THRESHOLD {
968 *sample = 0.0;
969 }
970 }
971 }
972
973 #[cfg(not(any(
974 all(target_arch = "x86_64", target_feature = "avx2"),
975 all(target_arch = "aarch64", target_feature = "neon")
976 )))]
977 {
978 for sample in samples {
979 if sample.abs() < DENORM_THRESHOLD {
980 *sample = 0.0;
981 }
982 }
983 }
984}
985
986#[inline]
999pub fn enable_ftz_daz() -> bool {
1000 #[cfg(target_arch = "x86_64")]
1001 {
1002 unsafe {
1006 let mut mxcsr: u32 = 0;
1007 std::arch::asm!("stmxcsr [{}]", in(reg) &mut mxcsr, options(nostack, preserves_flags));
1008 mxcsr |= (1 << 15) | (1 << 6); std::arch::asm!("ldmxcsr [{}]", in(reg) &mxcsr, options(nostack, preserves_flags));
1010 }
1011 true
1012 }
1013
1014 #[cfg(target_arch = "aarch64")]
1015 {
1016 unsafe {
1020 let mut fpcr: u64;
1021 std::arch::asm!("mrs {}, fpcr", out(reg) fpcr);
1022 fpcr |= 1 << 24; std::arch::asm!("msr fpcr, {}", in(reg) fpcr);
1024 }
1025 true
1026 }
1027
1028 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1029 {
1030 false
1031 }
1032}
1033
1034pub struct ScopedFtz {
1048 #[cfg(target_arch = "x86_64")]
1049 saved_mxcsr: Option<u32>,
1050 #[cfg(target_arch = "aarch64")]
1051 saved_fpcr: Option<u64>,
1052}
1053
1054impl ScopedFtz {
1055 #[allow(clippy::needless_return)]
1057 pub fn new() -> Self {
1058 #[cfg(target_arch = "x86_64")]
1059 {
1060 let saved = unsafe {
1061 let mut mxcsr: u32 = 0;
1062 std::arch::asm!(
1063 "stmxcsr [{}]",
1064 in(reg) &mut mxcsr,
1065 options(nostack, preserves_flags)
1066 );
1067 let new_mxcsr = mxcsr | (1 << 15) | (1 << 6); std::arch::asm!(
1069 "ldmxcsr [{}]",
1070 in(reg) &new_mxcsr,
1071 options(nostack, preserves_flags)
1072 );
1073 mxcsr
1074 };
1075 return Self {
1076 saved_mxcsr: Some(saved),
1077 };
1078 }
1079
1080 #[cfg(target_arch = "aarch64")]
1081 {
1082 let saved = unsafe {
1083 let mut fpcr: u64;
1084 std::arch::asm!("mrs {}, fpcr", out(reg) fpcr);
1085 let new_fpcr = fpcr | (1u64 << 24); std::arch::asm!("msr fpcr, {}", in(reg) new_fpcr);
1087 fpcr
1088 };
1089 return Self {
1090 saved_fpcr: Some(saved),
1091 };
1092 }
1093
1094 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1095 Self {}
1096 }
1097}
1098
1099impl Default for ScopedFtz {
1100 fn default() -> Self {
1101 Self::new()
1102 }
1103}
1104
1105impl Drop for ScopedFtz {
1106 fn drop(&mut self) {
1107 #[cfg(target_arch = "x86_64")]
1108 if let Some(saved) = self.saved_mxcsr {
1109 unsafe {
1110 std::arch::asm!(
1111 "ldmxcsr [{}]",
1112 in(reg) &saved,
1113 options(nostack, preserves_flags)
1114 );
1115 }
1116 }
1117
1118 #[cfg(target_arch = "aarch64")]
1119 if let Some(saved) = self.saved_fpcr {
1120 unsafe {
1121 std::arch::asm!("msr fpcr, {}", in(reg) saved);
1122 }
1123 }
1124 }
1125}
1126
1127#[inline]
1129pub fn flush_denormals_complex_inplace(samples: &mut [Complex<f32>]) {
1130 let len = samples.len() * 2;
1133 let ptr = samples.as_mut_ptr() as *mut f32;
1134 let f32_samples = unsafe { std::slice::from_raw_parts_mut(ptr, len) };
1135 flush_denormals_inplace(f32_samples);
1136}
1137
1138#[cfg(test)]
1139mod denorm_tests {
1140 use super::*;
1141
1142 #[test]
1143 fn test_flush_denormals_basic() {
1144 let mut samples = [1e-39_f32, 1e-20, 1e-10, 0.0, -1e-39_f32, 1.0];
1147 flush_denormals_inplace(&mut samples);
1148 assert_eq!(samples[0], 0.0, "subnormal 1e-39 must be zeroed");
1149 assert_eq!(samples[1], 1e-20);
1150 assert_eq!(samples[2], 1e-10);
1151 assert_eq!(samples[3], 0.0);
1152 assert_eq!(samples[4], 0.0, "negative subnormal -1e-39 must be zeroed");
1153 assert_eq!(samples[5], 1.0);
1154 }
1155
1156 #[test]
1158 fn test_flush_denormals_normal_small_not_zeroed() {
1159 let mut samples = [1e-35_f32];
1161 flush_denormals_inplace(&mut samples);
1162 assert!(
1163 samples[0] != 0.0,
1164 "normal value 1e-35 (above f32::MIN_POSITIVE) must not be zeroed"
1165 );
1166 }
1167
1168 #[test]
1170 fn test_flush_denormals_subnormal_zeroed() {
1171 let mut samples = [1e-39_f32];
1173 flush_denormals_inplace(&mut samples);
1174 assert_eq!(
1175 samples[0], 0.0,
1176 "subnormal value 1e-39 (below f32::MIN_POSITIVE) must be zeroed"
1177 );
1178 }
1179
1180 #[test]
1181 fn test_flush_denormals_complex() {
1182 use rustfft::num_complex::Complex;
1183 let mut samples = [
1185 Complex::new(1e-39_f32, 1e-30_f32),
1186 Complex::new(1.0, 1e-39_f32),
1187 Complex::new(0.0, 0.0),
1188 ];
1189 flush_denormals_complex_inplace(&mut samples);
1190 assert_eq!(samples[0].re, 0.0, "subnormal re must be zeroed");
1191 assert!(samples[0].im != 0.0, "normal im 1e-30 must be preserved");
1192 assert_eq!(samples[1].re, 1.0);
1193 assert_eq!(samples[1].im, 0.0, "subnormal im must be zeroed");
1194 assert_eq!(samples[2].re, 0.0);
1195 assert_eq!(samples[2].im, 0.0);
1196 }
1197
1198 #[test]
1199 fn test_flush_denormals_empty() {
1200 let mut samples: [f32; 0] = [];
1201 flush_denormals_inplace(&mut samples);
1202 }
1203
1204 #[test]
1205 fn test_flush_denormals_unaligned() {
1206 let mut samples = [1e-39_f32; 7];
1207 flush_denormals_inplace(&mut samples);
1208 for s in samples.iter() {
1209 assert_eq!(*s, 0.0);
1210 }
1211 }
1212}
1213
1214#[cfg(test)]
1215#[allow(clippy::needless_range_loop)]
1216mod tests {
1217 use super::*;
1218
1219 #[test]
1224 fn test_flush_denormals_basic() {
1225 let mut samples = [1e-39_f32, 1e-20, 1e-10, 0.0, -1e-39_f32, 1.0];
1227 flush_denormals_inplace(&mut samples);
1228 assert_eq!(samples[0], 0.0);
1229 assert_eq!(samples[1], 1e-20);
1230 assert_eq!(samples[2], 1e-10);
1231 assert_eq!(samples[3], 0.0);
1232 assert_eq!(samples[4], 0.0);
1233 assert_eq!(samples[5], 1.0);
1234 }
1235
1236 #[test]
1237 fn test_flush_denormals_complex() {
1238 use rustfft::num_complex::Complex;
1239 let mut samples = [
1241 Complex::new(1e-39_f32, 1e-30_f32),
1242 Complex::new(1.0, 1e-39_f32),
1243 Complex::new(0.0, 0.0),
1244 ];
1245 flush_denormals_complex_inplace(&mut samples);
1246 assert_eq!(samples[0].re, 0.0);
1247 assert!(samples[0].im != 0.0, "normal im 1e-30 must be preserved");
1248 assert_eq!(samples[1].re, 1.0);
1249 assert_eq!(samples[1].im, 0.0);
1250 assert_eq!(samples[2].re, 0.0);
1251 assert_eq!(samples[2].im, 0.0);
1252 }
1253
1254 #[test]
1255 fn test_flush_denormals_empty() {
1256 let mut samples: [f32; 0] = [];
1257 flush_denormals_inplace(&mut samples);
1258 }
1259
1260 #[test]
1261 fn test_flush_denormals_unaligned() {
1262 let mut samples = [1e-39_f32; 7];
1263 flush_denormals_inplace(&mut samples);
1264 for s in samples.iter() {
1265 assert_eq!(*s, 0.0);
1266 }
1267 }
1268
1269 #[test]
1270 fn test_enable_ftz_daz_does_not_panic() {
1271 let result = enable_ftz_daz();
1274 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
1275 assert!(
1276 result,
1277 "enable_ftz_daz should return true on supported platforms"
1278 );
1279 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1280 assert!(
1281 !result,
1282 "enable_ftz_daz should return false on unsupported platforms"
1283 );
1284 }
1285
1286 #[test]
1287 fn test_scoped_ftz_does_not_panic() {
1288 {
1290 let _guard = ScopedFtz::new();
1291 }
1293 }
1295
1296 #[test]
1297 fn test_apply_gain_simd_known_values() {
1298 let mut buffer = vec![1.0, 2.0, 3.0, 4.0, -1.0, 0.5, -0.5, 0.0, 1.5];
1300 let expected: Vec<f32> = buffer.iter().map(|&x| x * 2.0).collect();
1301 apply_gain_simd(&mut buffer, 2.0);
1302 for (i, (&got, &exp)) in buffer.iter().zip(expected.iter()).enumerate() {
1303 assert!(
1304 (got - exp).abs() < 1e-6,
1305 "apply_gain_simd mismatch at index {}: got {}, expected {}",
1306 i,
1307 got,
1308 exp
1309 );
1310 }
1311 }
1312
1313 #[test]
1314 fn test_apply_gain_simd_zero_gain() {
1315 let mut buffer = vec![1.0, -2.0, 3.5, 0.7];
1316 apply_gain_simd(&mut buffer, 0.0);
1317 for (i, &v) in buffer.iter().enumerate() {
1318 assert_eq!(
1319 v, 0.0,
1320 "apply_gain_simd with zero gain: index {} not zero",
1321 i
1322 );
1323 }
1324 }
1325
1326 #[test]
1327 fn test_apply_gain_simd_unity_gain() {
1328 let original = vec![1.0, -2.0, 3.5, 0.7, 0.0, -0.1];
1329 let mut buffer = original.clone();
1330 apply_gain_simd(&mut buffer, 1.0);
1331 assert_eq!(buffer, original);
1332 }
1333
1334 #[test]
1335 fn test_apply_per_channel_gain_simd_stereo() {
1336 let mut buffer = vec![1.0, 2.0, 3.0, 4.0]; let gains = vec![0.5, 2.0]; apply_per_channel_gain_simd(&mut buffer, 2, &gains);
1340 assert!((buffer[0] - 0.5).abs() < 1e-6, "L frame 0");
1341 assert!((buffer[1] - 4.0).abs() < 1e-6, "R frame 0");
1342 assert!((buffer[2] - 1.5).abs() < 1e-6, "L frame 1");
1343 assert!((buffer[3] - 8.0).abs() < 1e-6, "R frame 1");
1344 }
1345
1346 #[test]
1354 fn test_simd_complex_mul_add_correctness() {
1355 use rustfft::num_complex::Complex;
1357
1358 let src = vec![
1360 Complex::new(1.0, 2.0),
1361 Complex::new(3.0, 4.0),
1362 Complex::new(-1.0, 0.5),
1363 Complex::new(0.0, -2.0),
1364 Complex::new(2.5, -1.5),
1365 Complex::new(-3.5, 2.5),
1366 Complex::new(1.1, -0.9),
1367 Complex::new(-0.8, 1.2),
1368 ];
1369
1370 let hrtf = vec![
1371 Complex::new(0.5, 0.25),
1372 Complex::new(-1.0, 1.5),
1373 Complex::new(2.0, -0.5),
1374 Complex::new(0.75, 0.75),
1375 Complex::new(-0.5, 2.0),
1376 Complex::new(1.5, -1.0),
1377 Complex::new(0.9, 0.3),
1378 Complex::new(-1.1, 0.7),
1379 ];
1380
1381 let initial = vec![
1382 Complex::new(0.1, 0.2),
1383 Complex::new(0.3, 0.4),
1384 Complex::new(0.5, 0.6),
1385 Complex::new(0.7, 0.8),
1386 Complex::new(0.9, 1.0),
1387 Complex::new(1.1, 1.2),
1388 Complex::new(1.3, 1.4),
1389 Complex::new(1.5, 1.6),
1390 ];
1391
1392 let mut expected = initial.clone();
1394 for i in 0..src.len() {
1395 expected[i] += src[i] * hrtf[i];
1396 }
1397
1398 let mut result = initial.clone();
1400 complex_mul_add_simd(&mut result, &src, &hrtf);
1401
1402 const EPSILON: f32 = 1e-6;
1404 for i in 0..src.len() {
1405 assert!(
1406 (result[i].re - expected[i].re).abs() < EPSILON,
1407 "SIMD result[{}].re = {}, expected = {} (diff = {})",
1408 i,
1409 result[i].re,
1410 expected[i].re,
1411 (result[i].re - expected[i].re).abs()
1412 );
1413 assert!(
1414 (result[i].im - expected[i].im).abs() < EPSILON,
1415 "SIMD result[{}].im = {}, expected = {} (diff = {})",
1416 i,
1417 result[i].im,
1418 expected[i].im,
1419 (result[i].im - expected[i].im).abs()
1420 );
1421 }
1422 }
1423
1424 #[test]
1425 fn test_simd_complex_mul_correctness() {
1426 use rustfft::num_complex::Complex;
1428
1429 let src = vec![
1430 Complex::new(2.0, 3.0),
1431 Complex::new(-1.5, 2.5),
1432 Complex::new(0.5, -1.0),
1433 Complex::new(4.0, -2.0),
1434 ];
1435
1436 let hrtf = vec![
1437 Complex::new(1.0, 0.5),
1438 Complex::new(2.0, -1.0),
1439 Complex::new(-0.5, 1.5),
1440 Complex::new(0.75, 0.25),
1441 ];
1442
1443 let expected: Vec<Complex<f32>> = src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
1445
1446 let mut result = vec![Complex::new(0.0, 0.0); src.len()];
1448 complex_mul_simd(&mut result, &src, &hrtf);
1449
1450 const EPSILON: f32 = 1e-6;
1452 for i in 0..src.len() {
1453 assert!(
1454 (result[i].re - expected[i].re).abs() < EPSILON,
1455 "SIMD result[{}].re = {}, expected = {}",
1456 i,
1457 result[i].re,
1458 expected[i].re
1459 );
1460 assert!(
1461 (result[i].im - expected[i].im).abs() < EPSILON,
1462 "SIMD result[{}].im = {}, expected = {}",
1463 i,
1464 result[i].im,
1465 expected[i].im
1466 );
1467 }
1468 }
1469
1470 #[test]
1471 fn test_simd_edge_cases() {
1472 use rustfft::num_complex::Complex;
1474
1475 let src = vec![
1477 Complex::new(1.0, 2.0),
1478 Complex::new(3.0, 4.0),
1479 Complex::new(5.0, 6.0),
1480 Complex::new(7.0, 8.0),
1481 ];
1482 let zero = vec![Complex::new(0.0, 0.0); 4];
1483 let mut result = src.clone();
1484 let input = result.clone();
1485 complex_mul_simd(&mut result, &input, &zero);
1486 for i in 0..4 {
1487 assert_eq!(result[i].re, 0.0);
1488 assert_eq!(result[i].im, 0.0);
1489 }
1490
1491 let one = vec![Complex::new(1.0, 0.0); 4];
1493 let mut result = vec![Complex::new(0.0, 0.0); 4];
1494 complex_mul_simd(&mut result, &src, &one);
1495 for i in 0..4 {
1496 assert!((result[i].re - src[i].re).abs() < 1e-6);
1497 assert!((result[i].im - src[i].im).abs() < 1e-6);
1498 }
1499
1500 let a = Complex::new(3.0, 4.0);
1502 let a_conj = Complex::new(3.0, -4.0);
1503 let src = vec![a, a, a, a];
1504 let conj = vec![a_conj, a_conj, a_conj, a_conj];
1505 let mut result = vec![Complex::new(0.0, 0.0); 4];
1506 complex_mul_simd(&mut result, &src, &conj);
1507
1508 for i in 0..4 {
1510 assert!((result[i].re - 25.0).abs() < 1e-5);
1511 assert!(result[i].im.abs() < 1e-5); }
1513 }
1514
1515 #[test]
1516 fn test_simd_large_buffer() {
1517 use rustfft::num_complex::Complex;
1519
1520 for fft_size in [512, 1024, 2048, 4096] {
1521 let mut src = Vec::with_capacity(fft_size);
1522 let mut hrtf = Vec::with_capacity(fft_size);
1523
1524 for i in 0..fft_size {
1526 let phase = (i as f32) * 0.01;
1527 src.push(Complex::new(phase.cos(), phase.sin()));
1528 hrtf.push(Complex::new(0.5, 0.25));
1529 }
1530
1531 let mut expected = vec![Complex::new(0.1, 0.2); fft_size];
1533 for i in 0..fft_size {
1534 expected[i] += src[i] * hrtf[i];
1535 }
1536
1537 let mut result = vec![Complex::new(0.1, 0.2); fft_size];
1539 complex_mul_add_simd(&mut result, &src, &hrtf);
1540
1541 for i in 0..fft_size {
1543 assert!(
1544 (result[i].re - expected[i].re).abs() < 1e-5,
1545 "FFT size {}, index {}: SIMD mismatch",
1546 fft_size,
1547 i
1548 );
1549 assert!(
1550 (result[i].im - expected[i].im).abs() < 1e-5,
1551 "FFT size {}, index {}: SIMD mismatch",
1552 fft_size,
1553 i
1554 );
1555 }
1556 }
1557 }
1558
1559 #[test]
1560 fn test_simd_unaligned_sizes() {
1561 use rustfft::num_complex::Complex;
1564
1565 for size in [1, 3, 5, 7, 9, 13, 17] {
1566 let src: Vec<Complex<f32>> = (0..size)
1567 .map(|i| Complex::new(i as f32, (i as f32) * 0.5))
1568 .collect();
1569 let hrtf: Vec<Complex<f32>> = (0..size)
1570 .map(|i| Complex::new(0.5, (i as f32) * 0.1))
1571 .collect();
1572
1573 let expected: Vec<Complex<f32>> =
1575 src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
1576
1577 let mut result = vec![Complex::new(0.0, 0.0); size];
1579 complex_mul_simd(&mut result, &src, &hrtf);
1580
1581 for i in 0..size {
1583 assert!(
1584 (result[i].re - expected[i].re).abs() < 1e-6,
1585 "Size {}, index {}: re mismatch",
1586 size,
1587 i
1588 );
1589 assert!(
1590 (result[i].im - expected[i].im).abs() < 1e-6,
1591 "Size {}, index {}: im mismatch",
1592 size,
1593 i
1594 );
1595 }
1596 }
1597 }
1598
1599 #[test]
1604 fn test_simd_complex_mul_inplace_correctness() {
1605 use rustfft::num_complex::Complex;
1607
1608 let src = vec![
1609 Complex::new(2.0, 3.0),
1610 Complex::new(-1.5, 2.5),
1611 Complex::new(0.5, -1.0),
1612 Complex::new(4.0, -2.0),
1613 ];
1614
1615 let hrtf = vec![
1616 Complex::new(1.0, 0.5),
1617 Complex::new(2.0, -1.0),
1618 Complex::new(-0.5, 1.5),
1619 Complex::new(0.75, 0.25),
1620 ];
1621
1622 let mut expected = src.clone();
1624 for i in 0..expected.len() {
1625 expected[i] *= hrtf[i];
1626 }
1627
1628 let mut result = src.clone();
1630 complex_mul_inplace_simd(&mut result, &hrtf);
1631
1632 const EPSILON: f32 = 1e-6;
1633 for i in 0..result.len() {
1634 assert!(
1635 (result[i].re - expected[i].re).abs() < EPSILON,
1636 "Index {}: re mismatch {} vs {}",
1637 i,
1638 result[i].re,
1639 expected[i].re
1640 );
1641 assert!(
1642 (result[i].im - expected[i].im).abs() < EPSILON,
1643 "Index {}: im mismatch {} vs {}",
1644 i,
1645 result[i].im,
1646 expected[i].im
1647 );
1648 }
1649 }
1650
1651 #[test]
1652 fn test_simd_inplace_large_buffers() {
1653 use rustfft::num_complex::Complex;
1655
1656 for fft_size in [128, 256, 512, 1024, 2048] {
1657 let mut src: Vec<Complex<f32>> = (0..fft_size)
1658 .map(|i| {
1659 let phase = (i as f32) * 0.01;
1660 Complex::new(phase.cos(), phase.sin())
1661 })
1662 .collect();
1663
1664 let hrtf: Vec<Complex<f32>> = (0..fft_size)
1665 .map(|i| Complex::new(0.5 + (i as f32) * 0.001, 0.25))
1666 .collect();
1667
1668 let mut expected = src.clone();
1670 for i in 0..fft_size {
1671 expected[i] *= hrtf[i];
1672 }
1673
1674 complex_mul_inplace_simd(&mut src, &hrtf);
1676
1677 for i in 0..fft_size {
1679 assert!(
1680 (src[i].re - expected[i].re).abs() < 1e-5,
1681 "FFT size {}, index {}: re mismatch",
1682 fft_size,
1683 i
1684 );
1685 assert!(
1686 (src[i].im - expected[i].im).abs() < 1e-5,
1687 "FFT size {}, index {}: im mismatch",
1688 fft_size,
1689 i
1690 );
1691 }
1692 }
1693 }
1694
1695 #[test]
1696 fn test_simd_inplace_unaligned() {
1697 use rustfft::num_complex::Complex;
1699
1700 for size in [1, 2, 3, 5, 6, 7, 9, 10, 11, 15, 17, 19, 23] {
1701 let mut src: Vec<Complex<f32>> = (0..size)
1702 .map(|i| Complex::new((i as f32) * 0.5, (i as f32) * -0.3))
1703 .collect();
1704
1705 let hrtf: Vec<Complex<f32>> = (0..size)
1706 .map(|i| Complex::new(1.0 + (i as f32) * 0.1, 0.5))
1707 .collect();
1708
1709 let mut expected = src.clone();
1711 for i in 0..size {
1712 expected[i] *= hrtf[i];
1713 }
1714
1715 complex_mul_inplace_simd(&mut src, &hrtf);
1717
1718 for i in 0..size {
1720 assert!(
1721 (src[i].re - expected[i].re).abs() < 1e-6,
1722 "Size {}, index {}: re mismatch",
1723 size,
1724 i
1725 );
1726 assert!(
1727 (src[i].im - expected[i].im).abs() < 1e-6,
1728 "Size {}, index {}: im mismatch",
1729 size,
1730 i
1731 );
1732 }
1733 }
1734 }
1735
1736 #[test]
1737 fn test_simd_inplace_edge_cases() {
1738 use rustfft::num_complex::Complex;
1740
1741 let mut src = vec![
1743 Complex::new(1.0, 2.0),
1744 Complex::new(3.0, 4.0),
1745 Complex::new(5.0, 6.0),
1746 Complex::new(7.0, 8.0),
1747 ];
1748 let zero = vec![Complex::new(0.0, 0.0); 4];
1749 complex_mul_inplace_simd(&mut src, &zero);
1750 for i in 0..4 {
1751 assert!(src[i].re.abs() < 1e-6, "Expected zero, got {}", src[i].re);
1752 assert!(src[i].im.abs() < 1e-6, "Expected zero, got {}", src[i].im);
1753 }
1754
1755 let original = vec![
1757 Complex::new(1.5, 2.5),
1758 Complex::new(-3.5, 4.5),
1759 Complex::new(5.5, -6.5),
1760 Complex::new(-7.5, -8.5),
1761 ];
1762 let mut src = original.clone();
1763 let one = vec![Complex::new(1.0, 0.0); 4];
1764 complex_mul_inplace_simd(&mut src, &one);
1765 for i in 0..4 {
1766 assert!((src[i].re - original[i].re).abs() < 1e-6);
1767 assert!((src[i].im - original[i].im).abs() < 1e-6);
1768 }
1769
1770 let a = Complex::new(3.0, 4.0);
1772 let a_conj = Complex::new(3.0, -4.0);
1773 let mut src = vec![a; 8];
1774 let conj = vec![a_conj; 8];
1775 complex_mul_inplace_simd(&mut src, &conj);
1776 for i in 0..8 {
1778 assert!(
1779 (src[i].re - 25.0).abs() < 1e-5,
1780 "Expected 25.0, got {}",
1781 src[i].re
1782 );
1783 assert!(src[i].im.abs() < 1e-5, "Expected ~0, got {}", src[i].im);
1784 }
1785
1786 let mut src = vec![Complex::new(1.0, 0.0); 4];
1788 let i_val = vec![Complex::new(0.0, 1.0); 4];
1789 complex_mul_inplace_simd(&mut src, &i_val);
1790 for idx in 0..4 {
1791 assert!(src[idx].re.abs() < 1e-6, "Expected 0, got {}", src[idx].re);
1792 assert!(
1793 (src[idx].im - 1.0).abs() < 1e-6,
1794 "Expected 1, got {}",
1795 src[idx].im
1796 );
1797 }
1798 }
1799
1800 #[test]
1801 fn test_simd_inplace_negative_values() {
1802 use rustfft::num_complex::Complex;
1804
1805 let mut src = vec![
1806 Complex::new(-1.0, -2.0),
1807 Complex::new(-3.0, -4.0),
1808 Complex::new(-5.0, -6.0),
1809 Complex::new(-7.0, -8.0),
1810 ];
1811
1812 let hrtf = vec![
1813 Complex::new(-0.5, -0.25),
1814 Complex::new(-1.0, -1.5),
1815 Complex::new(-2.0, 0.5),
1816 Complex::new(0.75, -0.75),
1817 ];
1818
1819 let mut expected = src.clone();
1821 for i in 0..expected.len() {
1822 expected[i] *= hrtf[i];
1823 }
1824
1825 complex_mul_inplace_simd(&mut src, &hrtf);
1827
1828 const EPSILON: f32 = 1e-6;
1829 for i in 0..src.len() {
1830 assert!((src[i].re - expected[i].re).abs() < EPSILON);
1831 assert!((src[i].im - expected[i].im).abs() < EPSILON);
1832 }
1833 }
1834
1835 #[test]
1840 fn test_covariance_basic_correctness() {
1841 use rustfft::num_complex::Complex;
1843
1844 let left = vec![
1845 Complex::new(1.0, 2.0),
1846 Complex::new(3.0, 4.0),
1847 Complex::new(-1.0, 0.5),
1848 Complex::new(0.0, -2.0),
1849 Complex::new(2.5, -1.5),
1850 Complex::new(-3.5, 2.5),
1851 Complex::new(1.1, -0.9),
1852 Complex::new(-0.8, 1.2),
1853 ];
1854
1855 let right = vec![
1856 Complex::new(0.5, 0.25),
1857 Complex::new(-1.0, 1.5),
1858 Complex::new(2.0, -0.5),
1859 Complex::new(0.75, 0.75),
1860 Complex::new(-0.5, 2.0),
1861 Complex::new(1.5, -1.0),
1862 Complex::new(0.9, 0.3),
1863 Complex::new(-1.1, 0.7),
1864 ];
1865
1866 let mut expected_xx = 0.0_f32;
1868 let mut expected_yy = 0.0_f32;
1869 let mut expected_xy = Complex::new(0.0, 0.0);
1870 for i in 0..left.len() {
1871 expected_xx += left[i].norm_sqr();
1872 expected_yy += right[i].norm_sqr();
1873 expected_xy += left[i] * right[i].conj();
1874 }
1875
1876 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, 0, left.len());
1878
1879 const EPSILON: f32 = 1e-5;
1880 assert!(
1881 (cov_xx - expected_xx).abs() < EPSILON,
1882 "cov_xx mismatch: {} vs {}",
1883 cov_xx,
1884 expected_xx
1885 );
1886 assert!(
1887 (cov_yy - expected_yy).abs() < EPSILON,
1888 "cov_yy mismatch: {} vs {}",
1889 cov_yy,
1890 expected_yy
1891 );
1892 assert!(
1893 (cov_xy.re - expected_xy.re).abs() < EPSILON,
1894 "cov_xy.re mismatch: {} vs {}",
1895 cov_xy.re,
1896 expected_xy.re
1897 );
1898 assert!(
1899 (cov_xy.im - expected_xy.im).abs() < EPSILON,
1900 "cov_xy.im mismatch: {} vs {}",
1901 cov_xy.im,
1902 expected_xy.im
1903 );
1904 }
1905
1906 #[test]
1907 fn test_covariance_with_ranges() {
1908 use rustfft::num_complex::Complex;
1910
1911 let left: Vec<Complex<f32>> = (0..32)
1912 .map(|i| Complex::new(i as f32 * 0.5, i as f32 * -0.3))
1913 .collect();
1914 let right: Vec<Complex<f32>> = (0..32)
1915 .map(|i| Complex::new(i as f32 * -0.4, i as f32 * 0.6))
1916 .collect();
1917
1918 for (start, end) in [(0, 8), (4, 12), (10, 20), (5, 25), (0, 32)] {
1920 let mut expected_xx = 0.0_f32;
1922 let mut expected_yy = 0.0_f32;
1923 let mut expected_xy = Complex::new(0.0, 0.0);
1924 for i in start..end {
1925 expected_xx += left[i].norm_sqr();
1926 expected_yy += right[i].norm_sqr();
1927 expected_xy += left[i] * right[i].conj();
1928 }
1929
1930 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, start, end);
1932
1933 const EPSILON: f32 = 1e-4;
1934 assert!(
1935 (cov_xx - expected_xx).abs() < EPSILON,
1936 "Range [{}, {}): cov_xx mismatch: {} vs {}",
1937 start,
1938 end,
1939 cov_xx,
1940 expected_xx
1941 );
1942 assert!(
1943 (cov_yy - expected_yy).abs() < EPSILON,
1944 "Range [{}, {}): cov_yy mismatch: {} vs {}",
1945 start,
1946 end,
1947 cov_yy,
1948 expected_yy
1949 );
1950 assert!(
1951 (cov_xy.re - expected_xy.re).abs() < EPSILON,
1952 "Range [{}, {}): cov_xy.re mismatch: {} vs {}",
1953 start,
1954 end,
1955 cov_xy.re,
1956 expected_xy.re
1957 );
1958 assert!(
1959 (cov_xy.im - expected_xy.im).abs() < EPSILON,
1960 "Range [{}, {}): cov_xy.im mismatch: {} vs {}",
1961 start,
1962 end,
1963 cov_xy.im,
1964 expected_xy.im
1965 );
1966 }
1967 }
1968
1969 #[test]
1970 fn test_covariance_large_buffers() {
1971 use rustfft::num_complex::Complex;
1973
1974 for fft_size in [128, 256, 512, 1024, 2048, 4096] {
1975 let left: Vec<Complex<f32>> = (0..fft_size)
1976 .map(|i| {
1977 let phase = (i as f32) * 0.01;
1978 Complex::new(phase.cos(), phase.sin())
1979 })
1980 .collect();
1981
1982 let right: Vec<Complex<f32>> = (0..fft_size)
1983 .map(|i| {
1984 let phase = (i as f32) * 0.02;
1985 Complex::new(phase.sin(), phase.cos())
1986 })
1987 .collect();
1988
1989 let mut expected_xx = 0.0_f32;
1991 let mut expected_yy = 0.0_f32;
1992 let mut expected_xy = Complex::new(0.0, 0.0);
1993 for i in 0..fft_size {
1994 expected_xx += left[i].norm_sqr();
1995 expected_yy += right[i].norm_sqr();
1996 expected_xy += left[i] * right[i].conj();
1997 }
1998
1999 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, 0, fft_size);
2001
2002 let rel_epsilon = 1e-4;
2004 assert!(
2005 (cov_xx - expected_xx).abs() < expected_xx * rel_epsilon,
2006 "FFT size {}: cov_xx mismatch",
2007 fft_size
2008 );
2009 assert!(
2010 (cov_yy - expected_yy).abs() < expected_yy * rel_epsilon,
2011 "FFT size {}: cov_yy mismatch",
2012 fft_size
2013 );
2014 assert!(
2015 (cov_xy.re - expected_xy.re).abs() < expected_xy.re.abs() * rel_epsilon + 1e-5,
2016 "FFT size {}: cov_xy.re mismatch",
2017 fft_size
2018 );
2019 assert!(
2020 (cov_xy.im - expected_xy.im).abs() < expected_xy.im.abs() * rel_epsilon + 1e-5,
2021 "FFT size {}: cov_xy.im mismatch",
2022 fft_size
2023 );
2024 }
2025 }
2026
2027 #[test]
2028 fn test_covariance_unaligned_ranges() {
2029 use rustfft::num_complex::Complex;
2031
2032 let left: Vec<Complex<f32>> = (0..50)
2033 .map(|i| Complex::new(i as f32 * 0.2, i as f32 * 0.3))
2034 .collect();
2035 let right: Vec<Complex<f32>> = (0..50)
2036 .map(|i| Complex::new(i as f32 * -0.1, i as f32 * 0.4))
2037 .collect();
2038
2039 for (start, end) in [(0, 1), (0, 3), (1, 4), (2, 7), (5, 11), (10, 23), (15, 37)] {
2041 let mut expected_xx = 0.0_f32;
2043 let mut expected_yy = 0.0_f32;
2044 let mut expected_xy = Complex::new(0.0, 0.0);
2045 for i in start..end {
2046 expected_xx += left[i].norm_sqr();
2047 expected_yy += right[i].norm_sqr();
2048 expected_xy += left[i] * right[i].conj();
2049 }
2050
2051 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, start, end);
2053
2054 const EPSILON: f32 = 1e-5;
2055 assert!(
2056 (cov_xx - expected_xx).abs() < EPSILON,
2057 "Range [{}, {}): cov_xx mismatch",
2058 start,
2059 end
2060 );
2061 assert!(
2062 (cov_yy - expected_yy).abs() < EPSILON,
2063 "Range [{}, {}): cov_yy mismatch",
2064 start,
2065 end
2066 );
2067 assert!(
2068 (cov_xy.re - expected_xy.re).abs() < EPSILON,
2069 "Range [{}, {}): cov_xy.re mismatch",
2070 start,
2071 end
2072 );
2073 assert!(
2074 (cov_xy.im - expected_xy.im).abs() < EPSILON,
2075 "Range [{}, {}): cov_xy.im mismatch",
2076 start,
2077 end
2078 );
2079 }
2080 }
2081
2082 #[test]
2083 fn test_covariance_edge_cases() {
2084 use rustfft::num_complex::Complex;
2086
2087 let zero_left = vec![Complex::new(0.0, 0.0); 8];
2089 let zero_right = vec![Complex::new(0.0, 0.0); 8];
2090 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&zero_left, &zero_right, 0, 8);
2091 assert!(cov_xx.abs() < 1e-6, "Expected zero cov_xx");
2092 assert!(cov_yy.abs() < 1e-6, "Expected zero cov_yy");
2093 assert!(cov_xy.norm_sqr() < 1e-6, "Expected zero cov_xy");
2094
2095 let real_left: Vec<Complex<f32>> = (0..8).map(|i| Complex::new(i as f32, 0.0)).collect();
2097 let real_right: Vec<Complex<f32>> = (0..8)
2098 .map(|i| Complex::new((i as f32) * 0.5, 0.0))
2099 .collect();
2100 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&real_left, &real_right, 0, 8);
2101
2102 assert!(
2104 cov_xy.im.abs() < 1e-5,
2105 "Expected real cov_xy for real signals"
2106 );
2107
2108 let mut expected_xx = 0.0;
2110 let mut expected_yy = 0.0;
2111 for i in 0..8 {
2112 expected_xx += (i * i) as f32;
2113 expected_yy += ((i as f32) * 0.5).powi(2);
2114 }
2115 assert!((cov_xx - expected_xx).abs() < 1e-5);
2116 assert!((cov_yy - expected_yy).abs() < 1e-5);
2117
2118 let imag_left: Vec<Complex<f32>> = (0..8).map(|i| Complex::new(0.0, i as f32)).collect();
2120 let imag_right: Vec<Complex<f32>> = (0..8)
2121 .map(|i| Complex::new(0.0, (i as f32) * 2.0))
2122 .collect();
2123 let (_cov_xx, _cov_yy, cov_xy) = compute_covariance_simd(&imag_left, &imag_right, 0, 8);
2124
2125 assert!(
2127 cov_xy.im.abs() < 1e-5,
2128 "Expected real cov_xy for imaginary signals"
2129 );
2130
2131 let single_left = vec![Complex::new(3.0, 4.0)];
2133 let single_right = vec![Complex::new(1.0, 2.0)];
2134 let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&single_left, &single_right, 0, 1);
2135 assert!((cov_xx - 25.0).abs() < 1e-5); assert!((cov_yy - 5.0).abs() < 1e-5); assert!((cov_xy.re - 11.0).abs() < 1e-5);
2139 assert!((cov_xy.im - (-2.0)).abs() < 1e-5);
2140 }
2141
2142 #[test]
2147 fn test_numerical_accuracy_small_values() {
2148 use rustfft::num_complex::Complex;
2150
2151 let small = 1e-20_f32;
2152 let src = vec![
2153 Complex::new(small, small),
2154 Complex::new(small * 2.0, small * 3.0),
2155 Complex::new(small * 4.0, small * 5.0),
2156 Complex::new(small * 6.0, small * 7.0),
2157 ];
2158
2159 let hrtf = vec![
2160 Complex::new(1.0, 0.5),
2161 Complex::new(2.0, -1.0),
2162 Complex::new(-0.5, 1.5),
2163 Complex::new(0.75, 0.25),
2164 ];
2165
2166 let expected: Vec<Complex<f32>> = src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2168
2169 let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2171 complex_mul_simd(&mut result, &src, &hrtf);
2172
2173 for i in 0..src.len() {
2175 let re_diff = (result[i].re - expected[i].re).abs();
2176 let im_diff = (result[i].im - expected[i].im).abs();
2177
2178 if expected[i].re.abs() > 1e-15 {
2180 assert!(re_diff / expected[i].re.abs() < 1e-3);
2181 } else {
2182 assert!(re_diff < 1e-25);
2183 }
2184
2185 if expected[i].im.abs() > 1e-15 {
2186 assert!(im_diff / expected[i].im.abs() < 1e-3);
2187 } else {
2188 assert!(im_diff < 1e-25);
2189 }
2190 }
2191 }
2192
2193 #[test]
2194 fn test_numerical_accuracy_large_values() {
2195 use rustfft::num_complex::Complex;
2197
2198 let large = 1e10_f32;
2199 let src = vec![
2200 Complex::new(large, large * 0.5),
2201 Complex::new(large * 2.0, large * 1.5),
2202 Complex::new(large * 0.3, large * 0.7),
2203 Complex::new(large * 1.2, large * 0.8),
2204 ];
2205
2206 let hrtf = vec![
2207 Complex::new(1e-5, 5e-6),
2208 Complex::new(2e-5, -1e-5),
2209 Complex::new(-5e-6, 1.5e-5),
2210 Complex::new(7.5e-6, 2.5e-6),
2211 ];
2212
2213 let mut expected = vec![Complex::new(0.0, 0.0); src.len()];
2215 for i in 0..src.len() {
2216 expected[i] = src[i] * hrtf[i];
2217 }
2218
2219 let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2221 complex_mul_simd(&mut result, &src, &hrtf);
2222
2223 for i in 0..src.len() {
2225 let re_rel_err = (result[i].re - expected[i].re).abs() / expected[i].re.abs().max(1.0);
2226 let im_rel_err = (result[i].im - expected[i].im).abs() / expected[i].im.abs().max(1.0);
2227 assert!(
2228 re_rel_err < 1e-5,
2229 "Index {}: re rel error too large: {}",
2230 i,
2231 re_rel_err
2232 );
2233 assert!(
2234 im_rel_err < 1e-5,
2235 "Index {}: im rel error too large: {}",
2236 i,
2237 im_rel_err
2238 );
2239 }
2240 }
2241
2242 #[test]
2243 fn test_accumulation_accuracy() {
2244 use rustfft::num_complex::Complex;
2246
2247 let src = vec![
2248 Complex::new(0.1, 0.2),
2249 Complex::new(0.3, 0.4),
2250 Complex::new(0.5, 0.6),
2251 Complex::new(0.7, 0.8),
2252 ];
2253
2254 let hrtf = vec![
2255 Complex::new(0.5, 0.25),
2256 Complex::new(-1.0, 1.5),
2257 Complex::new(2.0, -0.5),
2258 Complex::new(0.75, 0.75),
2259 ];
2260
2261 let mut expected = vec![Complex::new(0.0, 0.0); src.len()];
2263 for _ in 0..100 {
2264 for i in 0..src.len() {
2265 expected[i] += src[i] * hrtf[i];
2266 }
2267 }
2268
2269 let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2271 for _ in 0..100 {
2272 complex_mul_add_simd(&mut result, &src, &hrtf);
2273 }
2274
2275 const REL_EPSILON: f32 = 1e-4;
2277 for i in 0..src.len() {
2278 let re_abs_err = (result[i].re - expected[i].re).abs();
2279 let im_abs_err = (result[i].im - expected[i].im).abs();
2280
2281 let re_err = if expected[i].re.abs() > 1e-6 {
2283 re_abs_err / expected[i].re.abs()
2284 } else {
2285 re_abs_err
2286 };
2287 let im_err = if expected[i].im.abs() > 1e-6 {
2288 im_abs_err / expected[i].im.abs()
2289 } else {
2290 im_abs_err
2291 };
2292
2293 assert!(
2294 re_err < REL_EPSILON,
2295 "Index {}: re accumulated error too large: {} (abs: {}, expected: {})",
2296 i,
2297 re_err,
2298 re_abs_err,
2299 expected[i].re
2300 );
2301 assert!(
2302 im_err < REL_EPSILON,
2303 "Index {}: im accumulated error too large: {} (abs: {}, expected: {})",
2304 i,
2305 im_err,
2306 im_abs_err,
2307 expected[i].im
2308 );
2309 }
2310 }
2311
2312 #[test]
2313 fn test_platform_specific_simd_widths() {
2314 use rustfft::num_complex::Complex;
2316
2317 let test_sizes = vec![
2321 1, 2, 3, 4, 5, 8, 9, 12, 16, ];
2331
2332 for size in test_sizes {
2333 let src: Vec<Complex<f32>> = (0..size)
2334 .map(|i| Complex::new(i as f32 * 0.3, i as f32 * -0.2))
2335 .collect();
2336 let hrtf: Vec<Complex<f32>> = (0..size)
2337 .map(|i| Complex::new(1.0 + i as f32 * 0.1, 0.5))
2338 .collect();
2339
2340 let mut result_add = vec![Complex::new(1.0, 2.0); size];
2344 let mut expected_add = result_add.clone();
2345 for i in 0..size {
2346 expected_add[i] += src[i] * hrtf[i];
2347 }
2348 complex_mul_add_simd(&mut result_add, &src, &hrtf);
2349 for i in 0..size {
2350 assert!(
2351 (result_add[i].re - expected_add[i].re).abs() < 1e-6,
2352 "mul_add size {}, index {}: re mismatch",
2353 size,
2354 i
2355 );
2356 assert!(
2357 (result_add[i].im - expected_add[i].im).abs() < 1e-6,
2358 "mul_add size {}, index {}: im mismatch",
2359 size,
2360 i
2361 );
2362 }
2363
2364 let mut result_mul = vec![Complex::new(0.0, 0.0); size];
2366 let expected_mul: Vec<Complex<f32>> =
2367 src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2368 complex_mul_simd(&mut result_mul, &src, &hrtf);
2369 for i in 0..size {
2370 assert!(
2371 (result_mul[i].re - expected_mul[i].re).abs() < 1e-6,
2372 "mul size {}, index {}: re mismatch",
2373 size,
2374 i
2375 );
2376 assert!(
2377 (result_mul[i].im - expected_mul[i].im).abs() < 1e-6,
2378 "mul size {}, index {}: im mismatch",
2379 size,
2380 i
2381 );
2382 }
2383
2384 let mut result_inplace = src.clone();
2386 let mut expected_inplace = src.clone();
2387 for i in 0..size {
2388 expected_inplace[i] *= hrtf[i];
2389 }
2390 complex_mul_inplace_simd(&mut result_inplace, &hrtf);
2391 for i in 0..size {
2392 assert!(
2393 (result_inplace[i].re - expected_inplace[i].re).abs() < 1e-6,
2394 "inplace size {}, index {}: re mismatch",
2395 size,
2396 i
2397 );
2398 assert!(
2399 (result_inplace[i].im - expected_inplace[i].im).abs() < 1e-6,
2400 "inplace size {}, index {}: im mismatch",
2401 size,
2402 i
2403 );
2404 }
2405 }
2406 }
2407
2408 #[test]
2409 fn test_stress_test_random_data() {
2410 use rustfft::num_complex::Complex;
2412
2413 let mut seed = 12345_u32;
2415 let lcg = |s: &mut u32| -> f32 {
2416 *s = s.wrapping_mul(1103515245).wrapping_add(12345);
2417 ((*s / 65536) % 32768) as f32 / 32768.0 - 0.5
2418 };
2419
2420 for size in [64, 128, 256, 512] {
2421 let src: Vec<Complex<f32>> = (0..size)
2422 .map(|_| Complex::new(lcg(&mut seed), lcg(&mut seed)))
2423 .collect();
2424 let hrtf: Vec<Complex<f32>> = (0..size)
2425 .map(|_| Complex::new(lcg(&mut seed), lcg(&mut seed)))
2426 .collect();
2427
2428 let expected: Vec<Complex<f32>> =
2430 src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2431
2432 let mut result = vec![Complex::new(0.0, 0.0); size];
2434 complex_mul_simd(&mut result, &src, &hrtf);
2435
2436 for i in 0..size {
2438 assert!(
2439 (result[i].re - expected[i].re).abs() < 1e-5,
2440 "Stress test size {}, index {}: re mismatch",
2441 size,
2442 i
2443 );
2444 assert!(
2445 (result[i].im - expected[i].im).abs() < 1e-5,
2446 "Stress test size {}, index {}: im mismatch",
2447 size,
2448 i
2449 );
2450 }
2451 }
2452 }
2453}
2454
2455pub fn compute_covariance_simd(
2473 left: &[Complex<f32>],
2474 right: &[Complex<f32>],
2475 start: usize,
2476 end: usize,
2477) -> (f32, f32, Complex<f32>) {
2478 assert_eq!(left.len(), right.len());
2479 assert!(end <= left.len());
2480 assert!(start < end);
2481
2482 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2483 let count = end - start;
2484
2485 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2486 {
2487 use std::arch::x86_64::*;
2488
2489 let mut cov_xx;
2490 let mut cov_yy;
2491 let mut cov_xy = Complex::new(0.0, 0.0);
2492
2493 let simd_len = (count / 4) * 4;
2495 let simd_end = start + simd_len;
2496
2497 unsafe {
2498 let mut sum_xx = _mm256_setzero_ps();
2499 let mut sum_yy = _mm256_setzero_ps();
2500 let mut sum_xy_re = _mm256_setzero_ps();
2501 let _sum_xy_im = _mm256_setzero_ps();
2502
2503 for i in (start..simd_end).step_by(4) {
2504 let left_ptr = left.as_ptr().add(i) as *const f32;
2505 let right_ptr = right.as_ptr().add(i) as *const f32;
2506
2507 let l = _mm256_loadu_ps(left_ptr);
2509 let r = _mm256_loadu_ps(right_ptr);
2510
2511 let l_sqr = _mm256_mul_ps(l, l);
2513 let r_sqr = _mm256_mul_ps(r, r);
2514
2515 let l_norm = _mm256_hadd_ps(l_sqr, l_sqr);
2517 let r_norm = _mm256_hadd_ps(r_sqr, r_sqr);
2518
2519 sum_xx = _mm256_add_ps(sum_xx, l_norm);
2520 sum_yy = _mm256_add_ps(sum_yy, r_norm);
2521
2522 let sign_mask = _mm256_set_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
2524 let r_conj = _mm256_xor_ps(r, sign_mask);
2525
2526 let l_re = _mm256_moveldup_ps(l);
2528 let l_im = _mm256_movehdup_ps(l);
2529
2530 let ac_ad = _mm256_mul_ps(l_re, r_conj);
2531 let r_conj_swap = _mm256_shuffle_ps(r_conj, r_conj, 0b10110001);
2532 let bd_bc = _mm256_mul_ps(l_im, r_conj_swap);
2533
2534 let result = _mm256_addsub_ps(ac_ad, bd_bc);
2535
2536 sum_xy_re = _mm256_add_ps(sum_xy_re, result);
2538 }
2539
2540 let xx_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_xx);
2542 let yy_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_yy);
2543 let xy_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_xy_re);
2544
2545 cov_xx = xx_arr[0] + xx_arr[1] + xx_arr[4] + xx_arr[5];
2548 cov_yy = yy_arr[0] + yy_arr[1] + yy_arr[4] + yy_arr[5];
2549
2550 cov_xy.re = xy_arr[0] + xy_arr[2] + xy_arr[4] + xy_arr[6];
2552 cov_xy.im = xy_arr[1] + xy_arr[3] + xy_arr[5] + xy_arr[7];
2553 }
2554
2555 for i in simd_end..end {
2557 let l = left[i];
2558 let r = right[i];
2559 cov_xx += l.norm_sqr();
2560 cov_yy += r.norm_sqr();
2561 cov_xy += l * r.conj();
2562 }
2563
2564 (cov_xx, cov_yy, cov_xy)
2565 }
2566
2567 #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
2568 {
2569 let mut cov_xx = 0.0_f32;
2570 let mut cov_yy = 0.0_f32;
2571 let mut cov_xy = Complex::new(0.0, 0.0);
2572
2573 for i in start..end {
2574 let l = left[i];
2575 let r = right[i];
2576 cov_xx += l.norm_sqr();
2577 cov_yy += r.norm_sqr();
2578 cov_xy += l * r.conj();
2579 }
2580
2581 (cov_xx, cov_yy, cov_xy)
2582 }
2583}
2584
2585#[inline]
2587pub fn apply_gain_simd(buffer: &mut [f32], gain: f32) {
2588 let len = buffer.len();
2589
2590 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2591 {
2592 use std::arch::x86_64::*;
2593 let gain_vec = unsafe { _mm256_set1_ps(gain) };
2594 let simd_len = (len / 8) * 8;
2595 for i in (0..simd_len).step_by(8) {
2596 unsafe {
2597 let ptr = buffer.as_mut_ptr().add(i);
2598 let v = _mm256_loadu_ps(ptr);
2599 let res = _mm256_mul_ps(v, gain_vec);
2600 _mm256_storeu_ps(ptr, res);
2601 }
2602 }
2603 for sample in buffer.iter_mut().take(len).skip(simd_len) {
2604 *sample *= gain;
2605 }
2606 }
2607
2608 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2609 {
2610 use std::arch::aarch64::*;
2611 let gain_vec = unsafe { vdupq_n_f32(gain) };
2612 let simd_len = (len / 4) * 4;
2613 for i in (0..simd_len).step_by(4) {
2614 unsafe {
2615 let ptr = buffer.as_mut_ptr().add(i);
2616 let v = vld1q_f32(ptr);
2617 let res = vmulq_f32(v, gain_vec);
2618 vst1q_f32(ptr, res);
2619 }
2620 }
2621 for sample in buffer[simd_len..len].iter_mut() {
2622 *sample *= gain;
2623 }
2624 }
2625
2626 #[cfg(not(any(
2627 all(target_arch = "x86_64", target_feature = "avx2"),
2628 all(target_arch = "aarch64", target_feature = "neon")
2629 )))]
2630 {
2631 for val in buffer.iter_mut() {
2632 *val *= gain;
2633 }
2634 }
2635}
2636
2637#[inline]
2639pub fn apply_per_channel_gain_simd(buffer: &mut [f32], channels: usize, gains: &[f32]) {
2640 let len = buffer.len();
2641 let num_frames = len / channels;
2642
2643 if channels == 2 {
2647 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2648 {
2649 use std::arch::x86_64::*;
2650 let gains_vec = unsafe {
2651 _mm256_set_ps(
2652 gains[1], gains[0], gains[1], gains[0], gains[1], gains[0], gains[1], gains[0],
2653 )
2654 };
2655 let simd_len = (num_frames / 4) * 4;
2656 for i in (0..simd_len).step_by(4) {
2657 unsafe {
2658 let ptr = buffer.as_mut_ptr().add(i * 2);
2659 let v = _mm256_loadu_ps(ptr);
2660 let res = _mm256_mul_ps(v, gains_vec);
2661 _mm256_storeu_ps(ptr, res);
2662 }
2663 }
2664 for i in simd_len..num_frames {
2665 buffer[i * 2] *= gains[0];
2666 buffer[i * 2 + 1] *= gains[1];
2667 }
2668 return;
2669 }
2670
2671 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2672 {
2673 use std::arch::aarch64::*;
2674 let gains_vec = unsafe {
2675 let g = [gains[0], gains[1], gains[0], gains[1]];
2676 vld1q_f32(g.as_ptr())
2677 };
2678 let simd_len = (num_frames / 2) * 2;
2679 for i in (0..simd_len).step_by(2) {
2680 unsafe {
2681 let ptr = buffer.as_mut_ptr().add(i * 2);
2682 let v = vld1q_f32(ptr);
2683 let res = vmulq_f32(v, gains_vec);
2684 vst1q_f32(ptr, res);
2685 }
2686 }
2687 for i in simd_len..num_frames {
2688 buffer[i * 2] *= gains[0];
2689 buffer[i * 2 + 1] *= gains[1];
2690 }
2691 return;
2692 }
2693 }
2694
2695 for frame in 0..num_frames {
2697 for ch in 0..channels {
2698 buffer[frame * channels + ch] *= gains[ch];
2699 }
2700 }
2701}
2702
2703#[inline(always)]
2706pub fn fast_inv_sqrt(x: f32) -> f32 {
2707 let half = 0.5 * x;
2708 let i = f32::to_bits(x);
2709 let i = 0x5f37_59df - (i >> 1); let y = f32::from_bits(i);
2711 y * (1.5 - half * y * y) }
2713
2714#[inline]
2716pub fn find_max_abs_simd(samples: &[f32]) -> f32 {
2717 let len = samples.len();
2718 if len == 0 {
2719 return 0.0;
2720 }
2721
2722 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2723 {
2724 use std::arch::x86_64::*;
2725 let mut max_vec = unsafe { _mm256_setzero_ps() };
2726 let abs_mask = unsafe { _mm256_set1_ps(-0.0) };
2727 let simd_len = (len / 8) * 8;
2728
2729 for i in (0..simd_len).step_by(8) {
2730 unsafe {
2731 let ptr = samples.as_ptr().add(i);
2732 let v = _mm256_loadu_ps(ptr);
2733 let av = _mm256_andnot_ps(abs_mask, v);
2734 max_vec = _mm256_max_ps(max_vec, av);
2735 }
2736 }
2737
2738 let mut max_val = 0.0_f32;
2739 unsafe {
2740 let arr = std::mem::transmute::<__m256, [f32; 8]>(max_vec);
2741 for &v in &arr {
2742 if v > max_val {
2743 max_val = v;
2744 }
2745 }
2746 }
2747
2748 for sample in samples.iter().take(len).skip(simd_len) {
2749 let v = sample.abs();
2750 if v > max_val {
2751 max_val = v;
2752 }
2753 }
2754 max_val
2755 }
2756
2757 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2758 {
2759 use std::arch::aarch64::*;
2760 let mut max_vec = unsafe { vdupq_n_f32(0.0) };
2761 let simd_len = (len / 4) * 4;
2762
2763 for i in (0..simd_len).step_by(4) {
2764 unsafe {
2765 let ptr = samples.as_ptr().add(i);
2766 let v = vld1q_f32(ptr);
2767 let av = vabsq_f32(v);
2768 max_vec = vmaxq_f32(max_vec, av);
2769 }
2770 }
2771
2772 let mut max_val = unsafe { vmaxvq_f32(max_vec) };
2773
2774 for sample in &samples[simd_len..len] {
2775 let v = sample.abs();
2776 if v > max_val {
2777 max_val = v;
2778 }
2779 }
2780 max_val
2781 }
2782
2783 #[cfg(not(any(
2784 all(target_arch = "x86_64", target_feature = "avx2"),
2785 all(target_arch = "aarch64", target_feature = "neon")
2786 )))]
2787 {
2788 let mut max_val = 0.0_f32;
2789 for &s in samples {
2790 let v = s.abs();
2791 if v > max_val {
2792 max_val = v;
2793 }
2794 }
2795 max_val
2796 }
2797}