Skip to main content

math_audio_dsp/
simd.rs

1// ============================================================================
2// SIMD Optimizations for Complex Multiplication
3// ============================================================================
4//
5// These functions provide SIMD-accelerated complex multiplication for the
6// frequency-domain HRTF convolution hot paths. Complex multiplication:
7//   (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
8//
9// Platform support:
10// - x86-64: AVX2 (processes 4 complex f32 at once using 256-bit registers)
11// - aarch64 + FCMA: FCMLA instructions (2 ops per complex mul, ARMv8.3+, all Apple Silicon)
12// - aarch64: NEON fallback (processes 2 complex f32 at once using 128-bit registers)
13// - fallback: Scalar implementation for all other platforms
14//
15// Performance gains: 2-4x speedup on supported platforms for FFT sizes >= 512
16
17use rustfft::num_complex::Complex;
18
19// FCMLA complex multiply-accumulate via inline asm (ARMv8.3+ / Apple Silicon)
20// Computes: r[i] += a[i] * b[i] for 2 complex f32 packed in float32x4_t
21#[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
22#[inline(always)]
23unsafe fn fcmla_mul_acc(
24    mut r: std::arch::aarch64::float32x4_t,
25    a: std::arch::aarch64::float32x4_t,
26    b: std::arch::aarch64::float32x4_t,
27) -> std::arch::aarch64::float32x4_t {
28    unsafe {
29        std::arch::asm!(
30            "fcmla {r:v}.4s, {a:v}.4s, {b:v}.4s, #0",
31            "fcmla {r:v}.4s, {a:v}.4s, {b:v}.4s, #90",
32            r = inout(vreg) r,
33            a = in(vreg) a,
34            b = in(vreg) b,
35            options(pure, nomem, nostack),
36        );
37    }
38    r
39}
40
41// AVX2 shuffle constant for swapping re/im pairs
42#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
43const SHUFFLE_SWAP_RE_IM: i32 = 0b10110001; // Swaps: [re, im] -> [im, re]
44
45#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
46/// # Safety
47/// Caller must ensure `dst`, `src` and `hrtf` have at least `start + 4` elements.
48#[inline]
49pub unsafe fn complex_mul_add_simd_chunk(
50    dst: &mut [Complex<f32>],
51    src: &[Complex<f32>],
52    hrtf: &[Complex<f32>],
53    start: usize,
54) {
55    use std::arch::x86_64::*;
56
57    // Process 4 complex numbers (8 floats) at once using AVX2
58    // Input layout: [re0, im0, re1, im1, re2, im2, re3, im3]
59    unsafe {
60        let src_ptr = src.as_ptr().add(start) as *const f32;
61        let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
62        let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
63
64        // Load 4 complex numbers
65        let a = _mm256_loadu_ps(src_ptr);
66        let b = _mm256_loadu_ps(hrtf_ptr);
67        let dst_val = _mm256_loadu_ps(dst_ptr);
68
69        // Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
70
71        // Duplicate real and imaginary parts correctly:
72        // moveldup: duplicates even elements [0, 0, 2, 2, 4, 4, 6, 6] -> [re0, re0, re1, re1, ...]
73        // movehdup: duplicates odd elements  [1, 1, 3, 3, 5, 5, 7, 7] -> [im0, im0, im1, im1, ...]
74        let a_re = _mm256_moveldup_ps(a);
75        let a_im = _mm256_movehdup_ps(a);
76
77        // Compute: a.re * b = [re*re, re*im, ...] = [ac, ad, ...]
78        let ac_ad = _mm256_mul_ps(a_re, b);
79
80        // Swap b's re/im: [im, re, im, re, ...] = [d, c, ...]
81        let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
82
83        // Compute: a.im * b_swapped = [im*im, im*re, ...] = [bd, bc, ...]
84        let bd_bc = _mm256_mul_ps(a_im, b_swapped);
85
86        // Combine using addsub: performs [a[0]-b[0], a[1]+b[1], a[2]-b[2], a[3]+b[3], ...]
87        // This gives us: [(ac - bd), (ad + bc), ...] = [result.re, result.im, ...]
88        let result = _mm256_addsub_ps(ac_ad, bd_bc);
89
90        // Add to destination (accumulate)
91        let final_result = _mm256_add_ps(dst_val, result);
92
93        _mm256_storeu_ps(dst_ptr, final_result);
94    }
95}
96
97#[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
98/// # Safety
99/// Caller must ensure `dst`, `src` and `hrtf` have at least `start + 2` elements.
100#[inline]
101pub unsafe fn complex_mul_add_simd_chunk(
102    dst: &mut [Complex<f32>],
103    src: &[Complex<f32>],
104    hrtf: &[Complex<f32>],
105    start: usize,
106) {
107    use std::arch::aarch64::*;
108
109    unsafe {
110        let src_ptr = src.as_ptr().add(start) as *const f32;
111        let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
112        let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
113
114        let a = vld1q_f32(src_ptr);
115        let b = vld1q_f32(hrtf_ptr);
116        let r = vld1q_f32(dst_ptr);
117        let result = fcmla_mul_acc(r, a, b);
118        vst1q_f32(dst_ptr, result);
119    }
120}
121
122#[cfg(all(
123    target_arch = "aarch64",
124    target_feature = "neon",
125    not(target_feature = "fcma")
126))]
127/// # Safety
128/// Caller must ensure `dst`, `src` and `hrtf` have at least `start + 2` elements.
129#[inline]
130pub unsafe fn complex_mul_add_simd_chunk(
131    dst: &mut [Complex<f32>],
132    src: &[Complex<f32>],
133    hrtf: &[Complex<f32>],
134    start: usize,
135) {
136    use std::arch::aarch64::*;
137
138    unsafe {
139        let src_ptr = src.as_ptr().add(start) as *const f32;
140        let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
141        let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
142
143        let a = vld1q_f32(src_ptr);
144        let b = vld1q_f32(hrtf_ptr);
145        let dst_val = vld1q_f32(dst_ptr);
146
147        let a_re = vtrn1q_f32(a, a);
148        let a_im = vtrn2q_f32(a, a);
149        let ac_ad = vmulq_f32(a_re, b);
150        let b_swapped = vrev64q_f32(b);
151        let bd_bc = vmulq_f32(a_im, b_swapped);
152
153        let sign_bit: u32 = 0x80000000;
154        let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
155            sign_bit,
156            vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
157        ));
158
159        let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
160            vreinterpretq_u32_f32(bd_bc),
161            vreinterpretq_u32_f32(neg_mask),
162        ));
163        let result = vaddq_f32(ac_ad, bd_bc_negated);
164        let final_result = vaddq_f32(dst_val, result);
165
166        vst1q_f32(dst_ptr, final_result);
167    }
168}
169
170#[cfg(not(any(
171    all(target_arch = "x86_64", target_feature = "avx2"),
172    all(target_arch = "aarch64", target_feature = "neon")
173)))]
174#[inline]
175pub fn complex_mul_add_simd_chunk(
176    dst: &mut [Complex<f32>],
177    src: &[Complex<f32>],
178    hrtf: &[Complex<f32>],
179    start: usize,
180) {
181    // Scalar fallback (will be optimized by LLVM auto-vectorization)
182    dst[start] += src[start] * hrtf[start];
183}
184
185/// SIMD-optimized complex multiply-accumulate
186///
187/// Computes: `dst[i] += src[i] * hrtf[i]` for all `i`
188///
189/// Uses platform-specific SIMD instructions for maximum performance:
190/// - AVX2 on x86-64 (4 complex at once)
191/// - NEON on aarch64 (2 complex at once)
192/// - Scalar fallback with auto-vectorization hints
193#[inline]
194pub fn complex_mul_add_simd(dst: &mut [Complex<f32>], src: &[Complex<f32>], hrtf: &[Complex<f32>]) {
195    let len = dst.len();
196
197    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
198    {
199        // Process 4 complex at a time with AVX2
200        let simd_len = (len / 4) * 4;
201
202        for i in (0..simd_len).step_by(4) {
203            unsafe {
204                complex_mul_add_simd_chunk(dst, src, hrtf, i);
205            }
206        }
207
208        // Scalar remainder
209        for i in simd_len..len {
210            dst[i] += src[i] * hrtf[i];
211        }
212    }
213
214    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
215    {
216        // Process 2 complex at a time with NEON
217        let simd_len = (len / 2) * 2;
218
219        for i in (0..simd_len).step_by(2) {
220            unsafe {
221                complex_mul_add_simd_chunk(dst, src, hrtf, i);
222            }
223        }
224
225        // Scalar remainder
226        for i in simd_len..len {
227            dst[i] += src[i] * hrtf[i];
228        }
229    }
230
231    #[cfg(not(any(
232        all(target_arch = "x86_64", target_feature = "avx2"),
233        all(target_arch = "aarch64", target_feature = "neon")
234    )))]
235    {
236        // Scalar fallback
237        for i in 0..len {
238            dst[i] += src[i] * hrtf[i];
239        }
240    }
241}
242
243/// SIMD-optimized complex multiplication (without accumulation)
244///
245/// Computes: `dst[i] = src[i] * hrtf[i]` for all `i`
246#[inline]
247pub fn complex_mul_simd(dst: &mut [Complex<f32>], src: &[Complex<f32>], hrtf: &[Complex<f32>]) {
248    let len = dst.len();
249
250    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
251    {
252        use std::arch::x86_64::*;
253
254        let simd_len = (len / 4) * 4;
255
256        for i in (0..simd_len).step_by(4) {
257            unsafe {
258                let src_ptr = src.as_ptr().add(i) as *const f32;
259                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
260                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
261
262                let a = _mm256_loadu_ps(src_ptr);
263                let b = _mm256_loadu_ps(hrtf_ptr);
264
265                let a_re = _mm256_moveldup_ps(a);
266                let a_im = _mm256_movehdup_ps(a);
267                let ac_ad = _mm256_mul_ps(a_re, b);
268                let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
269                let bd_bc = _mm256_mul_ps(a_im, b_swapped);
270                let result = _mm256_addsub_ps(ac_ad, bd_bc);
271
272                _mm256_storeu_ps(dst_ptr, result);
273            }
274        }
275
276        for i in simd_len..len {
277            dst[i] = src[i] * hrtf[i];
278        }
279    }
280
281    #[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
282    {
283        use std::arch::aarch64::*;
284
285        let simd_len = (len / 2) * 2;
286
287        for i in (0..simd_len).step_by(2) {
288            unsafe {
289                let src_ptr = src.as_ptr().add(i) as *const f32;
290                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
291                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
292
293                let a = vld1q_f32(src_ptr);
294                let b = vld1q_f32(hrtf_ptr);
295                let r = vdupq_n_f32(0.0);
296                let result = fcmla_mul_acc(r, a, b);
297                vst1q_f32(dst_ptr, result);
298            }
299        }
300
301        for i in simd_len..len {
302            dst[i] = src[i] * hrtf[i];
303        }
304    }
305
306    #[cfg(all(
307        target_arch = "aarch64",
308        target_feature = "neon",
309        not(target_feature = "fcma")
310    ))]
311    {
312        use std::arch::aarch64::*;
313
314        let simd_len = (len / 2) * 2;
315
316        for i in (0..simd_len).step_by(2) {
317            unsafe {
318                let src_ptr = src.as_ptr().add(i) as *const f32;
319                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
320                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
321
322                let a = vld1q_f32(src_ptr);
323                let b = vld1q_f32(hrtf_ptr);
324
325                let a_re = vtrn1q_f32(a, a);
326                let a_im = vtrn2q_f32(a, a);
327                let ac_ad = vmulq_f32(a_re, b);
328                let b_swapped = vrev64q_f32(b);
329                let bd_bc = vmulq_f32(a_im, b_swapped);
330
331                let sign_bit: u32 = 0x80000000;
332                let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
333                    sign_bit,
334                    vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
335                ));
336
337                let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
338                    vreinterpretq_u32_f32(bd_bc),
339                    vreinterpretq_u32_f32(neg_mask),
340                ));
341
342                let result = vaddq_f32(ac_ad, bd_bc_negated);
343                vst1q_f32(dst_ptr, result);
344            }
345        }
346
347        for i in simd_len..len {
348            dst[i] = src[i] * hrtf[i];
349        }
350    }
351
352    #[cfg(not(any(
353        all(target_arch = "x86_64", target_feature = "avx2"),
354        all(target_arch = "aarch64", target_feature = "neon")
355    )))]
356    {
357        for i in 0..len {
358            dst[i] = src[i] * hrtf[i];
359        }
360    }
361}
362
363/// SIMD-optimized in-place complex multiplication
364///
365/// Computes: `dst[i] *= hrtf[i]` for all `i`
366#[inline]
367#[allow(dead_code)]
368pub fn complex_mul_inplace_simd(dst: &mut [Complex<f32>], hrtf: &[Complex<f32>]) {
369    let len = dst.len();
370
371    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
372    {
373        use std::arch::x86_64::*;
374
375        let simd_len = (len / 4) * 4;
376
377        for i in (0..simd_len).step_by(4) {
378            unsafe {
379                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
380                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
381
382                let a = _mm256_loadu_ps(dst_ptr);
383                let b = _mm256_loadu_ps(hrtf_ptr);
384
385                let a_re = _mm256_moveldup_ps(a);
386                let a_im = _mm256_movehdup_ps(a);
387                let ac_ad = _mm256_mul_ps(a_re, b);
388                let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
389                let bd_bc = _mm256_mul_ps(a_im, b_swapped);
390                let result = _mm256_addsub_ps(ac_ad, bd_bc);
391
392                _mm256_storeu_ps(dst_ptr, result);
393            }
394        }
395
396        for i in simd_len..len {
397            dst[i] *= hrtf[i];
398        }
399    }
400
401    #[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
402    {
403        use std::arch::aarch64::*;
404
405        let simd_len = (len / 2) * 2;
406
407        for i in (0..simd_len).step_by(2) {
408            unsafe {
409                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
410                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
411
412                let a = vld1q_f32(dst_ptr);
413                let b = vld1q_f32(hrtf_ptr);
414                let r = vdupq_n_f32(0.0);
415                let result = fcmla_mul_acc(r, a, b);
416                vst1q_f32(dst_ptr, result);
417            }
418        }
419
420        for i in simd_len..len {
421            dst[i] *= hrtf[i];
422        }
423    }
424
425    #[cfg(all(
426        target_arch = "aarch64",
427        target_feature = "neon",
428        not(target_feature = "fcma")
429    ))]
430    {
431        use std::arch::aarch64::*;
432
433        let simd_len = (len / 2) * 2;
434
435        for i in (0..simd_len).step_by(2) {
436            unsafe {
437                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
438                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
439
440                let a = vld1q_f32(dst_ptr);
441                let b = vld1q_f32(hrtf_ptr);
442
443                let a_re = vtrn1q_f32(a, a);
444                let a_im = vtrn2q_f32(a, a);
445                let ac_ad = vmulq_f32(a_re, b);
446                let b_swapped = vrev64q_f32(b);
447                let bd_bc = vmulq_f32(a_im, b_swapped);
448
449                let sign_bit: u32 = 0x80000000;
450                let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
451                    sign_bit,
452                    vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
453                ));
454
455                let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
456                    vreinterpretq_u32_f32(bd_bc),
457                    vreinterpretq_u32_f32(neg_mask),
458                ));
459
460                let result = vaddq_f32(ac_ad, bd_bc_negated);
461                vst1q_f32(dst_ptr, result);
462            }
463        }
464
465        for i in simd_len..len {
466            dst[i] *= hrtf[i];
467        }
468    }
469
470    #[cfg(not(any(
471        all(target_arch = "x86_64", target_feature = "avx2"),
472        all(target_arch = "aarch64", target_feature = "neon")
473    )))]
474    {
475        for i in 0..len {
476            dst[i] *= hrtf[i];
477        }
478    }
479}
480
481// ============================================================================
482// Real-valued SIMD operations for windowing and overlap-add
483// ============================================================================
484
485/// SIMD-optimized multiply-accumulate for overlap-add synthesis (no window)
486///
487/// Computes: `dst[i] += src[i] * scale` for all `i`
488#[inline]
489pub fn scale_add_simd(dst: &mut [f32], src: &[f32], scale: f32) {
490    debug_assert_eq!(dst.len(), src.len());
491    let len = dst.len();
492
493    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
494    {
495        use std::arch::x86_64::*;
496
497        let scale_vec = unsafe { _mm256_set1_ps(scale) };
498        let simd_len = (len / 8) * 8;
499
500        for i in (0..simd_len).step_by(8) {
501            unsafe {
502                let src_ptr = src.as_ptr().add(i);
503                let dst_ptr = dst.as_mut_ptr().add(i);
504
505                let s = _mm256_loadu_ps(src_ptr);
506                let d = _mm256_loadu_ps(dst_ptr);
507
508                // src * scale + dst
509                let ss = _mm256_mul_ps(s, scale_vec);
510                let result = _mm256_add_ps(d, ss);
511
512                _mm256_storeu_ps(dst_ptr, result);
513            }
514        }
515
516        // Scalar remainder
517        for i in simd_len..len {
518            dst[i] += src[i] * scale;
519        }
520    }
521
522    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
523    {
524        use std::arch::aarch64::*;
525
526        let scale_vec = unsafe { vdupq_n_f32(scale) };
527        let simd_len = (len / 4) * 4;
528
529        for i in (0..simd_len).step_by(4) {
530            unsafe {
531                let src_ptr = src.as_ptr().add(i);
532                let dst_ptr = dst.as_mut_ptr().add(i);
533
534                let s = vld1q_f32(src_ptr);
535                let d = vld1q_f32(dst_ptr);
536
537                // FMA: dst + src * scale in a single fused instruction
538                let result = vfmaq_f32(d, s, scale_vec);
539
540                vst1q_f32(dst_ptr, result);
541            }
542        }
543
544        // Scalar remainder
545        for i in simd_len..len {
546            dst[i] += src[i] * scale;
547        }
548    }
549
550    #[cfg(not(any(
551        all(target_arch = "x86_64", target_feature = "avx2"),
552        all(target_arch = "aarch64", target_feature = "neon")
553    )))]
554    {
555        for i in 0..len {
556            dst[i] += src[i] * scale;
557        }
558    }
559}
560
561/// SIMD-optimized in-place scaling.
562///
563/// Computes: `data[i] *= scale` for all `i`
564#[inline]
565pub fn scale_add_simd_inplace(data: &mut [f32], scale: f32) {
566    let len = data.len();
567
568    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
569    {
570        use std::arch::x86_64::*;
571
572        let scale_vec = unsafe { _mm256_set1_ps(scale) };
573        let simd_len = (len / 8) * 8;
574
575        for i in (0..simd_len).step_by(8) {
576            unsafe {
577                let ptr = data.as_mut_ptr().add(i);
578                let d = _mm256_loadu_ps(ptr);
579                _mm256_storeu_ps(ptr, _mm256_mul_ps(d, scale_vec));
580            }
581        }
582
583        for sample in data.iter_mut().take(len).skip(simd_len) {
584            *sample *= scale;
585        }
586    }
587
588    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
589    {
590        use std::arch::aarch64::*;
591
592        let scale_vec = unsafe { vdupq_n_f32(scale) };
593        let simd_len = (len / 4) * 4;
594
595        for i in (0..simd_len).step_by(4) {
596            unsafe {
597                let ptr = data.as_mut_ptr().add(i);
598                let d = vld1q_f32(ptr);
599                vst1q_f32(ptr, vmulq_f32(d, scale_vec));
600            }
601        }
602
603        for sample in &mut data[simd_len..len] {
604            *sample *= scale;
605        }
606    }
607
608    #[cfg(not(any(
609        all(target_arch = "x86_64", target_feature = "avx2"),
610        all(target_arch = "aarch64", target_feature = "neon")
611    )))]
612    {
613        for sample in data {
614            *sample *= scale;
615        }
616    }
617}
618
619/// SIMD-optimized linear interpolation (blend) between two buffers
620///
621/// Computes: `dst[i] = prev[i] + alpha * (dst[i] - prev[i])` for all `i`
622/// Equivalent to: `dst[i] = (1 - alpha) * prev[i] + alpha * dst[i]`
623///
624/// Used for crossfading between old and new filter outputs in STFT plugins.
625#[inline]
626pub fn blend_simd(dst: &mut [f32], prev: &[f32], alpha: f32) {
627    debug_assert_eq!(dst.len(), prev.len());
628    let len = dst.len();
629
630    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
631    {
632        use std::arch::x86_64::*;
633
634        let alpha_vec = unsafe { _mm256_set1_ps(alpha) };
635        let simd_len = (len / 8) * 8;
636
637        for i in (0..simd_len).step_by(8) {
638            unsafe {
639                let prev_ptr = prev.as_ptr().add(i);
640                let dst_ptr = dst.as_mut_ptr().add(i);
641
642                let p = _mm256_loadu_ps(prev_ptr);
643                let d = _mm256_loadu_ps(dst_ptr);
644
645                // prev + alpha * (dst - prev)
646                let diff = _mm256_sub_ps(d, p);
647                let result = _mm256_fmadd_ps(alpha_vec, diff, p);
648
649                _mm256_storeu_ps(dst_ptr, result);
650            }
651        }
652
653        for i in simd_len..len {
654            dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
655        }
656    }
657
658    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
659    {
660        use std::arch::aarch64::*;
661
662        let alpha_vec = unsafe { vdupq_n_f32(alpha) };
663        let simd_len = (len / 4) * 4;
664
665        for i in (0..simd_len).step_by(4) {
666            unsafe {
667                let prev_ptr = prev.as_ptr().add(i);
668                let dst_ptr = dst.as_mut_ptr().add(i);
669
670                let p = vld1q_f32(prev_ptr);
671                let d = vld1q_f32(dst_ptr);
672
673                // prev + alpha * (dst - prev)
674                let diff = vsubq_f32(d, p);
675                let result = vfmaq_f32(p, alpha_vec, diff);
676
677                vst1q_f32(dst_ptr, result);
678            }
679        }
680
681        for i in simd_len..len {
682            dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
683        }
684    }
685
686    #[cfg(not(any(
687        all(target_arch = "x86_64", target_feature = "avx2"),
688        all(target_arch = "aarch64", target_feature = "neon")
689    )))]
690    {
691        for i in 0..len {
692            dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
693        }
694    }
695}
696
697/// SIMD-optimized windowed copy (for FFT input preparation)
698///
699/// Computes: `dst[i] = src[i] * window[i]` for all `i`
700#[inline]
701pub fn window_mul_simd(dst: &mut [f32], src: &[f32], window: &[f32]) {
702    debug_assert_eq!(dst.len(), src.len());
703    debug_assert_eq!(dst.len(), window.len());
704    let len = dst.len();
705
706    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
707    {
708        use std::arch::x86_64::*;
709
710        let simd_len = (len / 8) * 8;
711
712        for i in (0..simd_len).step_by(8) {
713            unsafe {
714                let src_ptr = src.as_ptr().add(i);
715                let win_ptr = window.as_ptr().add(i);
716                let dst_ptr = dst.as_mut_ptr().add(i);
717
718                let s = _mm256_loadu_ps(src_ptr);
719                let w = _mm256_loadu_ps(win_ptr);
720                let result = _mm256_mul_ps(s, w);
721
722                _mm256_storeu_ps(dst_ptr, result);
723            }
724        }
725
726        for i in simd_len..len {
727            dst[i] = src[i] * window[i];
728        }
729    }
730
731    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
732    {
733        use std::arch::aarch64::*;
734
735        let simd_len = (len / 4) * 4;
736
737        for i in (0..simd_len).step_by(4) {
738            unsafe {
739                let src_ptr = src.as_ptr().add(i);
740                let win_ptr = window.as_ptr().add(i);
741                let dst_ptr = dst.as_mut_ptr().add(i);
742
743                let s = vld1q_f32(src_ptr);
744                let w = vld1q_f32(win_ptr);
745                let result = vmulq_f32(s, w);
746
747                vst1q_f32(dst_ptr, result);
748            }
749        }
750
751        for i in simd_len..len {
752            dst[i] = src[i] * window[i];
753        }
754    }
755
756    #[cfg(not(any(
757        all(target_arch = "x86_64", target_feature = "avx2"),
758        all(target_arch = "aarch64", target_feature = "neon")
759    )))]
760    {
761        for i in 0..len {
762            dst[i] = src[i] * window[i];
763        }
764    }
765}
766
767/// SIMD-optimized in-place window multiplication.
768#[inline]
769pub fn window_mul_simd_inplace(data: &mut [f32], window: &[f32]) {
770    let len = data.len().min(window.len());
771
772    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
773    {
774        use std::arch::x86_64::*;
775        let simd_len = (len / 8) * 8;
776        for i in (0..simd_len).step_by(8) {
777            unsafe {
778                let ptr = data.as_mut_ptr().add(i);
779                let win_ptr = window.as_ptr().add(i);
780                let d = _mm256_loadu_ps(ptr);
781                let w = _mm256_loadu_ps(win_ptr);
782                _mm256_storeu_ps(ptr, _mm256_mul_ps(d, w));
783            }
784        }
785        for i in simd_len..len {
786            data[i] *= window[i];
787        }
788    }
789
790    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
791    {
792        use std::arch::aarch64::*;
793        let simd_len = (len / 4) * 4;
794        for i in (0..simd_len).step_by(4) {
795            unsafe {
796                let ptr = data.as_mut_ptr().add(i);
797                let win_ptr = window.as_ptr().add(i);
798                let d = vld1q_f32(ptr);
799                let w = vld1q_f32(win_ptr);
800                vst1q_f32(ptr, vmulq_f32(d, w));
801            }
802        }
803        for i in simd_len..len {
804            data[i] *= window[i];
805        }
806    }
807
808    #[cfg(not(any(
809        all(target_arch = "x86_64", target_feature = "avx2"),
810        all(target_arch = "aarch64", target_feature = "neon")
811    )))]
812    {
813        for i in 0..len {
814            data[i] *= window[i];
815        }
816    }
817}
818
819/// Deinterleave stereo buffer into separate L/R channels
820///
821/// Input: [L0, R0, L1, R1, L2, R2, ...]
822/// Output: left = [L0, L1, L2, ...], right = [R0, R1, R2, ...]
823#[inline]
824pub fn deinterleave_stereo(input: &[f32], left: &mut [f32], right: &mut [f32]) {
825    debug_assert_eq!(input.len(), left.len() * 2);
826    debug_assert_eq!(left.len(), right.len());
827
828    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
829    {
830        use std::arch::x86_64::*;
831
832        let len = left.len();
833        let simd_len = (len / 8) * 8;
834
835        for i in (0..simd_len).step_by(8) {
836            unsafe {
837                // Load 16 floats (8 stereo pairs)
838                let in_ptr = input.as_ptr().add(i * 2);
839                let v0 = _mm256_loadu_ps(in_ptr); // L0 R0 L1 R1 L2 R2 L3 R3
840                let v1 = _mm256_loadu_ps(in_ptr.add(8)); // L4 R4 L5 R5 L6 R6 L7 R7
841
842                // Shuffle to separate L and R
843                // Within 256-bit lanes, shuffle to group L/R
844                let shuf_l = _mm256_shuffle_ps(v0, v1, 0b10_00_10_00); // L0 L1 L4 L5 | L2 L3 L6 L7
845                let shuf_r = _mm256_shuffle_ps(v0, v1, 0b11_01_11_01); // R0 R1 R4 R5 | R2 R3 R6 R7
846
847                // Permute to get correct order
848                let left_vec = _mm256_permute4x64_pd(
849                    std::mem::transmute::<__m256, __m256d>(shuf_l),
850                    0b11_01_10_00,
851                );
852                let right_vec = _mm256_permute4x64_pd(
853                    std::mem::transmute::<__m256, __m256d>(shuf_r),
854                    0b11_01_10_00,
855                );
856
857                _mm256_storeu_ps(
858                    left.as_mut_ptr().add(i),
859                    std::mem::transmute::<__m256d, __m256>(left_vec),
860                );
861                _mm256_storeu_ps(
862                    right.as_mut_ptr().add(i),
863                    std::mem::transmute::<__m256d, __m256>(right_vec),
864                );
865            }
866        }
867
868        // Scalar remainder
869        for i in simd_len..len {
870            left[i] = input[i * 2];
871            right[i] = input[i * 2 + 1];
872        }
873    }
874
875    #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
876    {
877        // Scalar fallback (compiler will auto-vectorize for NEON)
878        for (i, chunk) in input.chunks_exact(2).enumerate() {
879            left[i] = chunk[0];
880            right[i] = chunk[1];
881        }
882    }
883}
884
885/// Interleave separate L/R channels into stereo buffer
886///
887/// Input: left = [L0, L1, L2, ...], right = [R0, R1, R2, ...]
888/// Output: [L0, R0, L1, R1, L2, R2, ...]
889#[inline]
890#[allow(dead_code)]
891pub fn interleave_stereo(left: &[f32], right: &[f32], output: &mut [f32]) {
892    debug_assert_eq!(left.len(), right.len());
893    debug_assert_eq!(output.len(), left.len() * 2);
894
895    // Scalar version - compiler auto-vectorizes well
896    for i in 0..left.len() {
897        output[i * 2] = left[i];
898        output[i * 2 + 1] = right[i];
899    }
900}
901
902/// Flush denormal numbers to zero to prevent CPU performance spikes and audio glitches.
903///
904/// Denormal floats (values with magnitude < 1e-30) cause significant CPU overhead
905/// when processed by FMA instructions, leading to audio artifacts and crackle.
906/// This function checks each sample and sets denormals to exactly 0.0.
907#[inline]
908pub fn flush_denormals_inplace(samples: &mut [f32]) {
909    const DENORM_THRESHOLD: f32 = 1e-30;
910
911    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
912    {
913        use std::arch::x86_64::*;
914
915        let threshold = unsafe { _mm256_set1_ps(DENORM_THRESHOLD) };
916        let zero = unsafe { _mm256_set1_ps(0.0) };
917        let len = samples.len();
918        let simd_len = (len / 8) * 8;
919
920        for i in (0..simd_len).step_by(8) {
921            unsafe {
922                let ptr = samples.as_mut_ptr().add(i);
923                let val = _mm256_loadu_ps(ptr);
924                let abs_val = _mm256_andnot_ps(_mm256_set1_ps(-0.0), val);
925                let mask = _mm256_cmp_ps(abs_val, threshold, _CMP_LT_OQ);
926                let result = _mm256_blendv_ps(val, zero, mask);
927                _mm256_storeu_ps(ptr, result);
928            }
929        }
930
931        for sample in samples.iter_mut().take(len).skip(simd_len) {
932            if sample.abs() < DENORM_THRESHOLD {
933                *sample = 0.0;
934            }
935        }
936    }
937
938    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
939    {
940        use std::arch::aarch64::*;
941
942        let threshold = unsafe { vdupq_n_f32(DENORM_THRESHOLD) };
943        let zero = unsafe { vdupq_n_f32(0.0) };
944        let len = samples.len();
945        let simd_len = (len / 4) * 4;
946
947        for i in (0..simd_len).step_by(4) {
948            unsafe {
949                let ptr = samples.as_mut_ptr().add(i);
950                let val = vld1q_f32(ptr);
951                let abs_val = vabsq_f32(val);
952                let mask = vcltq_f32(abs_val, threshold);
953                let result = vbslq_f32(mask, zero, val);
954                vst1q_f32(ptr, result);
955            }
956        }
957
958        for sample in &mut samples[simd_len..len] {
959            if sample.abs() < DENORM_THRESHOLD {
960                *sample = 0.0;
961            }
962        }
963    }
964
965    #[cfg(not(any(
966        all(target_arch = "x86_64", target_feature = "avx2"),
967        all(target_arch = "aarch64", target_feature = "neon")
968    )))]
969    {
970        for sample in samples {
971            if sample.abs() < DENORM_THRESHOLD {
972                *sample = 0.0;
973            }
974        }
975    }
976}
977
978/// Enable FTZ (Flush To Zero) and DAZ (Denormals Are Zero) CPU flags.
979///
980/// When enabled, denormal floating-point numbers are automatically flushed to zero
981/// by the CPU hardware. This prevents the severe performance degradation that occurs
982/// when IIR filter state variables (like biquad y1/y2) contain denormals.
983///
984/// This function is idempotent and should be called once per audio processing thread.
985/// On unsupported platforms, this is a no-op.
986///
987/// Returns true if the flags were successfully set, false otherwise (e.g., unsupported platform).
988#[inline]
989pub fn enable_ftz_daz() -> bool {
990    #[cfg(target_arch = "x86_64")]
991    {
992        // MXCSR register bits:
993        // Bit 15: FTZ (Flush To Zero) - flush denormal results to zero
994        // Bit 6: DAZ (Denormals Are Zero) - treat denormal inputs as zero
995        unsafe {
996            let mut mxcsr: u32 = 0;
997            std::arch::asm!("stmxcsr [{}]", in(reg) &mut mxcsr, options(nostack, preserves_flags));
998            mxcsr |= (1 << 15) | (1 << 6); // FTZ | DAZ
999            std::arch::asm!("ldmxcsr [{}]", in(reg) &mxcsr, options(nostack, preserves_flags));
1000        }
1001        true
1002    }
1003
1004    #[cfg(target_arch = "aarch64")]
1005    {
1006        // On AArch64, use the FPCR register
1007        // Bit 24: FZ (Flush-to-Zero)
1008        // Note: AArch64 always treats denormal inputs as zero in FZ mode
1009        unsafe {
1010            let mut fpcr: u64;
1011            std::arch::asm!("mrs {}, fpcr", out(reg) fpcr);
1012            fpcr |= 1 << 24; // FZ bit
1013            std::arch::asm!("msr fpcr, {}", in(reg) fpcr);
1014        }
1015        true
1016    }
1017
1018    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1019    {
1020        false
1021    }
1022}
1023
1024/// Flush denormals in complex buffer (applies to both real and imaginary parts)
1025#[inline]
1026pub fn flush_denormals_complex_inplace(samples: &mut [Complex<f32>]) {
1027    // A Complex<f32> is just two f32s (re and im)
1028    // We can treat the whole buffer as a slice of f32 and use flush_denormals_inplace
1029    let len = samples.len() * 2;
1030    let ptr = samples.as_mut_ptr() as *mut f32;
1031    let f32_samples = unsafe { std::slice::from_raw_parts_mut(ptr, len) };
1032    flush_denormals_inplace(f32_samples);
1033}
1034
1035#[cfg(test)]
1036mod denorm_tests {
1037    use super::*;
1038
1039    #[test]
1040    fn test_flush_denormals_basic() {
1041        let mut samples = [1e-31_f32, 1e-20, 1e-10, 0.0, -1e-31, 1.0];
1042        flush_denormals_inplace(&mut samples);
1043        assert_eq!(samples[0], 0.0);
1044        assert_eq!(samples[1], 1e-20);
1045        assert_eq!(samples[2], 1e-10);
1046        assert_eq!(samples[3], 0.0);
1047        assert_eq!(samples[4], 0.0);
1048        assert_eq!(samples[5], 1.0);
1049    }
1050
1051    #[test]
1052    fn test_flush_denormals_complex() {
1053        use rustfft::num_complex::Complex;
1054        let mut samples = [
1055            Complex::new(1e-31, 1e-30),
1056            Complex::new(1.0, 1e-31),
1057            Complex::new(0.0, 0.0),
1058        ];
1059        flush_denormals_complex_inplace(&mut samples);
1060        assert_eq!(samples[0].re, 0.0);
1061        assert!((samples[0].im - 1e-30).abs() < 1e-35);
1062        assert_eq!(samples[1].re, 1.0);
1063        assert_eq!(samples[1].im, 0.0);
1064        assert_eq!(samples[2].re, 0.0);
1065        assert_eq!(samples[2].im, 0.0);
1066    }
1067
1068    #[test]
1069    fn test_flush_denormals_empty() {
1070        let mut samples: [f32; 0] = [];
1071        flush_denormals_inplace(&mut samples);
1072    }
1073
1074    #[test]
1075    fn test_flush_denormals_unaligned() {
1076        let mut samples = [1e-31_f32; 7];
1077        flush_denormals_inplace(&mut samples);
1078        for s in samples.iter() {
1079            assert_eq!(*s, 0.0);
1080        }
1081    }
1082}
1083
1084#[cfg(test)]
1085#[allow(clippy::needless_range_loop)]
1086mod tests {
1087    use super::*;
1088
1089    // ============================================================================
1090    // Denormal Flushing Tests
1091    // ============================================================================
1092
1093    #[test]
1094    fn test_flush_denormals_basic() {
1095        let mut samples = [1e-31_f32, 1e-20, 1e-10, 0.0, -1e-31, 1.0];
1096        flush_denormals_inplace(&mut samples);
1097        assert_eq!(samples[0], 0.0);
1098        assert_eq!(samples[1], 1e-20);
1099        assert_eq!(samples[2], 1e-10);
1100        assert_eq!(samples[3], 0.0);
1101        assert_eq!(samples[4], 0.0);
1102        assert_eq!(samples[5], 1.0);
1103    }
1104
1105    #[test]
1106    fn test_flush_denormals_complex() {
1107        use rustfft::num_complex::Complex;
1108        let mut samples = [
1109            Complex::new(1e-31, 1e-30),
1110            Complex::new(1.0, 1e-31),
1111            Complex::new(0.0, 0.0),
1112        ];
1113        flush_denormals_complex_inplace(&mut samples);
1114        assert_eq!(samples[0].re, 0.0);
1115        assert!((samples[0].im - 1e-30).abs() < 1e-35);
1116        assert_eq!(samples[1].re, 1.0);
1117        assert_eq!(samples[1].im, 0.0);
1118        assert_eq!(samples[2].re, 0.0);
1119        assert_eq!(samples[2].im, 0.0);
1120    }
1121
1122    #[test]
1123    fn test_flush_denormals_empty() {
1124        let mut samples: [f32; 0] = [];
1125        flush_denormals_inplace(&mut samples);
1126    }
1127
1128    #[test]
1129    fn test_flush_denormals_unaligned() {
1130        let mut samples = [1e-31_f32; 7];
1131        flush_denormals_inplace(&mut samples);
1132        for s in samples.iter() {
1133            assert_eq!(*s, 0.0);
1134        }
1135    }
1136
1137    #[test]
1138    fn test_enable_ftz_daz_does_not_panic() {
1139        // enable_ftz_daz() should complete without panicking on any platform.
1140        // On x86_64 or aarch64 it returns true; on other platforms false.
1141        let result = enable_ftz_daz();
1142        #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
1143        assert!(
1144            result,
1145            "enable_ftz_daz should return true on supported platforms"
1146        );
1147        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1148        assert!(
1149            !result,
1150            "enable_ftz_daz should return false on unsupported platforms"
1151        );
1152    }
1153
1154    #[test]
1155    fn test_apply_gain_simd_known_values() {
1156        // Apply gain of 2.0 to known input
1157        let mut buffer = vec![1.0, 2.0, 3.0, 4.0, -1.0, 0.5, -0.5, 0.0, 1.5];
1158        let expected: Vec<f32> = buffer.iter().map(|&x| x * 2.0).collect();
1159        apply_gain_simd(&mut buffer, 2.0);
1160        for (i, (&got, &exp)) in buffer.iter().zip(expected.iter()).enumerate() {
1161            assert!(
1162                (got - exp).abs() < 1e-6,
1163                "apply_gain_simd mismatch at index {}: got {}, expected {}",
1164                i,
1165                got,
1166                exp
1167            );
1168        }
1169    }
1170
1171    #[test]
1172    fn test_apply_gain_simd_zero_gain() {
1173        let mut buffer = vec![1.0, -2.0, 3.5, 0.7];
1174        apply_gain_simd(&mut buffer, 0.0);
1175        for (i, &v) in buffer.iter().enumerate() {
1176            assert_eq!(
1177                v, 0.0,
1178                "apply_gain_simd with zero gain: index {} not zero",
1179                i
1180            );
1181        }
1182    }
1183
1184    #[test]
1185    fn test_apply_gain_simd_unity_gain() {
1186        let original = vec![1.0, -2.0, 3.5, 0.7, 0.0, -0.1];
1187        let mut buffer = original.clone();
1188        apply_gain_simd(&mut buffer, 1.0);
1189        assert_eq!(buffer, original);
1190    }
1191
1192    #[test]
1193    fn test_apply_per_channel_gain_simd_stereo() {
1194        // Stereo buffer: apply different gains to L and R
1195        let mut buffer = vec![1.0, 2.0, 3.0, 4.0]; // 2 frames, 2 channels
1196        let gains = vec![0.5, 2.0]; // L=0.5, R=2.0
1197        apply_per_channel_gain_simd(&mut buffer, 2, &gains);
1198        assert!((buffer[0] - 0.5).abs() < 1e-6, "L frame 0");
1199        assert!((buffer[1] - 4.0).abs() < 1e-6, "R frame 0");
1200        assert!((buffer[2] - 1.5).abs() < 1e-6, "L frame 1");
1201        assert!((buffer[3] - 8.0).abs() < 1e-6, "R frame 1");
1202    }
1203
1204    // ============================================================================
1205    // SIMD Correctness Tests
1206    // ============================================================================
1207    //
1208    // These tests verify that SIMD-optimized complex multiplication produces
1209    // identical results to scalar computation
1210
1211    #[test]
1212    fn test_simd_complex_mul_add_correctness() {
1213        // Test SIMD complex multiply-accumulate against scalar reference
1214        use rustfft::num_complex::Complex;
1215
1216        // Test with 8 complex numbers (to test both AVX2 and NEON code paths)
1217        let src = vec![
1218            Complex::new(1.0, 2.0),
1219            Complex::new(3.0, 4.0),
1220            Complex::new(-1.0, 0.5),
1221            Complex::new(0.0, -2.0),
1222            Complex::new(2.5, -1.5),
1223            Complex::new(-3.5, 2.5),
1224            Complex::new(1.1, -0.9),
1225            Complex::new(-0.8, 1.2),
1226        ];
1227
1228        let hrtf = vec![
1229            Complex::new(0.5, 0.25),
1230            Complex::new(-1.0, 1.5),
1231            Complex::new(2.0, -0.5),
1232            Complex::new(0.75, 0.75),
1233            Complex::new(-0.5, 2.0),
1234            Complex::new(1.5, -1.0),
1235            Complex::new(0.9, 0.3),
1236            Complex::new(-1.1, 0.7),
1237        ];
1238
1239        let initial = vec![
1240            Complex::new(0.1, 0.2),
1241            Complex::new(0.3, 0.4),
1242            Complex::new(0.5, 0.6),
1243            Complex::new(0.7, 0.8),
1244            Complex::new(0.9, 1.0),
1245            Complex::new(1.1, 1.2),
1246            Complex::new(1.3, 1.4),
1247            Complex::new(1.5, 1.6),
1248        ];
1249
1250        // Scalar reference computation
1251        let mut expected = initial.clone();
1252        for i in 0..src.len() {
1253            expected[i] += src[i] * hrtf[i];
1254        }
1255
1256        // SIMD computation
1257        let mut result = initial.clone();
1258        complex_mul_add_simd(&mut result, &src, &hrtf);
1259
1260        // Compare results with tolerance for floating point errors
1261        const EPSILON: f32 = 1e-6;
1262        for i in 0..src.len() {
1263            assert!(
1264                (result[i].re - expected[i].re).abs() < EPSILON,
1265                "SIMD result[{}].re = {}, expected = {} (diff = {})",
1266                i,
1267                result[i].re,
1268                expected[i].re,
1269                (result[i].re - expected[i].re).abs()
1270            );
1271            assert!(
1272                (result[i].im - expected[i].im).abs() < EPSILON,
1273                "SIMD result[{}].im = {}, expected = {} (diff = {})",
1274                i,
1275                result[i].im,
1276                expected[i].im,
1277                (result[i].im - expected[i].im).abs()
1278            );
1279        }
1280    }
1281
1282    #[test]
1283    fn test_simd_complex_mul_correctness() {
1284        // Test SIMD complex multiplication (without accumulation)
1285        use rustfft::num_complex::Complex;
1286
1287        let src = vec![
1288            Complex::new(2.0, 3.0),
1289            Complex::new(-1.5, 2.5),
1290            Complex::new(0.5, -1.0),
1291            Complex::new(4.0, -2.0),
1292        ];
1293
1294        let hrtf = vec![
1295            Complex::new(1.0, 0.5),
1296            Complex::new(2.0, -1.0),
1297            Complex::new(-0.5, 1.5),
1298            Complex::new(0.75, 0.25),
1299        ];
1300
1301        // Scalar reference
1302        let expected: Vec<Complex<f32>> = src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
1303
1304        // SIMD computation
1305        let mut result = vec![Complex::new(0.0, 0.0); src.len()];
1306        complex_mul_simd(&mut result, &src, &hrtf);
1307
1308        // Compare
1309        const EPSILON: f32 = 1e-6;
1310        for i in 0..src.len() {
1311            assert!(
1312                (result[i].re - expected[i].re).abs() < EPSILON,
1313                "SIMD result[{}].re = {}, expected = {}",
1314                i,
1315                result[i].re,
1316                expected[i].re
1317            );
1318            assert!(
1319                (result[i].im - expected[i].im).abs() < EPSILON,
1320                "SIMD result[{}].im = {}, expected = {}",
1321                i,
1322                result[i].im,
1323                expected[i].im
1324            );
1325        }
1326    }
1327
1328    #[test]
1329    fn test_simd_edge_cases() {
1330        // Test edge cases: zeros, ones, conjugates
1331        use rustfft::num_complex::Complex;
1332
1333        // Test 1: Multiply by zero
1334        let src = vec![
1335            Complex::new(1.0, 2.0),
1336            Complex::new(3.0, 4.0),
1337            Complex::new(5.0, 6.0),
1338            Complex::new(7.0, 8.0),
1339        ];
1340        let zero = vec![Complex::new(0.0, 0.0); 4];
1341        let mut result = src.clone();
1342        let input = result.clone();
1343        complex_mul_simd(&mut result, &input, &zero);
1344        for i in 0..4 {
1345            assert_eq!(result[i].re, 0.0);
1346            assert_eq!(result[i].im, 0.0);
1347        }
1348
1349        // Test 2: Multiply by one (identity)
1350        let one = vec![Complex::new(1.0, 0.0); 4];
1351        let mut result = vec![Complex::new(0.0, 0.0); 4];
1352        complex_mul_simd(&mut result, &src, &one);
1353        for i in 0..4 {
1354            assert!((result[i].re - src[i].re).abs() < 1e-6);
1355            assert!((result[i].im - src[i].im).abs() < 1e-6);
1356        }
1357
1358        // Test 3: Multiply by conjugate (should give real result)
1359        let a = Complex::new(3.0, 4.0);
1360        let a_conj = Complex::new(3.0, -4.0);
1361        let src = vec![a, a, a, a];
1362        let conj = vec![a_conj, a_conj, a_conj, a_conj];
1363        let mut result = vec![Complex::new(0.0, 0.0); 4];
1364        complex_mul_simd(&mut result, &src, &conj);
1365
1366        // a * conj(a) = |a|^2 = 3^2 + 4^2 = 25
1367        for i in 0..4 {
1368            assert!((result[i].re - 25.0).abs() < 1e-5);
1369            assert!(result[i].im.abs() < 1e-5); // Should be approximately zero
1370        }
1371    }
1372
1373    #[test]
1374    fn test_simd_large_buffer() {
1375        // Test with realistic FFT buffer sizes
1376        use rustfft::num_complex::Complex;
1377
1378        for fft_size in [512, 1024, 2048, 4096] {
1379            let mut src = Vec::with_capacity(fft_size);
1380            let mut hrtf = Vec::with_capacity(fft_size);
1381
1382            // Fill with test pattern
1383            for i in 0..fft_size {
1384                let phase = (i as f32) * 0.01;
1385                src.push(Complex::new(phase.cos(), phase.sin()));
1386                hrtf.push(Complex::new(0.5, 0.25));
1387            }
1388
1389            // Scalar reference
1390            let mut expected = vec![Complex::new(0.1, 0.2); fft_size];
1391            for i in 0..fft_size {
1392                expected[i] += src[i] * hrtf[i];
1393            }
1394
1395            // SIMD computation
1396            let mut result = vec![Complex::new(0.1, 0.2); fft_size];
1397            complex_mul_add_simd(&mut result, &src, &hrtf);
1398
1399            // Verify all elements match
1400            for i in 0..fft_size {
1401                assert!(
1402                    (result[i].re - expected[i].re).abs() < 1e-5,
1403                    "FFT size {}, index {}: SIMD mismatch",
1404                    fft_size,
1405                    i
1406                );
1407                assert!(
1408                    (result[i].im - expected[i].im).abs() < 1e-5,
1409                    "FFT size {}, index {}: SIMD mismatch",
1410                    fft_size,
1411                    i
1412                );
1413            }
1414        }
1415    }
1416
1417    #[test]
1418    fn test_simd_unaligned_sizes() {
1419        // Test with buffer sizes that don't align to SIMD width
1420        // This ensures the scalar remainder loop works correctly
1421        use rustfft::num_complex::Complex;
1422
1423        for size in [1, 3, 5, 7, 9, 13, 17] {
1424            let src: Vec<Complex<f32>> = (0..size)
1425                .map(|i| Complex::new(i as f32, (i as f32) * 0.5))
1426                .collect();
1427            let hrtf: Vec<Complex<f32>> = (0..size)
1428                .map(|i| Complex::new(0.5, (i as f32) * 0.1))
1429                .collect();
1430
1431            // Scalar reference
1432            let expected: Vec<Complex<f32>> =
1433                src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
1434
1435            // SIMD computation
1436            let mut result = vec![Complex::new(0.0, 0.0); size];
1437            complex_mul_simd(&mut result, &src, &hrtf);
1438
1439            // Verify
1440            for i in 0..size {
1441                assert!(
1442                    (result[i].re - expected[i].re).abs() < 1e-6,
1443                    "Size {}, index {}: re mismatch",
1444                    size,
1445                    i
1446                );
1447                assert!(
1448                    (result[i].im - expected[i].im).abs() < 1e-6,
1449                    "Size {}, index {}: im mismatch",
1450                    size,
1451                    i
1452                );
1453            }
1454        }
1455    }
1456
1457    // ============================================================================
1458    // Comprehensive Tests for complex_mul_inplace_simd
1459    // ============================================================================
1460
1461    #[test]
1462    fn test_simd_complex_mul_inplace_correctness() {
1463        // Basic correctness test for in-place complex multiplication
1464        use rustfft::num_complex::Complex;
1465
1466        let src = vec![
1467            Complex::new(2.0, 3.0),
1468            Complex::new(-1.5, 2.5),
1469            Complex::new(0.5, -1.0),
1470            Complex::new(4.0, -2.0),
1471        ];
1472
1473        let hrtf = vec![
1474            Complex::new(1.0, 0.5),
1475            Complex::new(2.0, -1.0),
1476            Complex::new(-0.5, 1.5),
1477            Complex::new(0.75, 0.25),
1478        ];
1479
1480        // Scalar reference
1481        let mut expected = src.clone();
1482        for i in 0..expected.len() {
1483            expected[i] *= hrtf[i];
1484        }
1485
1486        // SIMD in-place computation
1487        let mut result = src.clone();
1488        complex_mul_inplace_simd(&mut result, &hrtf);
1489
1490        const EPSILON: f32 = 1e-6;
1491        for i in 0..result.len() {
1492            assert!(
1493                (result[i].re - expected[i].re).abs() < EPSILON,
1494                "Index {}: re mismatch {} vs {}",
1495                i,
1496                result[i].re,
1497                expected[i].re
1498            );
1499            assert!(
1500                (result[i].im - expected[i].im).abs() < EPSILON,
1501                "Index {}: im mismatch {} vs {}",
1502                i,
1503                result[i].im,
1504                expected[i].im
1505            );
1506        }
1507    }
1508
1509    #[test]
1510    fn test_simd_inplace_large_buffers() {
1511        // Test in-place multiplication with realistic FFT buffer sizes
1512        use rustfft::num_complex::Complex;
1513
1514        for fft_size in [128, 256, 512, 1024, 2048] {
1515            let mut src: Vec<Complex<f32>> = (0..fft_size)
1516                .map(|i| {
1517                    let phase = (i as f32) * 0.01;
1518                    Complex::new(phase.cos(), phase.sin())
1519                })
1520                .collect();
1521
1522            let hrtf: Vec<Complex<f32>> = (0..fft_size)
1523                .map(|i| Complex::new(0.5 + (i as f32) * 0.001, 0.25))
1524                .collect();
1525
1526            // Scalar reference
1527            let mut expected = src.clone();
1528            for i in 0..fft_size {
1529                expected[i] *= hrtf[i];
1530            }
1531
1532            // SIMD computation
1533            complex_mul_inplace_simd(&mut src, &hrtf);
1534
1535            // Verify all elements match
1536            for i in 0..fft_size {
1537                assert!(
1538                    (src[i].re - expected[i].re).abs() < 1e-5,
1539                    "FFT size {}, index {}: re mismatch",
1540                    fft_size,
1541                    i
1542                );
1543                assert!(
1544                    (src[i].im - expected[i].im).abs() < 1e-5,
1545                    "FFT size {}, index {}: im mismatch",
1546                    fft_size,
1547                    i
1548                );
1549            }
1550        }
1551    }
1552
1553    #[test]
1554    fn test_simd_inplace_unaligned() {
1555        // Test with sizes that don't align to SIMD width (4 for AVX2, 2 for NEON)
1556        use rustfft::num_complex::Complex;
1557
1558        for size in [1, 2, 3, 5, 6, 7, 9, 10, 11, 15, 17, 19, 23] {
1559            let mut src: Vec<Complex<f32>> = (0..size)
1560                .map(|i| Complex::new((i as f32) * 0.5, (i as f32) * -0.3))
1561                .collect();
1562
1563            let hrtf: Vec<Complex<f32>> = (0..size)
1564                .map(|i| Complex::new(1.0 + (i as f32) * 0.1, 0.5))
1565                .collect();
1566
1567            // Scalar reference
1568            let mut expected = src.clone();
1569            for i in 0..size {
1570                expected[i] *= hrtf[i];
1571            }
1572
1573            // SIMD computation
1574            complex_mul_inplace_simd(&mut src, &hrtf);
1575
1576            // Verify
1577            for i in 0..size {
1578                assert!(
1579                    (src[i].re - expected[i].re).abs() < 1e-6,
1580                    "Size {}, index {}: re mismatch",
1581                    size,
1582                    i
1583                );
1584                assert!(
1585                    (src[i].im - expected[i].im).abs() < 1e-6,
1586                    "Size {}, index {}: im mismatch",
1587                    size,
1588                    i
1589                );
1590            }
1591        }
1592    }
1593
1594    #[test]
1595    fn test_simd_inplace_edge_cases() {
1596        // Test edge cases: zeros, ones, conjugates, negative values
1597        use rustfft::num_complex::Complex;
1598
1599        // Test 1: Multiply by zero (should zero out the buffer)
1600        let mut src = vec![
1601            Complex::new(1.0, 2.0),
1602            Complex::new(3.0, 4.0),
1603            Complex::new(5.0, 6.0),
1604            Complex::new(7.0, 8.0),
1605        ];
1606        let zero = vec![Complex::new(0.0, 0.0); 4];
1607        complex_mul_inplace_simd(&mut src, &zero);
1608        for i in 0..4 {
1609            assert!(src[i].re.abs() < 1e-6, "Expected zero, got {}", src[i].re);
1610            assert!(src[i].im.abs() < 1e-6, "Expected zero, got {}", src[i].im);
1611        }
1612
1613        // Test 2: Multiply by one (should be identity)
1614        let original = vec![
1615            Complex::new(1.5, 2.5),
1616            Complex::new(-3.5, 4.5),
1617            Complex::new(5.5, -6.5),
1618            Complex::new(-7.5, -8.5),
1619        ];
1620        let mut src = original.clone();
1621        let one = vec![Complex::new(1.0, 0.0); 4];
1622        complex_mul_inplace_simd(&mut src, &one);
1623        for i in 0..4 {
1624            assert!((src[i].re - original[i].re).abs() < 1e-6);
1625            assert!((src[i].im - original[i].im).abs() < 1e-6);
1626        }
1627
1628        // Test 3: Multiply by conjugate (should give real result with magnitude |a|^2)
1629        let a = Complex::new(3.0, 4.0);
1630        let a_conj = Complex::new(3.0, -4.0);
1631        let mut src = vec![a; 8];
1632        let conj = vec![a_conj; 8];
1633        complex_mul_inplace_simd(&mut src, &conj);
1634        // a * conj(a) = |a|^2 = 3^2 + 4^2 = 25
1635        for i in 0..8 {
1636            assert!(
1637                (src[i].re - 25.0).abs() < 1e-5,
1638                "Expected 25.0, got {}",
1639                src[i].re
1640            );
1641            assert!(src[i].im.abs() < 1e-5, "Expected ~0, got {}", src[i].im);
1642        }
1643
1644        // Test 4: Multiply by i (rotation by 90 degrees)
1645        let mut src = vec![Complex::new(1.0, 0.0); 4];
1646        let i_val = vec![Complex::new(0.0, 1.0); 4];
1647        complex_mul_inplace_simd(&mut src, &i_val);
1648        for idx in 0..4 {
1649            assert!(src[idx].re.abs() < 1e-6, "Expected 0, got {}", src[idx].re);
1650            assert!(
1651                (src[idx].im - 1.0).abs() < 1e-6,
1652                "Expected 1, got {}",
1653                src[idx].im
1654            );
1655        }
1656    }
1657
1658    #[test]
1659    fn test_simd_inplace_negative_values() {
1660        // Test with all negative values
1661        use rustfft::num_complex::Complex;
1662
1663        let mut src = vec![
1664            Complex::new(-1.0, -2.0),
1665            Complex::new(-3.0, -4.0),
1666            Complex::new(-5.0, -6.0),
1667            Complex::new(-7.0, -8.0),
1668        ];
1669
1670        let hrtf = vec![
1671            Complex::new(-0.5, -0.25),
1672            Complex::new(-1.0, -1.5),
1673            Complex::new(-2.0, 0.5),
1674            Complex::new(0.75, -0.75),
1675        ];
1676
1677        // Scalar reference
1678        let mut expected = src.clone();
1679        for i in 0..expected.len() {
1680            expected[i] *= hrtf[i];
1681        }
1682
1683        // SIMD computation
1684        complex_mul_inplace_simd(&mut src, &hrtf);
1685
1686        const EPSILON: f32 = 1e-6;
1687        for i in 0..src.len() {
1688            assert!((src[i].re - expected[i].re).abs() < EPSILON);
1689            assert!((src[i].im - expected[i].im).abs() < EPSILON);
1690        }
1691    }
1692
1693    // ============================================================================
1694    // Comprehensive Tests for compute_covariance_simd
1695    // ============================================================================
1696
1697    #[test]
1698    fn test_covariance_basic_correctness() {
1699        // Test basic covariance computation against scalar reference
1700        use rustfft::num_complex::Complex;
1701
1702        let left = vec![
1703            Complex::new(1.0, 2.0),
1704            Complex::new(3.0, 4.0),
1705            Complex::new(-1.0, 0.5),
1706            Complex::new(0.0, -2.0),
1707            Complex::new(2.5, -1.5),
1708            Complex::new(-3.5, 2.5),
1709            Complex::new(1.1, -0.9),
1710            Complex::new(-0.8, 1.2),
1711        ];
1712
1713        let right = vec![
1714            Complex::new(0.5, 0.25),
1715            Complex::new(-1.0, 1.5),
1716            Complex::new(2.0, -0.5),
1717            Complex::new(0.75, 0.75),
1718            Complex::new(-0.5, 2.0),
1719            Complex::new(1.5, -1.0),
1720            Complex::new(0.9, 0.3),
1721            Complex::new(-1.1, 0.7),
1722        ];
1723
1724        // Scalar reference
1725        let mut expected_xx = 0.0_f32;
1726        let mut expected_yy = 0.0_f32;
1727        let mut expected_xy = Complex::new(0.0, 0.0);
1728        for i in 0..left.len() {
1729            expected_xx += left[i].norm_sqr();
1730            expected_yy += right[i].norm_sqr();
1731            expected_xy += left[i] * right[i].conj();
1732        }
1733
1734        // SIMD computation
1735        let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, 0, left.len());
1736
1737        const EPSILON: f32 = 1e-5;
1738        assert!(
1739            (cov_xx - expected_xx).abs() < EPSILON,
1740            "cov_xx mismatch: {} vs {}",
1741            cov_xx,
1742            expected_xx
1743        );
1744        assert!(
1745            (cov_yy - expected_yy).abs() < EPSILON,
1746            "cov_yy mismatch: {} vs {}",
1747            cov_yy,
1748            expected_yy
1749        );
1750        assert!(
1751            (cov_xy.re - expected_xy.re).abs() < EPSILON,
1752            "cov_xy.re mismatch: {} vs {}",
1753            cov_xy.re,
1754            expected_xy.re
1755        );
1756        assert!(
1757            (cov_xy.im - expected_xy.im).abs() < EPSILON,
1758            "cov_xy.im mismatch: {} vs {}",
1759            cov_xy.im,
1760            expected_xy.im
1761        );
1762    }
1763
1764    #[test]
1765    fn test_covariance_with_ranges() {
1766        // Test covariance computation with different start/end ranges
1767        use rustfft::num_complex::Complex;
1768
1769        let left: Vec<Complex<f32>> = (0..32)
1770            .map(|i| Complex::new(i as f32 * 0.5, i as f32 * -0.3))
1771            .collect();
1772        let right: Vec<Complex<f32>> = (0..32)
1773            .map(|i| Complex::new(i as f32 * -0.4, i as f32 * 0.6))
1774            .collect();
1775
1776        // Test various ranges
1777        for (start, end) in [(0, 8), (4, 12), (10, 20), (5, 25), (0, 32)] {
1778            // Scalar reference
1779            let mut expected_xx = 0.0_f32;
1780            let mut expected_yy = 0.0_f32;
1781            let mut expected_xy = Complex::new(0.0, 0.0);
1782            for i in start..end {
1783                expected_xx += left[i].norm_sqr();
1784                expected_yy += right[i].norm_sqr();
1785                expected_xy += left[i] * right[i].conj();
1786            }
1787
1788            // SIMD computation
1789            let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, start, end);
1790
1791            const EPSILON: f32 = 1e-4;
1792            assert!(
1793                (cov_xx - expected_xx).abs() < EPSILON,
1794                "Range [{}, {}): cov_xx mismatch: {} vs {}",
1795                start,
1796                end,
1797                cov_xx,
1798                expected_xx
1799            );
1800            assert!(
1801                (cov_yy - expected_yy).abs() < EPSILON,
1802                "Range [{}, {}): cov_yy mismatch: {} vs {}",
1803                start,
1804                end,
1805                cov_yy,
1806                expected_yy
1807            );
1808            assert!(
1809                (cov_xy.re - expected_xy.re).abs() < EPSILON,
1810                "Range [{}, {}): cov_xy.re mismatch: {} vs {}",
1811                start,
1812                end,
1813                cov_xy.re,
1814                expected_xy.re
1815            );
1816            assert!(
1817                (cov_xy.im - expected_xy.im).abs() < EPSILON,
1818                "Range [{}, {}): cov_xy.im mismatch: {} vs {}",
1819                start,
1820                end,
1821                cov_xy.im,
1822                expected_xy.im
1823            );
1824        }
1825    }
1826
1827    #[test]
1828    fn test_covariance_large_buffers() {
1829        // Test with realistic FFT buffer sizes
1830        use rustfft::num_complex::Complex;
1831
1832        for fft_size in [128, 256, 512, 1024, 2048, 4096] {
1833            let left: Vec<Complex<f32>> = (0..fft_size)
1834                .map(|i| {
1835                    let phase = (i as f32) * 0.01;
1836                    Complex::new(phase.cos(), phase.sin())
1837                })
1838                .collect();
1839
1840            let right: Vec<Complex<f32>> = (0..fft_size)
1841                .map(|i| {
1842                    let phase = (i as f32) * 0.02;
1843                    Complex::new(phase.sin(), phase.cos())
1844                })
1845                .collect();
1846
1847            // Scalar reference
1848            let mut expected_xx = 0.0_f32;
1849            let mut expected_yy = 0.0_f32;
1850            let mut expected_xy = Complex::new(0.0, 0.0);
1851            for i in 0..fft_size {
1852                expected_xx += left[i].norm_sqr();
1853                expected_yy += right[i].norm_sqr();
1854                expected_xy += left[i] * right[i].conj();
1855            }
1856
1857            // SIMD computation
1858            let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, 0, fft_size);
1859
1860            // Relative tolerance for larger sums
1861            let rel_epsilon = 1e-4;
1862            assert!(
1863                (cov_xx - expected_xx).abs() < expected_xx * rel_epsilon,
1864                "FFT size {}: cov_xx mismatch",
1865                fft_size
1866            );
1867            assert!(
1868                (cov_yy - expected_yy).abs() < expected_yy * rel_epsilon,
1869                "FFT size {}: cov_yy mismatch",
1870                fft_size
1871            );
1872            assert!(
1873                (cov_xy.re - expected_xy.re).abs() < expected_xy.re.abs() * rel_epsilon + 1e-5,
1874                "FFT size {}: cov_xy.re mismatch",
1875                fft_size
1876            );
1877            assert!(
1878                (cov_xy.im - expected_xy.im).abs() < expected_xy.im.abs() * rel_epsilon + 1e-5,
1879                "FFT size {}: cov_xy.im mismatch",
1880                fft_size
1881            );
1882        }
1883    }
1884
1885    #[test]
1886    fn test_covariance_unaligned_ranges() {
1887        // Test with ranges that don't align to SIMD width
1888        use rustfft::num_complex::Complex;
1889
1890        let left: Vec<Complex<f32>> = (0..50)
1891            .map(|i| Complex::new(i as f32 * 0.2, i as f32 * 0.3))
1892            .collect();
1893        let right: Vec<Complex<f32>> = (0..50)
1894            .map(|i| Complex::new(i as f32 * -0.1, i as f32 * 0.4))
1895            .collect();
1896
1897        // Test with various unaligned ranges
1898        for (start, end) in [(0, 1), (0, 3), (1, 4), (2, 7), (5, 11), (10, 23), (15, 37)] {
1899            // Scalar reference
1900            let mut expected_xx = 0.0_f32;
1901            let mut expected_yy = 0.0_f32;
1902            let mut expected_xy = Complex::new(0.0, 0.0);
1903            for i in start..end {
1904                expected_xx += left[i].norm_sqr();
1905                expected_yy += right[i].norm_sqr();
1906                expected_xy += left[i] * right[i].conj();
1907            }
1908
1909            // SIMD computation
1910            let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, start, end);
1911
1912            const EPSILON: f32 = 1e-5;
1913            assert!(
1914                (cov_xx - expected_xx).abs() < EPSILON,
1915                "Range [{}, {}): cov_xx mismatch",
1916                start,
1917                end
1918            );
1919            assert!(
1920                (cov_yy - expected_yy).abs() < EPSILON,
1921                "Range [{}, {}): cov_yy mismatch",
1922                start,
1923                end
1924            );
1925            assert!(
1926                (cov_xy.re - expected_xy.re).abs() < EPSILON,
1927                "Range [{}, {}): cov_xy.re mismatch",
1928                start,
1929                end
1930            );
1931            assert!(
1932                (cov_xy.im - expected_xy.im).abs() < EPSILON,
1933                "Range [{}, {}): cov_xy.im mismatch",
1934                start,
1935                end
1936            );
1937        }
1938    }
1939
1940    #[test]
1941    fn test_covariance_edge_cases() {
1942        // Test covariance with edge cases
1943        use rustfft::num_complex::Complex;
1944
1945        // Test 1: All zeros
1946        let zero_left = vec![Complex::new(0.0, 0.0); 8];
1947        let zero_right = vec![Complex::new(0.0, 0.0); 8];
1948        let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&zero_left, &zero_right, 0, 8);
1949        assert!(cov_xx.abs() < 1e-6, "Expected zero cov_xx");
1950        assert!(cov_yy.abs() < 1e-6, "Expected zero cov_yy");
1951        assert!(cov_xy.norm_sqr() < 1e-6, "Expected zero cov_xy");
1952
1953        // Test 2: Real-only signals
1954        let real_left: Vec<Complex<f32>> = (0..8).map(|i| Complex::new(i as f32, 0.0)).collect();
1955        let real_right: Vec<Complex<f32>> = (0..8)
1956            .map(|i| Complex::new((i as f32) * 0.5, 0.0))
1957            .collect();
1958        let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&real_left, &real_right, 0, 8);
1959
1960        // cov_xy should be real (imaginary part should be ~0)
1961        assert!(
1962            cov_xy.im.abs() < 1e-5,
1963            "Expected real cov_xy for real signals"
1964        );
1965
1966        // Verify values match scalar computation
1967        let mut expected_xx = 0.0;
1968        let mut expected_yy = 0.0;
1969        for i in 0..8 {
1970            expected_xx += (i * i) as f32;
1971            expected_yy += ((i as f32) * 0.5).powi(2);
1972        }
1973        assert!((cov_xx - expected_xx).abs() < 1e-5);
1974        assert!((cov_yy - expected_yy).abs() < 1e-5);
1975
1976        // Test 3: Imaginary-only signals
1977        let imag_left: Vec<Complex<f32>> = (0..8).map(|i| Complex::new(0.0, i as f32)).collect();
1978        let imag_right: Vec<Complex<f32>> = (0..8)
1979            .map(|i| Complex::new(0.0, (i as f32) * 2.0))
1980            .collect();
1981        let (_cov_xx, _cov_yy, cov_xy) = compute_covariance_simd(&imag_left, &imag_right, 0, 8);
1982
1983        // cov_xy should be real (because conj flips imaginary sign)
1984        assert!(
1985            cov_xy.im.abs() < 1e-5,
1986            "Expected real cov_xy for imaginary signals"
1987        );
1988
1989        // Test 4: Single element range
1990        let single_left = vec![Complex::new(3.0, 4.0)];
1991        let single_right = vec![Complex::new(1.0, 2.0)];
1992        let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&single_left, &single_right, 0, 1);
1993        assert!((cov_xx - 25.0).abs() < 1e-5); // 3^2 + 4^2 = 25
1994        assert!((cov_yy - 5.0).abs() < 1e-5); // 1^2 + 2^2 = 5
1995        // (3 + 4i) * (1 - 2i) = 3 - 6i + 4i - 8i^2 = 3 - 2i + 8 = 11 - 2i
1996        assert!((cov_xy.re - 11.0).abs() < 1e-5);
1997        assert!((cov_xy.im - (-2.0)).abs() < 1e-5);
1998    }
1999
2000    // ============================================================================
2001    // Numerical Accuracy and Stress Tests
2002    // ============================================================================
2003
2004    #[test]
2005    fn test_numerical_accuracy_small_values() {
2006        // Test with very small values to check for denormal handling
2007        use rustfft::num_complex::Complex;
2008
2009        let small = 1e-20_f32;
2010        let src = vec![
2011            Complex::new(small, small),
2012            Complex::new(small * 2.0, small * 3.0),
2013            Complex::new(small * 4.0, small * 5.0),
2014            Complex::new(small * 6.0, small * 7.0),
2015        ];
2016
2017        let hrtf = vec![
2018            Complex::new(1.0, 0.5),
2019            Complex::new(2.0, -1.0),
2020            Complex::new(-0.5, 1.5),
2021            Complex::new(0.75, 0.25),
2022        ];
2023
2024        // Scalar reference
2025        let expected: Vec<Complex<f32>> = src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2026
2027        // SIMD computation
2028        let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2029        complex_mul_simd(&mut result, &src, &hrtf);
2030
2031        // Relative tolerance for very small numbers
2032        for i in 0..src.len() {
2033            let re_diff = (result[i].re - expected[i].re).abs();
2034            let im_diff = (result[i].im - expected[i].im).abs();
2035
2036            // For very small numbers, check relative error or absolute error for near-zero
2037            if expected[i].re.abs() > 1e-15 {
2038                assert!(re_diff / expected[i].re.abs() < 1e-3);
2039            } else {
2040                assert!(re_diff < 1e-25);
2041            }
2042
2043            if expected[i].im.abs() > 1e-15 {
2044                assert!(im_diff / expected[i].im.abs() < 1e-3);
2045            } else {
2046                assert!(im_diff < 1e-25);
2047            }
2048        }
2049    }
2050
2051    #[test]
2052    fn test_numerical_accuracy_large_values() {
2053        // Test with large values to check for overflow handling
2054        use rustfft::num_complex::Complex;
2055
2056        let large = 1e10_f32;
2057        let src = vec![
2058            Complex::new(large, large * 0.5),
2059            Complex::new(large * 2.0, large * 1.5),
2060            Complex::new(large * 0.3, large * 0.7),
2061            Complex::new(large * 1.2, large * 0.8),
2062        ];
2063
2064        let hrtf = vec![
2065            Complex::new(1e-5, 5e-6),
2066            Complex::new(2e-5, -1e-5),
2067            Complex::new(-5e-6, 1.5e-5),
2068            Complex::new(7.5e-6, 2.5e-6),
2069        ];
2070
2071        // Scalar reference
2072        let mut expected = vec![Complex::new(0.0, 0.0); src.len()];
2073        for i in 0..src.len() {
2074            expected[i] = src[i] * hrtf[i];
2075        }
2076
2077        // SIMD computation
2078        let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2079        complex_mul_simd(&mut result, &src, &hrtf);
2080
2081        // Relative tolerance
2082        for i in 0..src.len() {
2083            let re_rel_err = (result[i].re - expected[i].re).abs() / expected[i].re.abs().max(1.0);
2084            let im_rel_err = (result[i].im - expected[i].im).abs() / expected[i].im.abs().max(1.0);
2085            assert!(
2086                re_rel_err < 1e-5,
2087                "Index {}: re rel error too large: {}",
2088                i,
2089                re_rel_err
2090            );
2091            assert!(
2092                im_rel_err < 1e-5,
2093                "Index {}: im rel error too large: {}",
2094                i,
2095                im_rel_err
2096            );
2097        }
2098    }
2099
2100    #[test]
2101    fn test_accumulation_accuracy() {
2102        // Test that repeated accumulation doesn't degrade accuracy significantly
2103        use rustfft::num_complex::Complex;
2104
2105        let src = vec![
2106            Complex::new(0.1, 0.2),
2107            Complex::new(0.3, 0.4),
2108            Complex::new(0.5, 0.6),
2109            Complex::new(0.7, 0.8),
2110        ];
2111
2112        let hrtf = vec![
2113            Complex::new(0.5, 0.25),
2114            Complex::new(-1.0, 1.5),
2115            Complex::new(2.0, -0.5),
2116            Complex::new(0.75, 0.75),
2117        ];
2118
2119        // Scalar reference: accumulate 100 times
2120        let mut expected = vec![Complex::new(0.0, 0.0); src.len()];
2121        for _ in 0..100 {
2122            for i in 0..src.len() {
2123                expected[i] += src[i] * hrtf[i];
2124            }
2125        }
2126
2127        // SIMD computation: accumulate 100 times
2128        let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2129        for _ in 0..100 {
2130            complex_mul_add_simd(&mut result, &src, &hrtf);
2131        }
2132
2133        // Check relative error
2134        const REL_EPSILON: f32 = 1e-4;
2135        for i in 0..src.len() {
2136            let re_abs_err = (result[i].re - expected[i].re).abs();
2137            let im_abs_err = (result[i].im - expected[i].im).abs();
2138
2139            // Use relative error if expected value is large enough, otherwise use absolute error
2140            let re_err = if expected[i].re.abs() > 1e-6 {
2141                re_abs_err / expected[i].re.abs()
2142            } else {
2143                re_abs_err
2144            };
2145            let im_err = if expected[i].im.abs() > 1e-6 {
2146                im_abs_err / expected[i].im.abs()
2147            } else {
2148                im_abs_err
2149            };
2150
2151            assert!(
2152                re_err < REL_EPSILON,
2153                "Index {}: re accumulated error too large: {} (abs: {}, expected: {})",
2154                i,
2155                re_err,
2156                re_abs_err,
2157                expected[i].re
2158            );
2159            assert!(
2160                im_err < REL_EPSILON,
2161                "Index {}: im accumulated error too large: {} (abs: {}, expected: {})",
2162                i,
2163                im_err,
2164                im_abs_err,
2165                expected[i].im
2166            );
2167        }
2168    }
2169
2170    #[test]
2171    fn test_platform_specific_simd_widths() {
2172        // Test that SIMD code paths are correctly exercised
2173        use rustfft::num_complex::Complex;
2174
2175        // Test sizes that exercise SIMD boundaries:
2176        // - AVX2 processes 4 complex at a time
2177        // - NEON processes 2 complex at a time
2178        let test_sizes = vec![
2179            1,  // No SIMD, scalar only
2180            2,  // NEON: 1 iteration, AVX2: scalar only
2181            3,  // NEON: 1 iteration + 1 scalar, AVX2: scalar only
2182            4,  // NEON: 2 iterations, AVX2: 1 iteration
2183            5,  // NEON: 2 iterations + 1 scalar, AVX2: 1 iteration + 1 scalar
2184            8,  // NEON: 4 iterations, AVX2: 2 iterations
2185            9,  // Mixed
2186            12, // NEON: 6 iterations, AVX2: 3 iterations
2187            16, // NEON: 8 iterations, AVX2: 4 iterations
2188        ];
2189
2190        for size in test_sizes {
2191            let src: Vec<Complex<f32>> = (0..size)
2192                .map(|i| Complex::new(i as f32 * 0.3, i as f32 * -0.2))
2193                .collect();
2194            let hrtf: Vec<Complex<f32>> = (0..size)
2195                .map(|i| Complex::new(1.0 + i as f32 * 0.1, 0.5))
2196                .collect();
2197
2198            // Test all three functions
2199
2200            // 1. complex_mul_add_simd
2201            let mut result_add = vec![Complex::new(1.0, 2.0); size];
2202            let mut expected_add = result_add.clone();
2203            for i in 0..size {
2204                expected_add[i] += src[i] * hrtf[i];
2205            }
2206            complex_mul_add_simd(&mut result_add, &src, &hrtf);
2207            for i in 0..size {
2208                assert!(
2209                    (result_add[i].re - expected_add[i].re).abs() < 1e-6,
2210                    "mul_add size {}, index {}: re mismatch",
2211                    size,
2212                    i
2213                );
2214                assert!(
2215                    (result_add[i].im - expected_add[i].im).abs() < 1e-6,
2216                    "mul_add size {}, index {}: im mismatch",
2217                    size,
2218                    i
2219                );
2220            }
2221
2222            // 2. complex_mul_simd
2223            let mut result_mul = vec![Complex::new(0.0, 0.0); size];
2224            let expected_mul: Vec<Complex<f32>> =
2225                src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2226            complex_mul_simd(&mut result_mul, &src, &hrtf);
2227            for i in 0..size {
2228                assert!(
2229                    (result_mul[i].re - expected_mul[i].re).abs() < 1e-6,
2230                    "mul size {}, index {}: re mismatch",
2231                    size,
2232                    i
2233                );
2234                assert!(
2235                    (result_mul[i].im - expected_mul[i].im).abs() < 1e-6,
2236                    "mul size {}, index {}: im mismatch",
2237                    size,
2238                    i
2239                );
2240            }
2241
2242            // 3. complex_mul_inplace_simd
2243            let mut result_inplace = src.clone();
2244            let mut expected_inplace = src.clone();
2245            for i in 0..size {
2246                expected_inplace[i] *= hrtf[i];
2247            }
2248            complex_mul_inplace_simd(&mut result_inplace, &hrtf);
2249            for i in 0..size {
2250                assert!(
2251                    (result_inplace[i].re - expected_inplace[i].re).abs() < 1e-6,
2252                    "inplace size {}, index {}: re mismatch",
2253                    size,
2254                    i
2255                );
2256                assert!(
2257                    (result_inplace[i].im - expected_inplace[i].im).abs() < 1e-6,
2258                    "inplace size {}, index {}: im mismatch",
2259                    size,
2260                    i
2261                );
2262            }
2263        }
2264    }
2265
2266    #[test]
2267    fn test_stress_test_random_data() {
2268        // Stress test with pseudo-random data
2269        use rustfft::num_complex::Complex;
2270
2271        // Simple LCG for deterministic "random" values
2272        let mut seed = 12345_u32;
2273        let lcg = |s: &mut u32| -> f32 {
2274            *s = s.wrapping_mul(1103515245).wrapping_add(12345);
2275            ((*s / 65536) % 32768) as f32 / 32768.0 - 0.5
2276        };
2277
2278        for size in [64, 128, 256, 512] {
2279            let src: Vec<Complex<f32>> = (0..size)
2280                .map(|_| Complex::new(lcg(&mut seed), lcg(&mut seed)))
2281                .collect();
2282            let hrtf: Vec<Complex<f32>> = (0..size)
2283                .map(|_| Complex::new(lcg(&mut seed), lcg(&mut seed)))
2284                .collect();
2285
2286            // Scalar reference
2287            let expected: Vec<Complex<f32>> =
2288                src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2289
2290            // SIMD computation
2291            let mut result = vec![Complex::new(0.0, 0.0); size];
2292            complex_mul_simd(&mut result, &src, &hrtf);
2293
2294            // Verify
2295            for i in 0..size {
2296                assert!(
2297                    (result[i].re - expected[i].re).abs() < 1e-5,
2298                    "Stress test size {}, index {}: re mismatch",
2299                    size,
2300                    i
2301                );
2302                assert!(
2303                    (result[i].im - expected[i].im).abs() < 1e-5,
2304                    "Stress test size {}, index {}: im mismatch",
2305                    size,
2306                    i
2307                );
2308            }
2309        }
2310    }
2311}
2312
2313// ============================================================================
2314// SIMD Covariance Calculation for ERB Band Processing
2315// ============================================================================
2316//
2317// Computes covariance statistics for two complex arrays (left/right channels):
2318//   cov_xx = sum(|left[i]|^2)
2319//   cov_yy = sum(|right[i]|^2)
2320//   cov_xy = sum(left[i] * conj(right[i]))
2321//
2322// These are used for PCA-based direct/ambient decomposition in the upmixer.
2323
2324/// SIMD-accelerated covariance calculation for ERB bands
2325///
2326/// Returns (cov_xx, cov_yy, cov_xy) where:
2327/// - cov_xx: sum of left channel energy
2328/// - cov_yy: sum of right channel energy
2329/// - cov_xy: complex cross-correlation
2330pub fn compute_covariance_simd(
2331    left: &[Complex<f32>],
2332    right: &[Complex<f32>],
2333    start: usize,
2334    end: usize,
2335) -> (f32, f32, Complex<f32>) {
2336    assert_eq!(left.len(), right.len());
2337    assert!(end <= left.len());
2338    assert!(start < end);
2339
2340    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2341    let count = end - start;
2342
2343    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2344    {
2345        use std::arch::x86_64::*;
2346
2347        let mut cov_xx;
2348        let mut cov_yy;
2349        let mut cov_xy = Complex::new(0.0, 0.0);
2350
2351        // SIMD path: process 4 complex numbers at once
2352        let simd_len = (count / 4) * 4;
2353        let simd_end = start + simd_len;
2354
2355        unsafe {
2356            let mut sum_xx = _mm256_setzero_ps();
2357            let mut sum_yy = _mm256_setzero_ps();
2358            let mut sum_xy_re = _mm256_setzero_ps();
2359            let _sum_xy_im = _mm256_setzero_ps();
2360
2361            for i in (start..simd_end).step_by(4) {
2362                let left_ptr = left.as_ptr().add(i) as *const f32;
2363                let right_ptr = right.as_ptr().add(i) as *const f32;
2364
2365                // Load 4 complex numbers: [re0, im0, re1, im1, re2, im2, re3, im3]
2366                let l = _mm256_loadu_ps(left_ptr);
2367                let r = _mm256_loadu_ps(right_ptr);
2368
2369                // Compute norm_sqr: re^2 + im^2
2370                let l_sqr = _mm256_mul_ps(l, l);
2371                let r_sqr = _mm256_mul_ps(r, r);
2372
2373                // Horizontal add pairs: [re0^2 + im0^2, re1^2 + im1^2, ...]
2374                let l_norm = _mm256_hadd_ps(l_sqr, l_sqr);
2375                let r_norm = _mm256_hadd_ps(r_sqr, r_sqr);
2376
2377                sum_xx = _mm256_add_ps(sum_xx, l_norm);
2378                sum_yy = _mm256_add_ps(sum_yy, r_norm);
2379
2380                // Compute cross-correlation: left * conj(right)
2381                let sign_mask = _mm256_set_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
2382                let r_conj = _mm256_xor_ps(r, sign_mask);
2383
2384                // Complex multiplication: left * r_conj
2385                let l_re = _mm256_moveldup_ps(l);
2386                let l_im = _mm256_movehdup_ps(l);
2387
2388                let ac_ad = _mm256_mul_ps(l_re, r_conj);
2389                let r_conj_swap = _mm256_shuffle_ps(r_conj, r_conj, 0b10110001);
2390                let bd_bc = _mm256_mul_ps(l_im, r_conj_swap);
2391
2392                let result = _mm256_addsub_ps(ac_ad, bd_bc);
2393
2394                // Accumulate real parts (even indices) and imaginary parts (odd indices)
2395                sum_xy_re = _mm256_add_ps(sum_xy_re, result);
2396            }
2397
2398            // Horizontal reduction to scalars
2399            let xx_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_xx);
2400            let yy_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_yy);
2401            let xy_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_xy_re);
2402
2403            // Sum hadd results (each norm appears twice in each lane)
2404            // Lane 0: [norm0, norm1, norm0, norm1], Lane 1: [norm2, norm3, norm2, norm3]
2405            cov_xx = xx_arr[0] + xx_arr[1] + xx_arr[4] + xx_arr[5];
2406            cov_yy = yy_arr[0] + yy_arr[1] + yy_arr[4] + yy_arr[5];
2407
2408            // Sum re (even) and im (odd) separately
2409            cov_xy.re = xy_arr[0] + xy_arr[2] + xy_arr[4] + xy_arr[6];
2410            cov_xy.im = xy_arr[1] + xy_arr[3] + xy_arr[5] + xy_arr[7];
2411        }
2412
2413        // Scalar tail for remaining elements
2414        for i in simd_end..end {
2415            let l = left[i];
2416            let r = right[i];
2417            cov_xx += l.norm_sqr();
2418            cov_yy += r.norm_sqr();
2419            cov_xy += l * r.conj();
2420        }
2421
2422        (cov_xx, cov_yy, cov_xy)
2423    }
2424
2425    #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
2426    {
2427        let mut cov_xx = 0.0_f32;
2428        let mut cov_yy = 0.0_f32;
2429        let mut cov_xy = Complex::new(0.0, 0.0);
2430
2431        for i in start..end {
2432            let l = left[i];
2433            let r = right[i];
2434            cov_xx += l.norm_sqr();
2435            cov_yy += r.norm_sqr();
2436            cov_xy += l * r.conj();
2437        }
2438
2439        (cov_xx, cov_yy, cov_xy)
2440    }
2441}
2442
2443/// SIMD-optimized gain application for a single gain value
2444#[inline]
2445pub fn apply_gain_simd(buffer: &mut [f32], gain: f32) {
2446    let len = buffer.len();
2447
2448    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2449    {
2450        use std::arch::x86_64::*;
2451        let gain_vec = unsafe { _mm256_set1_ps(gain) };
2452        let simd_len = (len / 8) * 8;
2453        for i in (0..simd_len).step_by(8) {
2454            unsafe {
2455                let ptr = buffer.as_mut_ptr().add(i);
2456                let v = _mm256_loadu_ps(ptr);
2457                let res = _mm256_mul_ps(v, gain_vec);
2458                _mm256_storeu_ps(ptr, res);
2459            }
2460        }
2461        for sample in buffer.iter_mut().take(len).skip(simd_len) {
2462            *sample *= gain;
2463        }
2464    }
2465
2466    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2467    {
2468        use std::arch::aarch64::*;
2469        let gain_vec = unsafe { vdupq_n_f32(gain) };
2470        let simd_len = (len / 4) * 4;
2471        for i in (0..simd_len).step_by(4) {
2472            unsafe {
2473                let ptr = buffer.as_mut_ptr().add(i);
2474                let v = vld1q_f32(ptr);
2475                let res = vmulq_f32(v, gain_vec);
2476                vst1q_f32(ptr, res);
2477            }
2478        }
2479        for sample in buffer[simd_len..len].iter_mut() {
2480            *sample *= gain;
2481        }
2482    }
2483
2484    #[cfg(not(any(
2485        all(target_arch = "x86_64", target_feature = "avx2"),
2486        all(target_arch = "aarch64", target_feature = "neon")
2487    )))]
2488    {
2489        for val in buffer.iter_mut() {
2490            *val *= gain;
2491        }
2492    }
2493}
2494
2495/// SIMD-optimized per-channel gain application
2496#[inline]
2497pub fn apply_per_channel_gain_simd(buffer: &mut [f32], channels: usize, gains: &[f32]) {
2498    let len = buffer.len();
2499    let num_frames = len / channels;
2500
2501    // This is harder to SIMD generically for any channel count.
2502    // We prioritize common channel counts (1, 2, 6, 8) or use scalar for now.
2503    // For stereo (channels == 2), we can optimize easily.
2504    if channels == 2 {
2505        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2506        {
2507            use std::arch::x86_64::*;
2508            let gains_vec = unsafe {
2509                _mm256_set_ps(
2510                    gains[1], gains[0], gains[1], gains[0], gains[1], gains[0], gains[1], gains[0],
2511                )
2512            };
2513            let simd_len = (num_frames / 4) * 4;
2514            for i in (0..simd_len).step_by(4) {
2515                unsafe {
2516                    let ptr = buffer.as_mut_ptr().add(i * 2);
2517                    let v = _mm256_loadu_ps(ptr);
2518                    let res = _mm256_mul_ps(v, gains_vec);
2519                    _mm256_storeu_ps(ptr, res);
2520                }
2521            }
2522            for i in simd_len..num_frames {
2523                buffer[i * 2] *= gains[0];
2524                buffer[i * 2 + 1] *= gains[1];
2525            }
2526            return;
2527        }
2528
2529        #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2530        {
2531            use std::arch::aarch64::*;
2532            let gains_vec = unsafe {
2533                let g = [gains[0], gains[1], gains[0], gains[1]];
2534                vld1q_f32(g.as_ptr())
2535            };
2536            let simd_len = (num_frames / 2) * 2;
2537            for i in (0..simd_len).step_by(2) {
2538                unsafe {
2539                    let ptr = buffer.as_mut_ptr().add(i * 2);
2540                    let v = vld1q_f32(ptr);
2541                    let res = vmulq_f32(v, gains_vec);
2542                    vst1q_f32(ptr, res);
2543                }
2544            }
2545            for i in simd_len..num_frames {
2546                buffer[i * 2] *= gains[0];
2547                buffer[i * 2 + 1] *= gains[1];
2548            }
2549            return;
2550        }
2551    }
2552
2553    // Fallback scalar loop
2554    for frame in 0..num_frames {
2555        for ch in 0..channels {
2556            buffer[frame * channels + ch] *= gains[ch];
2557        }
2558    }
2559}
2560
2561/// Fast inverse square root (1/sqrt(x)) using one Newton-Raphson iteration.
2562/// ~0.1% relative error — suitable for normalizing all-pass filter magnitudes.
2563#[inline(always)]
2564pub fn fast_inv_sqrt(x: f32) -> f32 {
2565    let half = 0.5 * x;
2566    let i = f32::to_bits(x);
2567    let i = 0x5f37_59df - (i >> 1); // Initial "magic number" estimate
2568    let y = f32::from_bits(i);
2569    y * (1.5 - half * y * y) // One Newton-Raphson refinement
2570}
2571
2572/// SIMD-optimized peak detection (maximum absolute value)
2573#[inline]
2574pub fn find_max_abs_simd(samples: &[f32]) -> f32 {
2575    let len = samples.len();
2576    if len == 0 {
2577        return 0.0;
2578    }
2579
2580    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2581    {
2582        use std::arch::x86_64::*;
2583        let mut max_vec = unsafe { _mm256_setzero_ps() };
2584        let abs_mask = unsafe { _mm256_set1_ps(-0.0) };
2585        let simd_len = (len / 8) * 8;
2586
2587        for i in (0..simd_len).step_by(8) {
2588            unsafe {
2589                let ptr = samples.as_ptr().add(i);
2590                let v = _mm256_loadu_ps(ptr);
2591                let av = _mm256_andnot_ps(abs_mask, v);
2592                max_vec = _mm256_max_ps(max_vec, av);
2593            }
2594        }
2595
2596        let mut max_val = 0.0_f32;
2597        unsafe {
2598            let arr = std::mem::transmute::<__m256, [f32; 8]>(max_vec);
2599            for &v in &arr {
2600                if v > max_val {
2601                    max_val = v;
2602                }
2603            }
2604        }
2605
2606        for sample in samples.iter().take(len).skip(simd_len) {
2607            let v = sample.abs();
2608            if v > max_val {
2609                max_val = v;
2610            }
2611        }
2612        max_val
2613    }
2614
2615    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2616    {
2617        use std::arch::aarch64::*;
2618        let mut max_vec = unsafe { vdupq_n_f32(0.0) };
2619        let simd_len = (len / 4) * 4;
2620
2621        for i in (0..simd_len).step_by(4) {
2622            unsafe {
2623                let ptr = samples.as_ptr().add(i);
2624                let v = vld1q_f32(ptr);
2625                let av = vabsq_f32(v);
2626                max_vec = vmaxq_f32(max_vec, av);
2627            }
2628        }
2629
2630        let mut max_val = unsafe { vmaxvq_f32(max_vec) };
2631
2632        for sample in &samples[simd_len..len] {
2633            let v = sample.abs();
2634            if v > max_val {
2635                max_val = v;
2636            }
2637        }
2638        max_val
2639    }
2640
2641    #[cfg(not(any(
2642        all(target_arch = "x86_64", target_feature = "avx2"),
2643        all(target_arch = "aarch64", target_feature = "neon")
2644    )))]
2645    {
2646        let mut max_val = 0.0_f32;
2647        for &s in samples {
2648            let v = s.abs();
2649            if v > max_val {
2650                max_val = v;
2651            }
2652        }
2653        max_val
2654    }
2655}