Skip to main content

math_audio_dsp/
simd.rs

1// ============================================================================
2// SIMD Optimizations for Complex Multiplication
3// ============================================================================
4//
5// These functions provide SIMD-accelerated complex multiplication for the
6// frequency-domain HRTF convolution hot paths. Complex multiplication:
7//   (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
8//
9// Platform support:
10// - x86-64: AVX2 (processes 4 complex f32 at once using 256-bit registers)
11// - aarch64 + FCMA: FCMLA instructions (2 ops per complex mul, ARMv8.3+, all Apple Silicon)
12// - aarch64: NEON fallback (processes 2 complex f32 at once using 128-bit registers)
13// - fallback: Scalar implementation for all other platforms
14//
15// The explicit AVX2/NEON/FCMA implementations are selected with compile-time
16// `target_feature` cfgs. Portable release builds therefore use the scalar
17// fallback unless built with flags such as `RUSTFLAGS="-C target-feature=+avx2"`
18// (or the corresponding aarch64 features).
19//
20// Performance gains: 2-4x speedup on supported platforms for FFT sizes >= 512
21
22use rustfft::num_complex::Complex;
23
24// FCMLA complex multiply-accumulate via inline asm (ARMv8.3+ / Apple Silicon)
25// Computes: r[i] += a[i] * b[i] for 2 complex f32 packed in float32x4_t
26#[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
27#[inline(always)]
28unsafe fn fcmla_mul_acc(
29    mut r: std::arch::aarch64::float32x4_t,
30    a: std::arch::aarch64::float32x4_t,
31    b: std::arch::aarch64::float32x4_t,
32) -> std::arch::aarch64::float32x4_t {
33    unsafe {
34        std::arch::asm!(
35            "fcmla {r:v}.4s, {a:v}.4s, {b:v}.4s, #0",
36            "fcmla {r:v}.4s, {a:v}.4s, {b:v}.4s, #90",
37            r = inout(vreg) r,
38            a = in(vreg) a,
39            b = in(vreg) b,
40            options(pure, nomem, nostack),
41        );
42    }
43    r
44}
45
46// AVX2 shuffle constant for swapping re/im pairs
47#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
48const SHUFFLE_SWAP_RE_IM: i32 = 0b10110001; // Swaps: [re, im] -> [im, re]
49
50#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
51/// # Safety
52/// Caller must ensure `dst`, `src` and `hrtf` have at least `start + 4` elements.
53#[inline]
54pub unsafe fn complex_mul_add_simd_chunk(
55    dst: &mut [Complex<f32>],
56    src: &[Complex<f32>],
57    hrtf: &[Complex<f32>],
58    start: usize,
59) {
60    use std::arch::x86_64::*;
61
62    // Process 4 complex numbers (8 floats) at once using AVX2
63    // Input layout: [re0, im0, re1, im1, re2, im2, re3, im3]
64    unsafe {
65        let src_ptr = src.as_ptr().add(start) as *const f32;
66        let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
67        let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
68
69        // Load 4 complex numbers
70        let a = _mm256_loadu_ps(src_ptr);
71        let b = _mm256_loadu_ps(hrtf_ptr);
72        let dst_val = _mm256_loadu_ps(dst_ptr);
73
74        // Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
75
76        // Duplicate real and imaginary parts correctly:
77        // moveldup: duplicates even elements [0, 0, 2, 2, 4, 4, 6, 6] -> [re0, re0, re1, re1, ...]
78        // movehdup: duplicates odd elements  [1, 1, 3, 3, 5, 5, 7, 7] -> [im0, im0, im1, im1, ...]
79        let a_re = _mm256_moveldup_ps(a);
80        let a_im = _mm256_movehdup_ps(a);
81
82        // Compute: a.re * b = [re*re, re*im, ...] = [ac, ad, ...]
83        let ac_ad = _mm256_mul_ps(a_re, b);
84
85        // Swap b's re/im: [im, re, im, re, ...] = [d, c, ...]
86        let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
87
88        // Compute: a.im * b_swapped = [im*im, im*re, ...] = [bd, bc, ...]
89        let bd_bc = _mm256_mul_ps(a_im, b_swapped);
90
91        // Combine using addsub: performs [a[0]-b[0], a[1]+b[1], a[2]-b[2], a[3]+b[3], ...]
92        // This gives us: [(ac - bd), (ad + bc), ...] = [result.re, result.im, ...]
93        let result = _mm256_addsub_ps(ac_ad, bd_bc);
94
95        // Add to destination (accumulate)
96        let final_result = _mm256_add_ps(dst_val, result);
97
98        _mm256_storeu_ps(dst_ptr, final_result);
99    }
100}
101
102#[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
103/// # Safety
104/// Caller must ensure `dst`, `src` and `hrtf` have at least `start + 2` elements.
105#[inline]
106pub unsafe fn complex_mul_add_simd_chunk(
107    dst: &mut [Complex<f32>],
108    src: &[Complex<f32>],
109    hrtf: &[Complex<f32>],
110    start: usize,
111) {
112    use std::arch::aarch64::*;
113
114    unsafe {
115        let src_ptr = src.as_ptr().add(start) as *const f32;
116        let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
117        let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
118
119        let a = vld1q_f32(src_ptr);
120        let b = vld1q_f32(hrtf_ptr);
121        let r = vld1q_f32(dst_ptr);
122        let result = fcmla_mul_acc(r, a, b);
123        vst1q_f32(dst_ptr, result);
124    }
125}
126
127#[cfg(all(
128    target_arch = "aarch64",
129    target_feature = "neon",
130    not(target_feature = "fcma")
131))]
132/// # Safety
133/// Caller must ensure `dst`, `src` and `hrtf` have at least `start + 2` elements.
134#[inline]
135pub unsafe fn complex_mul_add_simd_chunk(
136    dst: &mut [Complex<f32>],
137    src: &[Complex<f32>],
138    hrtf: &[Complex<f32>],
139    start: usize,
140) {
141    use std::arch::aarch64::*;
142
143    unsafe {
144        let src_ptr = src.as_ptr().add(start) as *const f32;
145        let hrtf_ptr = hrtf.as_ptr().add(start) as *const f32;
146        let dst_ptr = dst.as_mut_ptr().add(start) as *mut f32;
147
148        let a = vld1q_f32(src_ptr);
149        let b = vld1q_f32(hrtf_ptr);
150        let dst_val = vld1q_f32(dst_ptr);
151
152        let a_re = vtrn1q_f32(a, a);
153        let a_im = vtrn2q_f32(a, a);
154        let ac_ad = vmulq_f32(a_re, b);
155        let b_swapped = vrev64q_f32(b);
156        let bd_bc = vmulq_f32(a_im, b_swapped);
157
158        let sign_bit: u32 = 0x80000000;
159        let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
160            sign_bit,
161            vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
162        ));
163
164        let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
165            vreinterpretq_u32_f32(bd_bc),
166            vreinterpretq_u32_f32(neg_mask),
167        ));
168        let result = vaddq_f32(ac_ad, bd_bc_negated);
169        let final_result = vaddq_f32(dst_val, result);
170
171        vst1q_f32(dst_ptr, final_result);
172    }
173}
174
175#[cfg(not(any(
176    all(target_arch = "x86_64", target_feature = "avx2"),
177    all(target_arch = "aarch64", target_feature = "neon")
178)))]
179#[inline]
180pub fn complex_mul_add_simd_chunk(
181    dst: &mut [Complex<f32>],
182    src: &[Complex<f32>],
183    hrtf: &[Complex<f32>],
184    start: usize,
185) {
186    // Scalar fallback (will be optimized by LLVM auto-vectorization)
187    dst[start] += src[start] * hrtf[start];
188}
189
190/// SIMD-optimized complex multiply-accumulate
191///
192/// Computes: `dst[i] += src[i] * hrtf[i]` for all `i`
193///
194/// Uses platform-specific SIMD instructions for maximum performance:
195/// - AVX2 on x86-64 (4 complex at once)
196/// - NEON on aarch64 (2 complex at once)
197/// - Scalar fallback with auto-vectorization hints
198#[inline]
199pub fn complex_mul_add_simd(dst: &mut [Complex<f32>], src: &[Complex<f32>], hrtf: &[Complex<f32>]) {
200    let len = dst.len();
201
202    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
203    {
204        // Process 4 complex at a time with AVX2
205        let simd_len = (len / 4) * 4;
206
207        for i in (0..simd_len).step_by(4) {
208            unsafe {
209                complex_mul_add_simd_chunk(dst, src, hrtf, i);
210            }
211        }
212
213        // Scalar remainder
214        for i in simd_len..len {
215            dst[i] += src[i] * hrtf[i];
216        }
217    }
218
219    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
220    {
221        // Process 2 complex at a time with NEON
222        let simd_len = (len / 2) * 2;
223
224        for i in (0..simd_len).step_by(2) {
225            unsafe {
226                complex_mul_add_simd_chunk(dst, src, hrtf, i);
227            }
228        }
229
230        // Scalar remainder
231        for i in simd_len..len {
232            dst[i] += src[i] * hrtf[i];
233        }
234    }
235
236    #[cfg(not(any(
237        all(target_arch = "x86_64", target_feature = "avx2"),
238        all(target_arch = "aarch64", target_feature = "neon")
239    )))]
240    {
241        // Scalar fallback
242        for i in 0..len {
243            dst[i] += src[i] * hrtf[i];
244        }
245    }
246}
247
248/// SIMD-optimized complex multiplication (without accumulation)
249///
250/// Computes: `dst[i] = src[i] * hrtf[i]` for all `i`
251#[inline]
252pub fn complex_mul_simd(dst: &mut [Complex<f32>], src: &[Complex<f32>], hrtf: &[Complex<f32>]) {
253    let len = dst.len();
254
255    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
256    {
257        use std::arch::x86_64::*;
258
259        let simd_len = (len / 4) * 4;
260
261        for i in (0..simd_len).step_by(4) {
262            unsafe {
263                let src_ptr = src.as_ptr().add(i) as *const f32;
264                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
265                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
266
267                let a = _mm256_loadu_ps(src_ptr);
268                let b = _mm256_loadu_ps(hrtf_ptr);
269
270                let a_re = _mm256_moveldup_ps(a);
271                let a_im = _mm256_movehdup_ps(a);
272                let ac_ad = _mm256_mul_ps(a_re, b);
273                let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
274                let bd_bc = _mm256_mul_ps(a_im, b_swapped);
275                let result = _mm256_addsub_ps(ac_ad, bd_bc);
276
277                _mm256_storeu_ps(dst_ptr, result);
278            }
279        }
280
281        for i in simd_len..len {
282            dst[i] = src[i] * hrtf[i];
283        }
284    }
285
286    #[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
287    {
288        use std::arch::aarch64::*;
289
290        let simd_len = (len / 2) * 2;
291
292        for i in (0..simd_len).step_by(2) {
293            unsafe {
294                let src_ptr = src.as_ptr().add(i) as *const f32;
295                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
296                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
297
298                let a = vld1q_f32(src_ptr);
299                let b = vld1q_f32(hrtf_ptr);
300                let r = vdupq_n_f32(0.0);
301                let result = fcmla_mul_acc(r, a, b);
302                vst1q_f32(dst_ptr, result);
303            }
304        }
305
306        for i in simd_len..len {
307            dst[i] = src[i] * hrtf[i];
308        }
309    }
310
311    #[cfg(all(
312        target_arch = "aarch64",
313        target_feature = "neon",
314        not(target_feature = "fcma")
315    ))]
316    {
317        use std::arch::aarch64::*;
318
319        let simd_len = (len / 2) * 2;
320
321        for i in (0..simd_len).step_by(2) {
322            unsafe {
323                let src_ptr = src.as_ptr().add(i) as *const f32;
324                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
325                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
326
327                let a = vld1q_f32(src_ptr);
328                let b = vld1q_f32(hrtf_ptr);
329
330                let a_re = vtrn1q_f32(a, a);
331                let a_im = vtrn2q_f32(a, a);
332                let ac_ad = vmulq_f32(a_re, b);
333                let b_swapped = vrev64q_f32(b);
334                let bd_bc = vmulq_f32(a_im, b_swapped);
335
336                let sign_bit: u32 = 0x80000000;
337                let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
338                    sign_bit,
339                    vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
340                ));
341
342                let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
343                    vreinterpretq_u32_f32(bd_bc),
344                    vreinterpretq_u32_f32(neg_mask),
345                ));
346
347                let result = vaddq_f32(ac_ad, bd_bc_negated);
348                vst1q_f32(dst_ptr, result);
349            }
350        }
351
352        for i in simd_len..len {
353            dst[i] = src[i] * hrtf[i];
354        }
355    }
356
357    #[cfg(not(any(
358        all(target_arch = "x86_64", target_feature = "avx2"),
359        all(target_arch = "aarch64", target_feature = "neon")
360    )))]
361    {
362        for i in 0..len {
363            dst[i] = src[i] * hrtf[i];
364        }
365    }
366}
367
368/// SIMD-optimized in-place complex multiplication
369///
370/// Computes: `dst[i] *= hrtf[i]` for all `i`
371#[inline]
372#[allow(dead_code)]
373pub fn complex_mul_inplace_simd(dst: &mut [Complex<f32>], hrtf: &[Complex<f32>]) {
374    let len = dst.len();
375
376    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
377    {
378        use std::arch::x86_64::*;
379
380        let simd_len = (len / 4) * 4;
381
382        for i in (0..simd_len).step_by(4) {
383            unsafe {
384                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
385                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
386
387                let a = _mm256_loadu_ps(dst_ptr);
388                let b = _mm256_loadu_ps(hrtf_ptr);
389
390                let a_re = _mm256_moveldup_ps(a);
391                let a_im = _mm256_movehdup_ps(a);
392                let ac_ad = _mm256_mul_ps(a_re, b);
393                let b_swapped = _mm256_shuffle_ps(b, b, SHUFFLE_SWAP_RE_IM);
394                let bd_bc = _mm256_mul_ps(a_im, b_swapped);
395                let result = _mm256_addsub_ps(ac_ad, bd_bc);
396
397                _mm256_storeu_ps(dst_ptr, result);
398            }
399        }
400
401        for i in simd_len..len {
402            dst[i] *= hrtf[i];
403        }
404    }
405
406    #[cfg(all(target_arch = "aarch64", target_feature = "fcma"))]
407    {
408        use std::arch::aarch64::*;
409
410        let simd_len = (len / 2) * 2;
411
412        for i in (0..simd_len).step_by(2) {
413            unsafe {
414                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
415                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
416
417                let a = vld1q_f32(dst_ptr);
418                let b = vld1q_f32(hrtf_ptr);
419                let r = vdupq_n_f32(0.0);
420                let result = fcmla_mul_acc(r, a, b);
421                vst1q_f32(dst_ptr, result);
422            }
423        }
424
425        for i in simd_len..len {
426            dst[i] *= hrtf[i];
427        }
428    }
429
430    #[cfg(all(
431        target_arch = "aarch64",
432        target_feature = "neon",
433        not(target_feature = "fcma")
434    ))]
435    {
436        use std::arch::aarch64::*;
437
438        let simd_len = (len / 2) * 2;
439
440        for i in (0..simd_len).step_by(2) {
441            unsafe {
442                let dst_ptr = dst.as_mut_ptr().add(i) as *mut f32;
443                let hrtf_ptr = hrtf.as_ptr().add(i) as *const f32;
444
445                let a = vld1q_f32(dst_ptr);
446                let b = vld1q_f32(hrtf_ptr);
447
448                let a_re = vtrn1q_f32(a, a);
449                let a_im = vtrn2q_f32(a, a);
450                let ac_ad = vmulq_f32(a_re, b);
451                let b_swapped = vrev64q_f32(b);
452                let bd_bc = vmulq_f32(a_im, b_swapped);
453
454                let sign_bit: u32 = 0x80000000;
455                let neg_mask = vreinterpretq_f32_u32(vsetq_lane_u32::<2>(
456                    sign_bit,
457                    vsetq_lane_u32::<0>(sign_bit, vdupq_n_u32(0)),
458                ));
459
460                let bd_bc_negated = vreinterpretq_f32_u32(veorq_u32(
461                    vreinterpretq_u32_f32(bd_bc),
462                    vreinterpretq_u32_f32(neg_mask),
463                ));
464
465                let result = vaddq_f32(ac_ad, bd_bc_negated);
466                vst1q_f32(dst_ptr, result);
467            }
468        }
469
470        for i in simd_len..len {
471            dst[i] *= hrtf[i];
472        }
473    }
474
475    #[cfg(not(any(
476        all(target_arch = "x86_64", target_feature = "avx2"),
477        all(target_arch = "aarch64", target_feature = "neon")
478    )))]
479    {
480        for i in 0..len {
481            dst[i] *= hrtf[i];
482        }
483    }
484}
485
486// ============================================================================
487// Real-valued SIMD operations for windowing and overlap-add
488// ============================================================================
489
490/// SIMD-optimized multiply-accumulate for overlap-add synthesis (no window)
491///
492/// Computes: `dst[i] += src[i] * scale` for all `i`
493#[inline]
494pub fn scale_add_simd(dst: &mut [f32], src: &[f32], scale: f32) {
495    debug_assert_eq!(dst.len(), src.len());
496    let len = dst.len();
497
498    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
499    {
500        use std::arch::x86_64::*;
501
502        let scale_vec = unsafe { _mm256_set1_ps(scale) };
503        let simd_len = (len / 8) * 8;
504
505        for i in (0..simd_len).step_by(8) {
506            unsafe {
507                let src_ptr = src.as_ptr().add(i);
508                let dst_ptr = dst.as_mut_ptr().add(i);
509
510                let s = _mm256_loadu_ps(src_ptr);
511                let d = _mm256_loadu_ps(dst_ptr);
512
513                // src * scale + dst
514                let ss = _mm256_mul_ps(s, scale_vec);
515                let result = _mm256_add_ps(d, ss);
516
517                _mm256_storeu_ps(dst_ptr, result);
518            }
519        }
520
521        // Scalar remainder
522        for i in simd_len..len {
523            dst[i] += src[i] * scale;
524        }
525    }
526
527    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
528    {
529        use std::arch::aarch64::*;
530
531        let scale_vec = unsafe { vdupq_n_f32(scale) };
532        let simd_len = (len / 4) * 4;
533
534        for i in (0..simd_len).step_by(4) {
535            unsafe {
536                let src_ptr = src.as_ptr().add(i);
537                let dst_ptr = dst.as_mut_ptr().add(i);
538
539                let s = vld1q_f32(src_ptr);
540                let d = vld1q_f32(dst_ptr);
541
542                // FMA: dst + src * scale in a single fused instruction
543                let result = vfmaq_f32(d, s, scale_vec);
544
545                vst1q_f32(dst_ptr, result);
546            }
547        }
548
549        // Scalar remainder
550        for i in simd_len..len {
551            dst[i] += src[i] * scale;
552        }
553    }
554
555    #[cfg(not(any(
556        all(target_arch = "x86_64", target_feature = "avx2"),
557        all(target_arch = "aarch64", target_feature = "neon")
558    )))]
559    {
560        for i in 0..len {
561            dst[i] += src[i] * scale;
562        }
563    }
564}
565
566/// SIMD-optimized in-place scaling.
567///
568/// Computes: `data[i] *= scale` for all `i`
569#[inline]
570pub fn scale_add_simd_inplace(data: &mut [f32], scale: f32) {
571    let len = data.len();
572
573    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
574    {
575        use std::arch::x86_64::*;
576
577        let scale_vec = unsafe { _mm256_set1_ps(scale) };
578        let simd_len = (len / 8) * 8;
579
580        for i in (0..simd_len).step_by(8) {
581            unsafe {
582                let ptr = data.as_mut_ptr().add(i);
583                let d = _mm256_loadu_ps(ptr);
584                _mm256_storeu_ps(ptr, _mm256_mul_ps(d, scale_vec));
585            }
586        }
587
588        for sample in data.iter_mut().take(len).skip(simd_len) {
589            *sample *= scale;
590        }
591    }
592
593    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
594    {
595        use std::arch::aarch64::*;
596
597        let scale_vec = unsafe { vdupq_n_f32(scale) };
598        let simd_len = (len / 4) * 4;
599
600        for i in (0..simd_len).step_by(4) {
601            unsafe {
602                let ptr = data.as_mut_ptr().add(i);
603                let d = vld1q_f32(ptr);
604                vst1q_f32(ptr, vmulq_f32(d, scale_vec));
605            }
606        }
607
608        for sample in &mut data[simd_len..len] {
609            *sample *= scale;
610        }
611    }
612
613    #[cfg(not(any(
614        all(target_arch = "x86_64", target_feature = "avx2"),
615        all(target_arch = "aarch64", target_feature = "neon")
616    )))]
617    {
618        for sample in data {
619            *sample *= scale;
620        }
621    }
622}
623
624/// SIMD-optimized linear interpolation (blend) between two buffers
625///
626/// Computes: `dst[i] = prev[i] + alpha * (dst[i] - prev[i])` for all `i`
627/// Equivalent to: `dst[i] = (1 - alpha) * prev[i] + alpha * dst[i]`
628///
629/// Used for crossfading between old and new filter outputs in STFT plugins.
630#[inline]
631pub fn blend_simd(dst: &mut [f32], prev: &[f32], alpha: f32) {
632    debug_assert_eq!(dst.len(), prev.len());
633    let len = dst.len();
634
635    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
636    {
637        use std::arch::x86_64::*;
638
639        let alpha_vec = unsafe { _mm256_set1_ps(alpha) };
640        let simd_len = (len / 8) * 8;
641
642        for i in (0..simd_len).step_by(8) {
643            unsafe {
644                let prev_ptr = prev.as_ptr().add(i);
645                let dst_ptr = dst.as_mut_ptr().add(i);
646
647                let p = _mm256_loadu_ps(prev_ptr);
648                let d = _mm256_loadu_ps(dst_ptr);
649
650                // prev + alpha * (dst - prev)
651                let diff = _mm256_sub_ps(d, p);
652                let result = _mm256_fmadd_ps(alpha_vec, diff, p);
653
654                _mm256_storeu_ps(dst_ptr, result);
655            }
656        }
657
658        for i in simd_len..len {
659            dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
660        }
661    }
662
663    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
664    {
665        use std::arch::aarch64::*;
666
667        let alpha_vec = unsafe { vdupq_n_f32(alpha) };
668        let simd_len = (len / 4) * 4;
669
670        for i in (0..simd_len).step_by(4) {
671            unsafe {
672                let prev_ptr = prev.as_ptr().add(i);
673                let dst_ptr = dst.as_mut_ptr().add(i);
674
675                let p = vld1q_f32(prev_ptr);
676                let d = vld1q_f32(dst_ptr);
677
678                // prev + alpha * (dst - prev)
679                let diff = vsubq_f32(d, p);
680                let result = vfmaq_f32(p, alpha_vec, diff);
681
682                vst1q_f32(dst_ptr, result);
683            }
684        }
685
686        for i in simd_len..len {
687            dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
688        }
689    }
690
691    #[cfg(not(any(
692        all(target_arch = "x86_64", target_feature = "avx2"),
693        all(target_arch = "aarch64", target_feature = "neon")
694    )))]
695    {
696        for i in 0..len {
697            dst[i] = prev[i] + alpha * (dst[i] - prev[i]);
698        }
699    }
700}
701
702/// SIMD-optimized windowed copy (for FFT input preparation)
703///
704/// Computes: `dst[i] = src[i] * window[i]` for all `i`
705#[inline]
706pub fn window_mul_simd(dst: &mut [f32], src: &[f32], window: &[f32]) {
707    debug_assert_eq!(dst.len(), src.len());
708    debug_assert_eq!(dst.len(), window.len());
709    let len = dst.len();
710
711    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
712    {
713        use std::arch::x86_64::*;
714
715        let simd_len = (len / 8) * 8;
716
717        for i in (0..simd_len).step_by(8) {
718            unsafe {
719                let src_ptr = src.as_ptr().add(i);
720                let win_ptr = window.as_ptr().add(i);
721                let dst_ptr = dst.as_mut_ptr().add(i);
722
723                let s = _mm256_loadu_ps(src_ptr);
724                let w = _mm256_loadu_ps(win_ptr);
725                let result = _mm256_mul_ps(s, w);
726
727                _mm256_storeu_ps(dst_ptr, result);
728            }
729        }
730
731        for i in simd_len..len {
732            dst[i] = src[i] * window[i];
733        }
734    }
735
736    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
737    {
738        use std::arch::aarch64::*;
739
740        let simd_len = (len / 4) * 4;
741
742        for i in (0..simd_len).step_by(4) {
743            unsafe {
744                let src_ptr = src.as_ptr().add(i);
745                let win_ptr = window.as_ptr().add(i);
746                let dst_ptr = dst.as_mut_ptr().add(i);
747
748                let s = vld1q_f32(src_ptr);
749                let w = vld1q_f32(win_ptr);
750                let result = vmulq_f32(s, w);
751
752                vst1q_f32(dst_ptr, result);
753            }
754        }
755
756        for i in simd_len..len {
757            dst[i] = src[i] * window[i];
758        }
759    }
760
761    #[cfg(not(any(
762        all(target_arch = "x86_64", target_feature = "avx2"),
763        all(target_arch = "aarch64", target_feature = "neon")
764    )))]
765    {
766        for i in 0..len {
767            dst[i] = src[i] * window[i];
768        }
769    }
770}
771
772/// SIMD-optimized in-place window multiplication.
773#[inline]
774pub fn window_mul_simd_inplace(data: &mut [f32], window: &[f32]) {
775    let len = data.len().min(window.len());
776
777    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
778    {
779        use std::arch::x86_64::*;
780        let simd_len = (len / 8) * 8;
781        for i in (0..simd_len).step_by(8) {
782            unsafe {
783                let ptr = data.as_mut_ptr().add(i);
784                let win_ptr = window.as_ptr().add(i);
785                let d = _mm256_loadu_ps(ptr);
786                let w = _mm256_loadu_ps(win_ptr);
787                _mm256_storeu_ps(ptr, _mm256_mul_ps(d, w));
788            }
789        }
790        for i in simd_len..len {
791            data[i] *= window[i];
792        }
793    }
794
795    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
796    {
797        use std::arch::aarch64::*;
798        let simd_len = (len / 4) * 4;
799        for i in (0..simd_len).step_by(4) {
800            unsafe {
801                let ptr = data.as_mut_ptr().add(i);
802                let win_ptr = window.as_ptr().add(i);
803                let d = vld1q_f32(ptr);
804                let w = vld1q_f32(win_ptr);
805                vst1q_f32(ptr, vmulq_f32(d, w));
806            }
807        }
808        for i in simd_len..len {
809            data[i] *= window[i];
810        }
811    }
812
813    #[cfg(not(any(
814        all(target_arch = "x86_64", target_feature = "avx2"),
815        all(target_arch = "aarch64", target_feature = "neon")
816    )))]
817    {
818        for i in 0..len {
819            data[i] *= window[i];
820        }
821    }
822}
823
824/// Deinterleave stereo buffer into separate L/R channels
825///
826/// Input: [L0, R0, L1, R1, L2, R2, ...]
827/// Output: left = [L0, L1, L2, ...], right = [R0, R1, R2, ...]
828#[inline]
829pub fn deinterleave_stereo(input: &[f32], left: &mut [f32], right: &mut [f32]) {
830    debug_assert_eq!(input.len(), left.len() * 2);
831    debug_assert_eq!(left.len(), right.len());
832
833    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
834    {
835        use std::arch::x86_64::*;
836
837        let len = left.len();
838        let simd_len = (len / 8) * 8;
839
840        for i in (0..simd_len).step_by(8) {
841            unsafe {
842                // Load 16 floats (8 stereo pairs)
843                let in_ptr = input.as_ptr().add(i * 2);
844                let v0 = _mm256_loadu_ps(in_ptr); // L0 R0 L1 R1 L2 R2 L3 R3
845                let v1 = _mm256_loadu_ps(in_ptr.add(8)); // L4 R4 L5 R5 L6 R6 L7 R7
846
847                // Shuffle to separate L and R
848                // Within 256-bit lanes, shuffle to group L/R
849                let shuf_l = _mm256_shuffle_ps(v0, v1, 0b10_00_10_00); // L0 L1 L4 L5 | L2 L3 L6 L7
850                let shuf_r = _mm256_shuffle_ps(v0, v1, 0b11_01_11_01); // R0 R1 R4 R5 | R2 R3 R6 R7
851
852                // Permute to get correct order
853                let left_vec = _mm256_permute4x64_pd(
854                    std::mem::transmute::<__m256, __m256d>(shuf_l),
855                    0b11_01_10_00,
856                );
857                let right_vec = _mm256_permute4x64_pd(
858                    std::mem::transmute::<__m256, __m256d>(shuf_r),
859                    0b11_01_10_00,
860                );
861
862                _mm256_storeu_ps(
863                    left.as_mut_ptr().add(i),
864                    std::mem::transmute::<__m256d, __m256>(left_vec),
865                );
866                _mm256_storeu_ps(
867                    right.as_mut_ptr().add(i),
868                    std::mem::transmute::<__m256d, __m256>(right_vec),
869                );
870            }
871        }
872
873        // Scalar remainder
874        for i in simd_len..len {
875            left[i] = input[i * 2];
876            right[i] = input[i * 2 + 1];
877        }
878    }
879
880    #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
881    {
882        // Scalar fallback (compiler will auto-vectorize for NEON)
883        for (i, chunk) in input.chunks_exact(2).enumerate() {
884            left[i] = chunk[0];
885            right[i] = chunk[1];
886        }
887    }
888}
889
890/// Interleave separate L/R channels into stereo buffer
891///
892/// Input: left = [L0, L1, L2, ...], right = [R0, R1, R2, ...]
893/// Output: [L0, R0, L1, R1, L2, R2, ...]
894#[inline]
895#[allow(dead_code)]
896pub fn interleave_stereo(left: &[f32], right: &[f32], output: &mut [f32]) {
897    debug_assert_eq!(left.len(), right.len());
898    debug_assert_eq!(output.len(), left.len() * 2);
899
900    // Scalar version - compiler auto-vectorizes well
901    for i in 0..left.len() {
902        output[i * 2] = left[i];
903        output[i * 2 + 1] = right[i];
904    }
905}
906
907/// Flush denormal numbers to zero to prevent CPU performance spikes and audio glitches.
908///
909/// Denormal floats (IEEE 754 subnormals with magnitude < `f32::MIN_POSITIVE` ≈ 1.175e-38)
910/// cause significant CPU overhead when processed by FMA instructions. Normal values
911/// smaller than `f32::MIN_POSITIVE` do not exist — `f32::MIN_POSITIVE` IS the boundary.
912/// Values at or above `f32::MIN_POSITIVE` are normal and will never be zeroed.
913#[inline]
914pub fn flush_denormals_inplace(samples: &mut [f32]) {
915    // f32::MIN_POSITIVE (~1.175e-38) is the smallest positive *normal* f32.
916    // Values with abs() < MIN_POSITIVE are subnormals (denormals).
917    const DENORM_THRESHOLD: f32 = f32::MIN_POSITIVE;
918
919    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
920    {
921        use std::arch::x86_64::*;
922
923        let threshold = unsafe { _mm256_set1_ps(DENORM_THRESHOLD) };
924        let zero = unsafe { _mm256_set1_ps(0.0) };
925        let len = samples.len();
926        let simd_len = (len / 8) * 8;
927
928        for i in (0..simd_len).step_by(8) {
929            unsafe {
930                let ptr = samples.as_mut_ptr().add(i);
931                let val = _mm256_loadu_ps(ptr);
932                let abs_val = _mm256_andnot_ps(_mm256_set1_ps(-0.0), val);
933                let mask = _mm256_cmp_ps(abs_val, threshold, _CMP_LT_OQ);
934                let result = _mm256_blendv_ps(val, zero, mask);
935                _mm256_storeu_ps(ptr, result);
936            }
937        }
938
939        for sample in samples.iter_mut().take(len).skip(simd_len) {
940            if sample.abs() < DENORM_THRESHOLD {
941                *sample = 0.0;
942            }
943        }
944    }
945
946    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
947    {
948        use std::arch::aarch64::*;
949
950        let threshold = unsafe { vdupq_n_f32(DENORM_THRESHOLD) };
951        let zero = unsafe { vdupq_n_f32(0.0) };
952        let len = samples.len();
953        let simd_len = (len / 4) * 4;
954
955        for i in (0..simd_len).step_by(4) {
956            unsafe {
957                let ptr = samples.as_mut_ptr().add(i);
958                let val = vld1q_f32(ptr);
959                let abs_val = vabsq_f32(val);
960                let mask = vcltq_f32(abs_val, threshold);
961                let result = vbslq_f32(mask, zero, val);
962                vst1q_f32(ptr, result);
963            }
964        }
965
966        for sample in &mut samples[simd_len..len] {
967            if sample.abs() < DENORM_THRESHOLD {
968                *sample = 0.0;
969            }
970        }
971    }
972
973    #[cfg(not(any(
974        all(target_arch = "x86_64", target_feature = "avx2"),
975        all(target_arch = "aarch64", target_feature = "neon")
976    )))]
977    {
978        for sample in samples {
979            if sample.abs() < DENORM_THRESHOLD {
980                *sample = 0.0;
981            }
982        }
983    }
984}
985
986/// Enable FTZ (Flush To Zero) and DAZ (Denormals Are Zero) CPU flags.
987///
988/// When enabled, denormal floating-point numbers are automatically flushed to zero
989/// by the CPU hardware. This prevents the severe performance degradation that occurs
990/// when IIR filter state variables (like biquad y1/y2) contain denormals.
991///
992/// This function is idempotent and should be called once per audio processing thread.
993/// On unsupported platforms, this is a no-op. The register change remains active
994/// on the calling thread; use [`ScopedFtz`] when the effect must be limited to a
995/// lexical scope.
996///
997/// Returns true if the flags were successfully set, false otherwise (e.g., unsupported platform).
998#[inline]
999pub fn enable_ftz_daz() -> bool {
1000    #[cfg(target_arch = "x86_64")]
1001    {
1002        // MXCSR register bits:
1003        // Bit 15: FTZ (Flush To Zero) - flush denormal results to zero
1004        // Bit 6: DAZ (Denormals Are Zero) - treat denormal inputs as zero
1005        unsafe {
1006            let mut mxcsr: u32 = 0;
1007            std::arch::asm!("stmxcsr [{}]", in(reg) &mut mxcsr, options(nostack, preserves_flags));
1008            mxcsr |= (1 << 15) | (1 << 6); // FTZ | DAZ
1009            std::arch::asm!("ldmxcsr [{}]", in(reg) &mxcsr, options(nostack, preserves_flags));
1010        }
1011        true
1012    }
1013
1014    #[cfg(target_arch = "aarch64")]
1015    {
1016        // On AArch64, use the FPCR register
1017        // Bit 24: FZ (Flush-to-Zero)
1018        // Note: AArch64 always treats denormal inputs as zero in FZ mode
1019        unsafe {
1020            let mut fpcr: u64;
1021            std::arch::asm!("mrs {}, fpcr", out(reg) fpcr);
1022            fpcr |= 1 << 24; // FZ bit
1023            std::arch::asm!("msr fpcr, {}", in(reg) fpcr);
1024        }
1025        true
1026    }
1027
1028    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1029    {
1030        false
1031    }
1032}
1033
1034/// RAII guard that enables FTZ/DAZ on construction and restores the previous
1035/// FPU control register state on drop, scoping the effect to a block.
1036///
1037/// # Usage
1038/// ```ignore
1039/// {
1040///     let _guard = ScopedFtz::new();
1041///     // FTZ/DAZ active here
1042/// } // previous FPU state restored on drop
1043/// ```
1044///
1045/// On unsupported platforms this is a no-op. The guard is `!Send` because
1046/// FPU control registers are per-thread state.
1047pub struct ScopedFtz {
1048    #[cfg(target_arch = "x86_64")]
1049    saved_mxcsr: Option<u32>,
1050    #[cfg(target_arch = "aarch64")]
1051    saved_fpcr: Option<u64>,
1052}
1053
1054impl ScopedFtz {
1055    /// Enable FTZ/DAZ and save the current register state for restoration on drop.
1056    #[allow(clippy::needless_return)]
1057    pub fn new() -> Self {
1058        #[cfg(target_arch = "x86_64")]
1059        {
1060            let saved = unsafe {
1061                let mut mxcsr: u32 = 0;
1062                std::arch::asm!(
1063                    "stmxcsr [{}]",
1064                    in(reg) &mut mxcsr,
1065                    options(nostack, preserves_flags)
1066                );
1067                let new_mxcsr = mxcsr | (1 << 15) | (1 << 6); // FTZ | DAZ
1068                std::arch::asm!(
1069                    "ldmxcsr [{}]",
1070                    in(reg) &new_mxcsr,
1071                    options(nostack, preserves_flags)
1072                );
1073                mxcsr
1074            };
1075            return Self {
1076                saved_mxcsr: Some(saved),
1077            };
1078        }
1079
1080        #[cfg(target_arch = "aarch64")]
1081        {
1082            let saved = unsafe {
1083                let mut fpcr: u64;
1084                std::arch::asm!("mrs {}, fpcr", out(reg) fpcr);
1085                let new_fpcr = fpcr | (1u64 << 24); // FZ bit
1086                std::arch::asm!("msr fpcr, {}", in(reg) new_fpcr);
1087                fpcr
1088            };
1089            return Self {
1090                saved_fpcr: Some(saved),
1091            };
1092        }
1093
1094        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1095        Self {}
1096    }
1097}
1098
1099impl Default for ScopedFtz {
1100    fn default() -> Self {
1101        Self::new()
1102    }
1103}
1104
1105impl Drop for ScopedFtz {
1106    fn drop(&mut self) {
1107        #[cfg(target_arch = "x86_64")]
1108        if let Some(saved) = self.saved_mxcsr {
1109            unsafe {
1110                std::arch::asm!(
1111                    "ldmxcsr [{}]",
1112                    in(reg) &saved,
1113                    options(nostack, preserves_flags)
1114                );
1115            }
1116        }
1117
1118        #[cfg(target_arch = "aarch64")]
1119        if let Some(saved) = self.saved_fpcr {
1120            unsafe {
1121                std::arch::asm!("msr fpcr, {}", in(reg) saved);
1122            }
1123        }
1124    }
1125}
1126
1127/// Flush denormals in complex buffer (applies to both real and imaginary parts)
1128#[inline]
1129pub fn flush_denormals_complex_inplace(samples: &mut [Complex<f32>]) {
1130    // A Complex<f32> is just two f32s (re and im)
1131    // We can treat the whole buffer as a slice of f32 and use flush_denormals_inplace
1132    let len = samples.len() * 2;
1133    let ptr = samples.as_mut_ptr() as *mut f32;
1134    let f32_samples = unsafe { std::slice::from_raw_parts_mut(ptr, len) };
1135    flush_denormals_inplace(f32_samples);
1136}
1137
1138#[cfg(test)]
1139mod denorm_tests {
1140    use super::*;
1141
1142    #[test]
1143    fn test_flush_denormals_basic() {
1144        // 1e-39 is subnormal (below f32::MIN_POSITIVE) and must be zeroed.
1145        // 1e-20, 1e-10, 1.0 are normal and must be preserved.
1146        let mut samples = [1e-39_f32, 1e-20, 1e-10, 0.0, -1e-39_f32, 1.0];
1147        flush_denormals_inplace(&mut samples);
1148        assert_eq!(samples[0], 0.0, "subnormal 1e-39 must be zeroed");
1149        assert_eq!(samples[1], 1e-20);
1150        assert_eq!(samples[2], 1e-10);
1151        assert_eq!(samples[3], 0.0);
1152        assert_eq!(samples[4], 0.0, "negative subnormal -1e-39 must be zeroed");
1153        assert_eq!(samples[5], 1.0);
1154    }
1155
1156    /// A normal f32 value just above the subnormal boundary must NOT be zeroed.
1157    #[test]
1158    fn test_flush_denormals_normal_small_not_zeroed() {
1159        // 1e-35 is a normal f32 (above f32::MIN_POSITIVE ~1.175e-38).
1160        let mut samples = [1e-35_f32];
1161        flush_denormals_inplace(&mut samples);
1162        assert!(
1163            samples[0] != 0.0,
1164            "normal value 1e-35 (above f32::MIN_POSITIVE) must not be zeroed"
1165        );
1166    }
1167
1168    /// A subnormal f32 value must be zeroed.
1169    #[test]
1170    fn test_flush_denormals_subnormal_zeroed() {
1171        // 1e-39 is a subnormal f32 (below f32::MIN_POSITIVE ~1.175e-38).
1172        let mut samples = [1e-39_f32];
1173        flush_denormals_inplace(&mut samples);
1174        assert_eq!(
1175            samples[0], 0.0,
1176            "subnormal value 1e-39 (below f32::MIN_POSITIVE) must be zeroed"
1177        );
1178    }
1179
1180    #[test]
1181    fn test_flush_denormals_complex() {
1182        use rustfft::num_complex::Complex;
1183        // re=1e-39 (subnormal) zeroed; im=1e-30 (normal) preserved.
1184        let mut samples = [
1185            Complex::new(1e-39_f32, 1e-30_f32),
1186            Complex::new(1.0, 1e-39_f32),
1187            Complex::new(0.0, 0.0),
1188        ];
1189        flush_denormals_complex_inplace(&mut samples);
1190        assert_eq!(samples[0].re, 0.0, "subnormal re must be zeroed");
1191        assert!(samples[0].im != 0.0, "normal im 1e-30 must be preserved");
1192        assert_eq!(samples[1].re, 1.0);
1193        assert_eq!(samples[1].im, 0.0, "subnormal im must be zeroed");
1194        assert_eq!(samples[2].re, 0.0);
1195        assert_eq!(samples[2].im, 0.0);
1196    }
1197
1198    #[test]
1199    fn test_flush_denormals_empty() {
1200        let mut samples: [f32; 0] = [];
1201        flush_denormals_inplace(&mut samples);
1202    }
1203
1204    #[test]
1205    fn test_flush_denormals_unaligned() {
1206        let mut samples = [1e-39_f32; 7];
1207        flush_denormals_inplace(&mut samples);
1208        for s in samples.iter() {
1209            assert_eq!(*s, 0.0);
1210        }
1211    }
1212}
1213
1214#[cfg(test)]
1215#[allow(clippy::needless_range_loop)]
1216mod tests {
1217    use super::*;
1218
1219    // ============================================================================
1220    // Denormal Flushing Tests
1221    // ============================================================================
1222
1223    #[test]
1224    fn test_flush_denormals_basic() {
1225        // 1e-39 is subnormal (below f32::MIN_POSITIVE ~1.175e-38) and must be zeroed.
1226        let mut samples = [1e-39_f32, 1e-20, 1e-10, 0.0, -1e-39_f32, 1.0];
1227        flush_denormals_inplace(&mut samples);
1228        assert_eq!(samples[0], 0.0);
1229        assert_eq!(samples[1], 1e-20);
1230        assert_eq!(samples[2], 1e-10);
1231        assert_eq!(samples[3], 0.0);
1232        assert_eq!(samples[4], 0.0);
1233        assert_eq!(samples[5], 1.0);
1234    }
1235
1236    #[test]
1237    fn test_flush_denormals_complex() {
1238        use rustfft::num_complex::Complex;
1239        // re=1e-39 (subnormal) zeroed; im=1e-30 (normal) preserved.
1240        let mut samples = [
1241            Complex::new(1e-39_f32, 1e-30_f32),
1242            Complex::new(1.0, 1e-39_f32),
1243            Complex::new(0.0, 0.0),
1244        ];
1245        flush_denormals_complex_inplace(&mut samples);
1246        assert_eq!(samples[0].re, 0.0);
1247        assert!(samples[0].im != 0.0, "normal im 1e-30 must be preserved");
1248        assert_eq!(samples[1].re, 1.0);
1249        assert_eq!(samples[1].im, 0.0);
1250        assert_eq!(samples[2].re, 0.0);
1251        assert_eq!(samples[2].im, 0.0);
1252    }
1253
1254    #[test]
1255    fn test_flush_denormals_empty() {
1256        let mut samples: [f32; 0] = [];
1257        flush_denormals_inplace(&mut samples);
1258    }
1259
1260    #[test]
1261    fn test_flush_denormals_unaligned() {
1262        let mut samples = [1e-39_f32; 7];
1263        flush_denormals_inplace(&mut samples);
1264        for s in samples.iter() {
1265            assert_eq!(*s, 0.0);
1266        }
1267    }
1268
1269    #[test]
1270    fn test_enable_ftz_daz_does_not_panic() {
1271        // enable_ftz_daz() should complete without panicking on any platform.
1272        // On x86_64 or aarch64 it returns true; on other platforms false.
1273        let result = enable_ftz_daz();
1274        #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
1275        assert!(
1276            result,
1277            "enable_ftz_daz should return true on supported platforms"
1278        );
1279        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1280        assert!(
1281            !result,
1282            "enable_ftz_daz should return false on unsupported platforms"
1283        );
1284    }
1285
1286    #[test]
1287    fn test_scoped_ftz_does_not_panic() {
1288        // ScopedFtz must construct and drop without panicking on any platform.
1289        {
1290            let _guard = ScopedFtz::new();
1291            // FTZ/DAZ active within this scope
1292        }
1293        // Previous FPU state restored on drop.
1294    }
1295
1296    #[test]
1297    fn test_apply_gain_simd_known_values() {
1298        // Apply gain of 2.0 to known input
1299        let mut buffer = vec![1.0, 2.0, 3.0, 4.0, -1.0, 0.5, -0.5, 0.0, 1.5];
1300        let expected: Vec<f32> = buffer.iter().map(|&x| x * 2.0).collect();
1301        apply_gain_simd(&mut buffer, 2.0);
1302        for (i, (&got, &exp)) in buffer.iter().zip(expected.iter()).enumerate() {
1303            assert!(
1304                (got - exp).abs() < 1e-6,
1305                "apply_gain_simd mismatch at index {}: got {}, expected {}",
1306                i,
1307                got,
1308                exp
1309            );
1310        }
1311    }
1312
1313    #[test]
1314    fn test_apply_gain_simd_zero_gain() {
1315        let mut buffer = vec![1.0, -2.0, 3.5, 0.7];
1316        apply_gain_simd(&mut buffer, 0.0);
1317        for (i, &v) in buffer.iter().enumerate() {
1318            assert_eq!(
1319                v, 0.0,
1320                "apply_gain_simd with zero gain: index {} not zero",
1321                i
1322            );
1323        }
1324    }
1325
1326    #[test]
1327    fn test_apply_gain_simd_unity_gain() {
1328        let original = vec![1.0, -2.0, 3.5, 0.7, 0.0, -0.1];
1329        let mut buffer = original.clone();
1330        apply_gain_simd(&mut buffer, 1.0);
1331        assert_eq!(buffer, original);
1332    }
1333
1334    #[test]
1335    fn test_apply_per_channel_gain_simd_stereo() {
1336        // Stereo buffer: apply different gains to L and R
1337        let mut buffer = vec![1.0, 2.0, 3.0, 4.0]; // 2 frames, 2 channels
1338        let gains = vec![0.5, 2.0]; // L=0.5, R=2.0
1339        apply_per_channel_gain_simd(&mut buffer, 2, &gains);
1340        assert!((buffer[0] - 0.5).abs() < 1e-6, "L frame 0");
1341        assert!((buffer[1] - 4.0).abs() < 1e-6, "R frame 0");
1342        assert!((buffer[2] - 1.5).abs() < 1e-6, "L frame 1");
1343        assert!((buffer[3] - 8.0).abs() < 1e-6, "R frame 1");
1344    }
1345
1346    // ============================================================================
1347    // SIMD Correctness Tests
1348    // ============================================================================
1349    //
1350    // These tests verify that SIMD-optimized complex multiplication produces
1351    // identical results to scalar computation
1352
1353    #[test]
1354    fn test_simd_complex_mul_add_correctness() {
1355        // Test SIMD complex multiply-accumulate against scalar reference
1356        use rustfft::num_complex::Complex;
1357
1358        // Test with 8 complex numbers (to test both AVX2 and NEON code paths)
1359        let src = vec![
1360            Complex::new(1.0, 2.0),
1361            Complex::new(3.0, 4.0),
1362            Complex::new(-1.0, 0.5),
1363            Complex::new(0.0, -2.0),
1364            Complex::new(2.5, -1.5),
1365            Complex::new(-3.5, 2.5),
1366            Complex::new(1.1, -0.9),
1367            Complex::new(-0.8, 1.2),
1368        ];
1369
1370        let hrtf = vec![
1371            Complex::new(0.5, 0.25),
1372            Complex::new(-1.0, 1.5),
1373            Complex::new(2.0, -0.5),
1374            Complex::new(0.75, 0.75),
1375            Complex::new(-0.5, 2.0),
1376            Complex::new(1.5, -1.0),
1377            Complex::new(0.9, 0.3),
1378            Complex::new(-1.1, 0.7),
1379        ];
1380
1381        let initial = vec![
1382            Complex::new(0.1, 0.2),
1383            Complex::new(0.3, 0.4),
1384            Complex::new(0.5, 0.6),
1385            Complex::new(0.7, 0.8),
1386            Complex::new(0.9, 1.0),
1387            Complex::new(1.1, 1.2),
1388            Complex::new(1.3, 1.4),
1389            Complex::new(1.5, 1.6),
1390        ];
1391
1392        // Scalar reference computation
1393        let mut expected = initial.clone();
1394        for i in 0..src.len() {
1395            expected[i] += src[i] * hrtf[i];
1396        }
1397
1398        // SIMD computation
1399        let mut result = initial.clone();
1400        complex_mul_add_simd(&mut result, &src, &hrtf);
1401
1402        // Compare results with tolerance for floating point errors
1403        const EPSILON: f32 = 1e-6;
1404        for i in 0..src.len() {
1405            assert!(
1406                (result[i].re - expected[i].re).abs() < EPSILON,
1407                "SIMD result[{}].re = {}, expected = {} (diff = {})",
1408                i,
1409                result[i].re,
1410                expected[i].re,
1411                (result[i].re - expected[i].re).abs()
1412            );
1413            assert!(
1414                (result[i].im - expected[i].im).abs() < EPSILON,
1415                "SIMD result[{}].im = {}, expected = {} (diff = {})",
1416                i,
1417                result[i].im,
1418                expected[i].im,
1419                (result[i].im - expected[i].im).abs()
1420            );
1421        }
1422    }
1423
1424    #[test]
1425    fn test_simd_complex_mul_correctness() {
1426        // Test SIMD complex multiplication (without accumulation)
1427        use rustfft::num_complex::Complex;
1428
1429        let src = vec![
1430            Complex::new(2.0, 3.0),
1431            Complex::new(-1.5, 2.5),
1432            Complex::new(0.5, -1.0),
1433            Complex::new(4.0, -2.0),
1434        ];
1435
1436        let hrtf = vec![
1437            Complex::new(1.0, 0.5),
1438            Complex::new(2.0, -1.0),
1439            Complex::new(-0.5, 1.5),
1440            Complex::new(0.75, 0.25),
1441        ];
1442
1443        // Scalar reference
1444        let expected: Vec<Complex<f32>> = src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
1445
1446        // SIMD computation
1447        let mut result = vec![Complex::new(0.0, 0.0); src.len()];
1448        complex_mul_simd(&mut result, &src, &hrtf);
1449
1450        // Compare
1451        const EPSILON: f32 = 1e-6;
1452        for i in 0..src.len() {
1453            assert!(
1454                (result[i].re - expected[i].re).abs() < EPSILON,
1455                "SIMD result[{}].re = {}, expected = {}",
1456                i,
1457                result[i].re,
1458                expected[i].re
1459            );
1460            assert!(
1461                (result[i].im - expected[i].im).abs() < EPSILON,
1462                "SIMD result[{}].im = {}, expected = {}",
1463                i,
1464                result[i].im,
1465                expected[i].im
1466            );
1467        }
1468    }
1469
1470    #[test]
1471    fn test_simd_edge_cases() {
1472        // Test edge cases: zeros, ones, conjugates
1473        use rustfft::num_complex::Complex;
1474
1475        // Test 1: Multiply by zero
1476        let src = vec![
1477            Complex::new(1.0, 2.0),
1478            Complex::new(3.0, 4.0),
1479            Complex::new(5.0, 6.0),
1480            Complex::new(7.0, 8.0),
1481        ];
1482        let zero = vec![Complex::new(0.0, 0.0); 4];
1483        let mut result = src.clone();
1484        let input = result.clone();
1485        complex_mul_simd(&mut result, &input, &zero);
1486        for i in 0..4 {
1487            assert_eq!(result[i].re, 0.0);
1488            assert_eq!(result[i].im, 0.0);
1489        }
1490
1491        // Test 2: Multiply by one (identity)
1492        let one = vec![Complex::new(1.0, 0.0); 4];
1493        let mut result = vec![Complex::new(0.0, 0.0); 4];
1494        complex_mul_simd(&mut result, &src, &one);
1495        for i in 0..4 {
1496            assert!((result[i].re - src[i].re).abs() < 1e-6);
1497            assert!((result[i].im - src[i].im).abs() < 1e-6);
1498        }
1499
1500        // Test 3: Multiply by conjugate (should give real result)
1501        let a = Complex::new(3.0, 4.0);
1502        let a_conj = Complex::new(3.0, -4.0);
1503        let src = vec![a, a, a, a];
1504        let conj = vec![a_conj, a_conj, a_conj, a_conj];
1505        let mut result = vec![Complex::new(0.0, 0.0); 4];
1506        complex_mul_simd(&mut result, &src, &conj);
1507
1508        // a * conj(a) = |a|^2 = 3^2 + 4^2 = 25
1509        for i in 0..4 {
1510            assert!((result[i].re - 25.0).abs() < 1e-5);
1511            assert!(result[i].im.abs() < 1e-5); // Should be approximately zero
1512        }
1513    }
1514
1515    #[test]
1516    fn test_simd_large_buffer() {
1517        // Test with realistic FFT buffer sizes
1518        use rustfft::num_complex::Complex;
1519
1520        for fft_size in [512, 1024, 2048, 4096] {
1521            let mut src = Vec::with_capacity(fft_size);
1522            let mut hrtf = Vec::with_capacity(fft_size);
1523
1524            // Fill with test pattern
1525            for i in 0..fft_size {
1526                let phase = (i as f32) * 0.01;
1527                src.push(Complex::new(phase.cos(), phase.sin()));
1528                hrtf.push(Complex::new(0.5, 0.25));
1529            }
1530
1531            // Scalar reference
1532            let mut expected = vec![Complex::new(0.1, 0.2); fft_size];
1533            for i in 0..fft_size {
1534                expected[i] += src[i] * hrtf[i];
1535            }
1536
1537            // SIMD computation
1538            let mut result = vec![Complex::new(0.1, 0.2); fft_size];
1539            complex_mul_add_simd(&mut result, &src, &hrtf);
1540
1541            // Verify all elements match
1542            for i in 0..fft_size {
1543                assert!(
1544                    (result[i].re - expected[i].re).abs() < 1e-5,
1545                    "FFT size {}, index {}: SIMD mismatch",
1546                    fft_size,
1547                    i
1548                );
1549                assert!(
1550                    (result[i].im - expected[i].im).abs() < 1e-5,
1551                    "FFT size {}, index {}: SIMD mismatch",
1552                    fft_size,
1553                    i
1554                );
1555            }
1556        }
1557    }
1558
1559    #[test]
1560    fn test_simd_unaligned_sizes() {
1561        // Test with buffer sizes that don't align to SIMD width
1562        // This ensures the scalar remainder loop works correctly
1563        use rustfft::num_complex::Complex;
1564
1565        for size in [1, 3, 5, 7, 9, 13, 17] {
1566            let src: Vec<Complex<f32>> = (0..size)
1567                .map(|i| Complex::new(i as f32, (i as f32) * 0.5))
1568                .collect();
1569            let hrtf: Vec<Complex<f32>> = (0..size)
1570                .map(|i| Complex::new(0.5, (i as f32) * 0.1))
1571                .collect();
1572
1573            // Scalar reference
1574            let expected: Vec<Complex<f32>> =
1575                src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
1576
1577            // SIMD computation
1578            let mut result = vec![Complex::new(0.0, 0.0); size];
1579            complex_mul_simd(&mut result, &src, &hrtf);
1580
1581            // Verify
1582            for i in 0..size {
1583                assert!(
1584                    (result[i].re - expected[i].re).abs() < 1e-6,
1585                    "Size {}, index {}: re mismatch",
1586                    size,
1587                    i
1588                );
1589                assert!(
1590                    (result[i].im - expected[i].im).abs() < 1e-6,
1591                    "Size {}, index {}: im mismatch",
1592                    size,
1593                    i
1594                );
1595            }
1596        }
1597    }
1598
1599    // ============================================================================
1600    // Comprehensive Tests for complex_mul_inplace_simd
1601    // ============================================================================
1602
1603    #[test]
1604    fn test_simd_complex_mul_inplace_correctness() {
1605        // Basic correctness test for in-place complex multiplication
1606        use rustfft::num_complex::Complex;
1607
1608        let src = vec![
1609            Complex::new(2.0, 3.0),
1610            Complex::new(-1.5, 2.5),
1611            Complex::new(0.5, -1.0),
1612            Complex::new(4.0, -2.0),
1613        ];
1614
1615        let hrtf = vec![
1616            Complex::new(1.0, 0.5),
1617            Complex::new(2.0, -1.0),
1618            Complex::new(-0.5, 1.5),
1619            Complex::new(0.75, 0.25),
1620        ];
1621
1622        // Scalar reference
1623        let mut expected = src.clone();
1624        for i in 0..expected.len() {
1625            expected[i] *= hrtf[i];
1626        }
1627
1628        // SIMD in-place computation
1629        let mut result = src.clone();
1630        complex_mul_inplace_simd(&mut result, &hrtf);
1631
1632        const EPSILON: f32 = 1e-6;
1633        for i in 0..result.len() {
1634            assert!(
1635                (result[i].re - expected[i].re).abs() < EPSILON,
1636                "Index {}: re mismatch {} vs {}",
1637                i,
1638                result[i].re,
1639                expected[i].re
1640            );
1641            assert!(
1642                (result[i].im - expected[i].im).abs() < EPSILON,
1643                "Index {}: im mismatch {} vs {}",
1644                i,
1645                result[i].im,
1646                expected[i].im
1647            );
1648        }
1649    }
1650
1651    #[test]
1652    fn test_simd_inplace_large_buffers() {
1653        // Test in-place multiplication with realistic FFT buffer sizes
1654        use rustfft::num_complex::Complex;
1655
1656        for fft_size in [128, 256, 512, 1024, 2048] {
1657            let mut src: Vec<Complex<f32>> = (0..fft_size)
1658                .map(|i| {
1659                    let phase = (i as f32) * 0.01;
1660                    Complex::new(phase.cos(), phase.sin())
1661                })
1662                .collect();
1663
1664            let hrtf: Vec<Complex<f32>> = (0..fft_size)
1665                .map(|i| Complex::new(0.5 + (i as f32) * 0.001, 0.25))
1666                .collect();
1667
1668            // Scalar reference
1669            let mut expected = src.clone();
1670            for i in 0..fft_size {
1671                expected[i] *= hrtf[i];
1672            }
1673
1674            // SIMD computation
1675            complex_mul_inplace_simd(&mut src, &hrtf);
1676
1677            // Verify all elements match
1678            for i in 0..fft_size {
1679                assert!(
1680                    (src[i].re - expected[i].re).abs() < 1e-5,
1681                    "FFT size {}, index {}: re mismatch",
1682                    fft_size,
1683                    i
1684                );
1685                assert!(
1686                    (src[i].im - expected[i].im).abs() < 1e-5,
1687                    "FFT size {}, index {}: im mismatch",
1688                    fft_size,
1689                    i
1690                );
1691            }
1692        }
1693    }
1694
1695    #[test]
1696    fn test_simd_inplace_unaligned() {
1697        // Test with sizes that don't align to SIMD width (4 for AVX2, 2 for NEON)
1698        use rustfft::num_complex::Complex;
1699
1700        for size in [1, 2, 3, 5, 6, 7, 9, 10, 11, 15, 17, 19, 23] {
1701            let mut src: Vec<Complex<f32>> = (0..size)
1702                .map(|i| Complex::new((i as f32) * 0.5, (i as f32) * -0.3))
1703                .collect();
1704
1705            let hrtf: Vec<Complex<f32>> = (0..size)
1706                .map(|i| Complex::new(1.0 + (i as f32) * 0.1, 0.5))
1707                .collect();
1708
1709            // Scalar reference
1710            let mut expected = src.clone();
1711            for i in 0..size {
1712                expected[i] *= hrtf[i];
1713            }
1714
1715            // SIMD computation
1716            complex_mul_inplace_simd(&mut src, &hrtf);
1717
1718            // Verify
1719            for i in 0..size {
1720                assert!(
1721                    (src[i].re - expected[i].re).abs() < 1e-6,
1722                    "Size {}, index {}: re mismatch",
1723                    size,
1724                    i
1725                );
1726                assert!(
1727                    (src[i].im - expected[i].im).abs() < 1e-6,
1728                    "Size {}, index {}: im mismatch",
1729                    size,
1730                    i
1731                );
1732            }
1733        }
1734    }
1735
1736    #[test]
1737    fn test_simd_inplace_edge_cases() {
1738        // Test edge cases: zeros, ones, conjugates, negative values
1739        use rustfft::num_complex::Complex;
1740
1741        // Test 1: Multiply by zero (should zero out the buffer)
1742        let mut src = vec![
1743            Complex::new(1.0, 2.0),
1744            Complex::new(3.0, 4.0),
1745            Complex::new(5.0, 6.0),
1746            Complex::new(7.0, 8.0),
1747        ];
1748        let zero = vec![Complex::new(0.0, 0.0); 4];
1749        complex_mul_inplace_simd(&mut src, &zero);
1750        for i in 0..4 {
1751            assert!(src[i].re.abs() < 1e-6, "Expected zero, got {}", src[i].re);
1752            assert!(src[i].im.abs() < 1e-6, "Expected zero, got {}", src[i].im);
1753        }
1754
1755        // Test 2: Multiply by one (should be identity)
1756        let original = vec![
1757            Complex::new(1.5, 2.5),
1758            Complex::new(-3.5, 4.5),
1759            Complex::new(5.5, -6.5),
1760            Complex::new(-7.5, -8.5),
1761        ];
1762        let mut src = original.clone();
1763        let one = vec![Complex::new(1.0, 0.0); 4];
1764        complex_mul_inplace_simd(&mut src, &one);
1765        for i in 0..4 {
1766            assert!((src[i].re - original[i].re).abs() < 1e-6);
1767            assert!((src[i].im - original[i].im).abs() < 1e-6);
1768        }
1769
1770        // Test 3: Multiply by conjugate (should give real result with magnitude |a|^2)
1771        let a = Complex::new(3.0, 4.0);
1772        let a_conj = Complex::new(3.0, -4.0);
1773        let mut src = vec![a; 8];
1774        let conj = vec![a_conj; 8];
1775        complex_mul_inplace_simd(&mut src, &conj);
1776        // a * conj(a) = |a|^2 = 3^2 + 4^2 = 25
1777        for i in 0..8 {
1778            assert!(
1779                (src[i].re - 25.0).abs() < 1e-5,
1780                "Expected 25.0, got {}",
1781                src[i].re
1782            );
1783            assert!(src[i].im.abs() < 1e-5, "Expected ~0, got {}", src[i].im);
1784        }
1785
1786        // Test 4: Multiply by i (rotation by 90 degrees)
1787        let mut src = vec![Complex::new(1.0, 0.0); 4];
1788        let i_val = vec![Complex::new(0.0, 1.0); 4];
1789        complex_mul_inplace_simd(&mut src, &i_val);
1790        for idx in 0..4 {
1791            assert!(src[idx].re.abs() < 1e-6, "Expected 0, got {}", src[idx].re);
1792            assert!(
1793                (src[idx].im - 1.0).abs() < 1e-6,
1794                "Expected 1, got {}",
1795                src[idx].im
1796            );
1797        }
1798    }
1799
1800    #[test]
1801    fn test_simd_inplace_negative_values() {
1802        // Test with all negative values
1803        use rustfft::num_complex::Complex;
1804
1805        let mut src = vec![
1806            Complex::new(-1.0, -2.0),
1807            Complex::new(-3.0, -4.0),
1808            Complex::new(-5.0, -6.0),
1809            Complex::new(-7.0, -8.0),
1810        ];
1811
1812        let hrtf = vec![
1813            Complex::new(-0.5, -0.25),
1814            Complex::new(-1.0, -1.5),
1815            Complex::new(-2.0, 0.5),
1816            Complex::new(0.75, -0.75),
1817        ];
1818
1819        // Scalar reference
1820        let mut expected = src.clone();
1821        for i in 0..expected.len() {
1822            expected[i] *= hrtf[i];
1823        }
1824
1825        // SIMD computation
1826        complex_mul_inplace_simd(&mut src, &hrtf);
1827
1828        const EPSILON: f32 = 1e-6;
1829        for i in 0..src.len() {
1830            assert!((src[i].re - expected[i].re).abs() < EPSILON);
1831            assert!((src[i].im - expected[i].im).abs() < EPSILON);
1832        }
1833    }
1834
1835    // ============================================================================
1836    // Comprehensive Tests for compute_covariance_simd
1837    // ============================================================================
1838
1839    #[test]
1840    fn test_covariance_basic_correctness() {
1841        // Test basic covariance computation against scalar reference
1842        use rustfft::num_complex::Complex;
1843
1844        let left = vec![
1845            Complex::new(1.0, 2.0),
1846            Complex::new(3.0, 4.0),
1847            Complex::new(-1.0, 0.5),
1848            Complex::new(0.0, -2.0),
1849            Complex::new(2.5, -1.5),
1850            Complex::new(-3.5, 2.5),
1851            Complex::new(1.1, -0.9),
1852            Complex::new(-0.8, 1.2),
1853        ];
1854
1855        let right = vec![
1856            Complex::new(0.5, 0.25),
1857            Complex::new(-1.0, 1.5),
1858            Complex::new(2.0, -0.5),
1859            Complex::new(0.75, 0.75),
1860            Complex::new(-0.5, 2.0),
1861            Complex::new(1.5, -1.0),
1862            Complex::new(0.9, 0.3),
1863            Complex::new(-1.1, 0.7),
1864        ];
1865
1866        // Scalar reference
1867        let mut expected_xx = 0.0_f32;
1868        let mut expected_yy = 0.0_f32;
1869        let mut expected_xy = Complex::new(0.0, 0.0);
1870        for i in 0..left.len() {
1871            expected_xx += left[i].norm_sqr();
1872            expected_yy += right[i].norm_sqr();
1873            expected_xy += left[i] * right[i].conj();
1874        }
1875
1876        // SIMD computation
1877        let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, 0, left.len());
1878
1879        const EPSILON: f32 = 1e-5;
1880        assert!(
1881            (cov_xx - expected_xx).abs() < EPSILON,
1882            "cov_xx mismatch: {} vs {}",
1883            cov_xx,
1884            expected_xx
1885        );
1886        assert!(
1887            (cov_yy - expected_yy).abs() < EPSILON,
1888            "cov_yy mismatch: {} vs {}",
1889            cov_yy,
1890            expected_yy
1891        );
1892        assert!(
1893            (cov_xy.re - expected_xy.re).abs() < EPSILON,
1894            "cov_xy.re mismatch: {} vs {}",
1895            cov_xy.re,
1896            expected_xy.re
1897        );
1898        assert!(
1899            (cov_xy.im - expected_xy.im).abs() < EPSILON,
1900            "cov_xy.im mismatch: {} vs {}",
1901            cov_xy.im,
1902            expected_xy.im
1903        );
1904    }
1905
1906    #[test]
1907    fn test_covariance_with_ranges() {
1908        // Test covariance computation with different start/end ranges
1909        use rustfft::num_complex::Complex;
1910
1911        let left: Vec<Complex<f32>> = (0..32)
1912            .map(|i| Complex::new(i as f32 * 0.5, i as f32 * -0.3))
1913            .collect();
1914        let right: Vec<Complex<f32>> = (0..32)
1915            .map(|i| Complex::new(i as f32 * -0.4, i as f32 * 0.6))
1916            .collect();
1917
1918        // Test various ranges
1919        for (start, end) in [(0, 8), (4, 12), (10, 20), (5, 25), (0, 32)] {
1920            // Scalar reference
1921            let mut expected_xx = 0.0_f32;
1922            let mut expected_yy = 0.0_f32;
1923            let mut expected_xy = Complex::new(0.0, 0.0);
1924            for i in start..end {
1925                expected_xx += left[i].norm_sqr();
1926                expected_yy += right[i].norm_sqr();
1927                expected_xy += left[i] * right[i].conj();
1928            }
1929
1930            // SIMD computation
1931            let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, start, end);
1932
1933            const EPSILON: f32 = 1e-4;
1934            assert!(
1935                (cov_xx - expected_xx).abs() < EPSILON,
1936                "Range [{}, {}): cov_xx mismatch: {} vs {}",
1937                start,
1938                end,
1939                cov_xx,
1940                expected_xx
1941            );
1942            assert!(
1943                (cov_yy - expected_yy).abs() < EPSILON,
1944                "Range [{}, {}): cov_yy mismatch: {} vs {}",
1945                start,
1946                end,
1947                cov_yy,
1948                expected_yy
1949            );
1950            assert!(
1951                (cov_xy.re - expected_xy.re).abs() < EPSILON,
1952                "Range [{}, {}): cov_xy.re mismatch: {} vs {}",
1953                start,
1954                end,
1955                cov_xy.re,
1956                expected_xy.re
1957            );
1958            assert!(
1959                (cov_xy.im - expected_xy.im).abs() < EPSILON,
1960                "Range [{}, {}): cov_xy.im mismatch: {} vs {}",
1961                start,
1962                end,
1963                cov_xy.im,
1964                expected_xy.im
1965            );
1966        }
1967    }
1968
1969    #[test]
1970    fn test_covariance_large_buffers() {
1971        // Test with realistic FFT buffer sizes
1972        use rustfft::num_complex::Complex;
1973
1974        for fft_size in [128, 256, 512, 1024, 2048, 4096] {
1975            let left: Vec<Complex<f32>> = (0..fft_size)
1976                .map(|i| {
1977                    let phase = (i as f32) * 0.01;
1978                    Complex::new(phase.cos(), phase.sin())
1979                })
1980                .collect();
1981
1982            let right: Vec<Complex<f32>> = (0..fft_size)
1983                .map(|i| {
1984                    let phase = (i as f32) * 0.02;
1985                    Complex::new(phase.sin(), phase.cos())
1986                })
1987                .collect();
1988
1989            // Scalar reference
1990            let mut expected_xx = 0.0_f32;
1991            let mut expected_yy = 0.0_f32;
1992            let mut expected_xy = Complex::new(0.0, 0.0);
1993            for i in 0..fft_size {
1994                expected_xx += left[i].norm_sqr();
1995                expected_yy += right[i].norm_sqr();
1996                expected_xy += left[i] * right[i].conj();
1997            }
1998
1999            // SIMD computation
2000            let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, 0, fft_size);
2001
2002            // Relative tolerance for larger sums
2003            let rel_epsilon = 1e-4;
2004            assert!(
2005                (cov_xx - expected_xx).abs() < expected_xx * rel_epsilon,
2006                "FFT size {}: cov_xx mismatch",
2007                fft_size
2008            );
2009            assert!(
2010                (cov_yy - expected_yy).abs() < expected_yy * rel_epsilon,
2011                "FFT size {}: cov_yy mismatch",
2012                fft_size
2013            );
2014            assert!(
2015                (cov_xy.re - expected_xy.re).abs() < expected_xy.re.abs() * rel_epsilon + 1e-5,
2016                "FFT size {}: cov_xy.re mismatch",
2017                fft_size
2018            );
2019            assert!(
2020                (cov_xy.im - expected_xy.im).abs() < expected_xy.im.abs() * rel_epsilon + 1e-5,
2021                "FFT size {}: cov_xy.im mismatch",
2022                fft_size
2023            );
2024        }
2025    }
2026
2027    #[test]
2028    fn test_covariance_unaligned_ranges() {
2029        // Test with ranges that don't align to SIMD width
2030        use rustfft::num_complex::Complex;
2031
2032        let left: Vec<Complex<f32>> = (0..50)
2033            .map(|i| Complex::new(i as f32 * 0.2, i as f32 * 0.3))
2034            .collect();
2035        let right: Vec<Complex<f32>> = (0..50)
2036            .map(|i| Complex::new(i as f32 * -0.1, i as f32 * 0.4))
2037            .collect();
2038
2039        // Test with various unaligned ranges
2040        for (start, end) in [(0, 1), (0, 3), (1, 4), (2, 7), (5, 11), (10, 23), (15, 37)] {
2041            // Scalar reference
2042            let mut expected_xx = 0.0_f32;
2043            let mut expected_yy = 0.0_f32;
2044            let mut expected_xy = Complex::new(0.0, 0.0);
2045            for i in start..end {
2046                expected_xx += left[i].norm_sqr();
2047                expected_yy += right[i].norm_sqr();
2048                expected_xy += left[i] * right[i].conj();
2049            }
2050
2051            // SIMD computation
2052            let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&left, &right, start, end);
2053
2054            const EPSILON: f32 = 1e-5;
2055            assert!(
2056                (cov_xx - expected_xx).abs() < EPSILON,
2057                "Range [{}, {}): cov_xx mismatch",
2058                start,
2059                end
2060            );
2061            assert!(
2062                (cov_yy - expected_yy).abs() < EPSILON,
2063                "Range [{}, {}): cov_yy mismatch",
2064                start,
2065                end
2066            );
2067            assert!(
2068                (cov_xy.re - expected_xy.re).abs() < EPSILON,
2069                "Range [{}, {}): cov_xy.re mismatch",
2070                start,
2071                end
2072            );
2073            assert!(
2074                (cov_xy.im - expected_xy.im).abs() < EPSILON,
2075                "Range [{}, {}): cov_xy.im mismatch",
2076                start,
2077                end
2078            );
2079        }
2080    }
2081
2082    #[test]
2083    fn test_covariance_edge_cases() {
2084        // Test covariance with edge cases
2085        use rustfft::num_complex::Complex;
2086
2087        // Test 1: All zeros
2088        let zero_left = vec![Complex::new(0.0, 0.0); 8];
2089        let zero_right = vec![Complex::new(0.0, 0.0); 8];
2090        let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&zero_left, &zero_right, 0, 8);
2091        assert!(cov_xx.abs() < 1e-6, "Expected zero cov_xx");
2092        assert!(cov_yy.abs() < 1e-6, "Expected zero cov_yy");
2093        assert!(cov_xy.norm_sqr() < 1e-6, "Expected zero cov_xy");
2094
2095        // Test 2: Real-only signals
2096        let real_left: Vec<Complex<f32>> = (0..8).map(|i| Complex::new(i as f32, 0.0)).collect();
2097        let real_right: Vec<Complex<f32>> = (0..8)
2098            .map(|i| Complex::new((i as f32) * 0.5, 0.0))
2099            .collect();
2100        let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&real_left, &real_right, 0, 8);
2101
2102        // cov_xy should be real (imaginary part should be ~0)
2103        assert!(
2104            cov_xy.im.abs() < 1e-5,
2105            "Expected real cov_xy for real signals"
2106        );
2107
2108        // Verify values match scalar computation
2109        let mut expected_xx = 0.0;
2110        let mut expected_yy = 0.0;
2111        for i in 0..8 {
2112            expected_xx += (i * i) as f32;
2113            expected_yy += ((i as f32) * 0.5).powi(2);
2114        }
2115        assert!((cov_xx - expected_xx).abs() < 1e-5);
2116        assert!((cov_yy - expected_yy).abs() < 1e-5);
2117
2118        // Test 3: Imaginary-only signals
2119        let imag_left: Vec<Complex<f32>> = (0..8).map(|i| Complex::new(0.0, i as f32)).collect();
2120        let imag_right: Vec<Complex<f32>> = (0..8)
2121            .map(|i| Complex::new(0.0, (i as f32) * 2.0))
2122            .collect();
2123        let (_cov_xx, _cov_yy, cov_xy) = compute_covariance_simd(&imag_left, &imag_right, 0, 8);
2124
2125        // cov_xy should be real (because conj flips imaginary sign)
2126        assert!(
2127            cov_xy.im.abs() < 1e-5,
2128            "Expected real cov_xy for imaginary signals"
2129        );
2130
2131        // Test 4: Single element range
2132        let single_left = vec![Complex::new(3.0, 4.0)];
2133        let single_right = vec![Complex::new(1.0, 2.0)];
2134        let (cov_xx, cov_yy, cov_xy) = compute_covariance_simd(&single_left, &single_right, 0, 1);
2135        assert!((cov_xx - 25.0).abs() < 1e-5); // 3^2 + 4^2 = 25
2136        assert!((cov_yy - 5.0).abs() < 1e-5); // 1^2 + 2^2 = 5
2137        // (3 + 4i) * (1 - 2i) = 3 - 6i + 4i - 8i^2 = 3 - 2i + 8 = 11 - 2i
2138        assert!((cov_xy.re - 11.0).abs() < 1e-5);
2139        assert!((cov_xy.im - (-2.0)).abs() < 1e-5);
2140    }
2141
2142    // ============================================================================
2143    // Numerical Accuracy and Stress Tests
2144    // ============================================================================
2145
2146    #[test]
2147    fn test_numerical_accuracy_small_values() {
2148        // Test with very small values to check for denormal handling
2149        use rustfft::num_complex::Complex;
2150
2151        let small = 1e-20_f32;
2152        let src = vec![
2153            Complex::new(small, small),
2154            Complex::new(small * 2.0, small * 3.0),
2155            Complex::new(small * 4.0, small * 5.0),
2156            Complex::new(small * 6.0, small * 7.0),
2157        ];
2158
2159        let hrtf = vec![
2160            Complex::new(1.0, 0.5),
2161            Complex::new(2.0, -1.0),
2162            Complex::new(-0.5, 1.5),
2163            Complex::new(0.75, 0.25),
2164        ];
2165
2166        // Scalar reference
2167        let expected: Vec<Complex<f32>> = src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2168
2169        // SIMD computation
2170        let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2171        complex_mul_simd(&mut result, &src, &hrtf);
2172
2173        // Relative tolerance for very small numbers
2174        for i in 0..src.len() {
2175            let re_diff = (result[i].re - expected[i].re).abs();
2176            let im_diff = (result[i].im - expected[i].im).abs();
2177
2178            // For very small numbers, check relative error or absolute error for near-zero
2179            if expected[i].re.abs() > 1e-15 {
2180                assert!(re_diff / expected[i].re.abs() < 1e-3);
2181            } else {
2182                assert!(re_diff < 1e-25);
2183            }
2184
2185            if expected[i].im.abs() > 1e-15 {
2186                assert!(im_diff / expected[i].im.abs() < 1e-3);
2187            } else {
2188                assert!(im_diff < 1e-25);
2189            }
2190        }
2191    }
2192
2193    #[test]
2194    fn test_numerical_accuracy_large_values() {
2195        // Test with large values to check for overflow handling
2196        use rustfft::num_complex::Complex;
2197
2198        let large = 1e10_f32;
2199        let src = vec![
2200            Complex::new(large, large * 0.5),
2201            Complex::new(large * 2.0, large * 1.5),
2202            Complex::new(large * 0.3, large * 0.7),
2203            Complex::new(large * 1.2, large * 0.8),
2204        ];
2205
2206        let hrtf = vec![
2207            Complex::new(1e-5, 5e-6),
2208            Complex::new(2e-5, -1e-5),
2209            Complex::new(-5e-6, 1.5e-5),
2210            Complex::new(7.5e-6, 2.5e-6),
2211        ];
2212
2213        // Scalar reference
2214        let mut expected = vec![Complex::new(0.0, 0.0); src.len()];
2215        for i in 0..src.len() {
2216            expected[i] = src[i] * hrtf[i];
2217        }
2218
2219        // SIMD computation
2220        let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2221        complex_mul_simd(&mut result, &src, &hrtf);
2222
2223        // Relative tolerance
2224        for i in 0..src.len() {
2225            let re_rel_err = (result[i].re - expected[i].re).abs() / expected[i].re.abs().max(1.0);
2226            let im_rel_err = (result[i].im - expected[i].im).abs() / expected[i].im.abs().max(1.0);
2227            assert!(
2228                re_rel_err < 1e-5,
2229                "Index {}: re rel error too large: {}",
2230                i,
2231                re_rel_err
2232            );
2233            assert!(
2234                im_rel_err < 1e-5,
2235                "Index {}: im rel error too large: {}",
2236                i,
2237                im_rel_err
2238            );
2239        }
2240    }
2241
2242    #[test]
2243    fn test_accumulation_accuracy() {
2244        // Test that repeated accumulation doesn't degrade accuracy significantly
2245        use rustfft::num_complex::Complex;
2246
2247        let src = vec![
2248            Complex::new(0.1, 0.2),
2249            Complex::new(0.3, 0.4),
2250            Complex::new(0.5, 0.6),
2251            Complex::new(0.7, 0.8),
2252        ];
2253
2254        let hrtf = vec![
2255            Complex::new(0.5, 0.25),
2256            Complex::new(-1.0, 1.5),
2257            Complex::new(2.0, -0.5),
2258            Complex::new(0.75, 0.75),
2259        ];
2260
2261        // Scalar reference: accumulate 100 times
2262        let mut expected = vec![Complex::new(0.0, 0.0); src.len()];
2263        for _ in 0..100 {
2264            for i in 0..src.len() {
2265                expected[i] += src[i] * hrtf[i];
2266            }
2267        }
2268
2269        // SIMD computation: accumulate 100 times
2270        let mut result = vec![Complex::new(0.0, 0.0); src.len()];
2271        for _ in 0..100 {
2272            complex_mul_add_simd(&mut result, &src, &hrtf);
2273        }
2274
2275        // Check relative error
2276        const REL_EPSILON: f32 = 1e-4;
2277        for i in 0..src.len() {
2278            let re_abs_err = (result[i].re - expected[i].re).abs();
2279            let im_abs_err = (result[i].im - expected[i].im).abs();
2280
2281            // Use relative error if expected value is large enough, otherwise use absolute error
2282            let re_err = if expected[i].re.abs() > 1e-6 {
2283                re_abs_err / expected[i].re.abs()
2284            } else {
2285                re_abs_err
2286            };
2287            let im_err = if expected[i].im.abs() > 1e-6 {
2288                im_abs_err / expected[i].im.abs()
2289            } else {
2290                im_abs_err
2291            };
2292
2293            assert!(
2294                re_err < REL_EPSILON,
2295                "Index {}: re accumulated error too large: {} (abs: {}, expected: {})",
2296                i,
2297                re_err,
2298                re_abs_err,
2299                expected[i].re
2300            );
2301            assert!(
2302                im_err < REL_EPSILON,
2303                "Index {}: im accumulated error too large: {} (abs: {}, expected: {})",
2304                i,
2305                im_err,
2306                im_abs_err,
2307                expected[i].im
2308            );
2309        }
2310    }
2311
2312    #[test]
2313    fn test_platform_specific_simd_widths() {
2314        // Test that SIMD code paths are correctly exercised
2315        use rustfft::num_complex::Complex;
2316
2317        // Test sizes that exercise SIMD boundaries:
2318        // - AVX2 processes 4 complex at a time
2319        // - NEON processes 2 complex at a time
2320        let test_sizes = vec![
2321            1,  // No SIMD, scalar only
2322            2,  // NEON: 1 iteration, AVX2: scalar only
2323            3,  // NEON: 1 iteration + 1 scalar, AVX2: scalar only
2324            4,  // NEON: 2 iterations, AVX2: 1 iteration
2325            5,  // NEON: 2 iterations + 1 scalar, AVX2: 1 iteration + 1 scalar
2326            8,  // NEON: 4 iterations, AVX2: 2 iterations
2327            9,  // Mixed
2328            12, // NEON: 6 iterations, AVX2: 3 iterations
2329            16, // NEON: 8 iterations, AVX2: 4 iterations
2330        ];
2331
2332        for size in test_sizes {
2333            let src: Vec<Complex<f32>> = (0..size)
2334                .map(|i| Complex::new(i as f32 * 0.3, i as f32 * -0.2))
2335                .collect();
2336            let hrtf: Vec<Complex<f32>> = (0..size)
2337                .map(|i| Complex::new(1.0 + i as f32 * 0.1, 0.5))
2338                .collect();
2339
2340            // Test all three functions
2341
2342            // 1. complex_mul_add_simd
2343            let mut result_add = vec![Complex::new(1.0, 2.0); size];
2344            let mut expected_add = result_add.clone();
2345            for i in 0..size {
2346                expected_add[i] += src[i] * hrtf[i];
2347            }
2348            complex_mul_add_simd(&mut result_add, &src, &hrtf);
2349            for i in 0..size {
2350                assert!(
2351                    (result_add[i].re - expected_add[i].re).abs() < 1e-6,
2352                    "mul_add size {}, index {}: re mismatch",
2353                    size,
2354                    i
2355                );
2356                assert!(
2357                    (result_add[i].im - expected_add[i].im).abs() < 1e-6,
2358                    "mul_add size {}, index {}: im mismatch",
2359                    size,
2360                    i
2361                );
2362            }
2363
2364            // 2. complex_mul_simd
2365            let mut result_mul = vec![Complex::new(0.0, 0.0); size];
2366            let expected_mul: Vec<Complex<f32>> =
2367                src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2368            complex_mul_simd(&mut result_mul, &src, &hrtf);
2369            for i in 0..size {
2370                assert!(
2371                    (result_mul[i].re - expected_mul[i].re).abs() < 1e-6,
2372                    "mul size {}, index {}: re mismatch",
2373                    size,
2374                    i
2375                );
2376                assert!(
2377                    (result_mul[i].im - expected_mul[i].im).abs() < 1e-6,
2378                    "mul size {}, index {}: im mismatch",
2379                    size,
2380                    i
2381                );
2382            }
2383
2384            // 3. complex_mul_inplace_simd
2385            let mut result_inplace = src.clone();
2386            let mut expected_inplace = src.clone();
2387            for i in 0..size {
2388                expected_inplace[i] *= hrtf[i];
2389            }
2390            complex_mul_inplace_simd(&mut result_inplace, &hrtf);
2391            for i in 0..size {
2392                assert!(
2393                    (result_inplace[i].re - expected_inplace[i].re).abs() < 1e-6,
2394                    "inplace size {}, index {}: re mismatch",
2395                    size,
2396                    i
2397                );
2398                assert!(
2399                    (result_inplace[i].im - expected_inplace[i].im).abs() < 1e-6,
2400                    "inplace size {}, index {}: im mismatch",
2401                    size,
2402                    i
2403                );
2404            }
2405        }
2406    }
2407
2408    #[test]
2409    fn test_stress_test_random_data() {
2410        // Stress test with pseudo-random data
2411        use rustfft::num_complex::Complex;
2412
2413        // Simple LCG for deterministic "random" values
2414        let mut seed = 12345_u32;
2415        let lcg = |s: &mut u32| -> f32 {
2416            *s = s.wrapping_mul(1103515245).wrapping_add(12345);
2417            ((*s / 65536) % 32768) as f32 / 32768.0 - 0.5
2418        };
2419
2420        for size in [64, 128, 256, 512] {
2421            let src: Vec<Complex<f32>> = (0..size)
2422                .map(|_| Complex::new(lcg(&mut seed), lcg(&mut seed)))
2423                .collect();
2424            let hrtf: Vec<Complex<f32>> = (0..size)
2425                .map(|_| Complex::new(lcg(&mut seed), lcg(&mut seed)))
2426                .collect();
2427
2428            // Scalar reference
2429            let expected: Vec<Complex<f32>> =
2430                src.iter().zip(hrtf.iter()).map(|(a, b)| a * b).collect();
2431
2432            // SIMD computation
2433            let mut result = vec![Complex::new(0.0, 0.0); size];
2434            complex_mul_simd(&mut result, &src, &hrtf);
2435
2436            // Verify
2437            for i in 0..size {
2438                assert!(
2439                    (result[i].re - expected[i].re).abs() < 1e-5,
2440                    "Stress test size {}, index {}: re mismatch",
2441                    size,
2442                    i
2443                );
2444                assert!(
2445                    (result[i].im - expected[i].im).abs() < 1e-5,
2446                    "Stress test size {}, index {}: im mismatch",
2447                    size,
2448                    i
2449                );
2450            }
2451        }
2452    }
2453}
2454
2455// ============================================================================
2456// SIMD Covariance Calculation for ERB Band Processing
2457// ============================================================================
2458//
2459// Computes covariance statistics for two complex arrays (left/right channels):
2460//   cov_xx = sum(|left[i]|^2)
2461//   cov_yy = sum(|right[i]|^2)
2462//   cov_xy = sum(left[i] * conj(right[i]))
2463//
2464// These are used for PCA-based direct/ambient decomposition in the upmixer.
2465
2466/// SIMD-accelerated covariance calculation for ERB bands
2467///
2468/// Returns (cov_xx, cov_yy, cov_xy) where:
2469/// - cov_xx: sum of left channel energy
2470/// - cov_yy: sum of right channel energy
2471/// - cov_xy: complex cross-correlation
2472pub fn compute_covariance_simd(
2473    left: &[Complex<f32>],
2474    right: &[Complex<f32>],
2475    start: usize,
2476    end: usize,
2477) -> (f32, f32, Complex<f32>) {
2478    assert_eq!(left.len(), right.len());
2479    assert!(end <= left.len());
2480    assert!(start < end);
2481
2482    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2483    let count = end - start;
2484
2485    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2486    {
2487        use std::arch::x86_64::*;
2488
2489        let mut cov_xx;
2490        let mut cov_yy;
2491        let mut cov_xy = Complex::new(0.0, 0.0);
2492
2493        // SIMD path: process 4 complex numbers at once
2494        let simd_len = (count / 4) * 4;
2495        let simd_end = start + simd_len;
2496
2497        unsafe {
2498            let mut sum_xx = _mm256_setzero_ps();
2499            let mut sum_yy = _mm256_setzero_ps();
2500            let mut sum_xy_re = _mm256_setzero_ps();
2501            let _sum_xy_im = _mm256_setzero_ps();
2502
2503            for i in (start..simd_end).step_by(4) {
2504                let left_ptr = left.as_ptr().add(i) as *const f32;
2505                let right_ptr = right.as_ptr().add(i) as *const f32;
2506
2507                // Load 4 complex numbers: [re0, im0, re1, im1, re2, im2, re3, im3]
2508                let l = _mm256_loadu_ps(left_ptr);
2509                let r = _mm256_loadu_ps(right_ptr);
2510
2511                // Compute norm_sqr: re^2 + im^2
2512                let l_sqr = _mm256_mul_ps(l, l);
2513                let r_sqr = _mm256_mul_ps(r, r);
2514
2515                // Horizontal add pairs: [re0^2 + im0^2, re1^2 + im1^2, ...]
2516                let l_norm = _mm256_hadd_ps(l_sqr, l_sqr);
2517                let r_norm = _mm256_hadd_ps(r_sqr, r_sqr);
2518
2519                sum_xx = _mm256_add_ps(sum_xx, l_norm);
2520                sum_yy = _mm256_add_ps(sum_yy, r_norm);
2521
2522                // Compute cross-correlation: left * conj(right)
2523                let sign_mask = _mm256_set_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
2524                let r_conj = _mm256_xor_ps(r, sign_mask);
2525
2526                // Complex multiplication: left * r_conj
2527                let l_re = _mm256_moveldup_ps(l);
2528                let l_im = _mm256_movehdup_ps(l);
2529
2530                let ac_ad = _mm256_mul_ps(l_re, r_conj);
2531                let r_conj_swap = _mm256_shuffle_ps(r_conj, r_conj, 0b10110001);
2532                let bd_bc = _mm256_mul_ps(l_im, r_conj_swap);
2533
2534                let result = _mm256_addsub_ps(ac_ad, bd_bc);
2535
2536                // Accumulate real parts (even indices) and imaginary parts (odd indices)
2537                sum_xy_re = _mm256_add_ps(sum_xy_re, result);
2538            }
2539
2540            // Horizontal reduction to scalars
2541            let xx_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_xx);
2542            let yy_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_yy);
2543            let xy_arr = std::mem::transmute::<__m256, [f32; 8]>(sum_xy_re);
2544
2545            // Sum hadd results (each norm appears twice in each lane)
2546            // Lane 0: [norm0, norm1, norm0, norm1], Lane 1: [norm2, norm3, norm2, norm3]
2547            cov_xx = xx_arr[0] + xx_arr[1] + xx_arr[4] + xx_arr[5];
2548            cov_yy = yy_arr[0] + yy_arr[1] + yy_arr[4] + yy_arr[5];
2549
2550            // Sum re (even) and im (odd) separately
2551            cov_xy.re = xy_arr[0] + xy_arr[2] + xy_arr[4] + xy_arr[6];
2552            cov_xy.im = xy_arr[1] + xy_arr[3] + xy_arr[5] + xy_arr[7];
2553        }
2554
2555        // Scalar tail for remaining elements
2556        for i in simd_end..end {
2557            let l = left[i];
2558            let r = right[i];
2559            cov_xx += l.norm_sqr();
2560            cov_yy += r.norm_sqr();
2561            cov_xy += l * r.conj();
2562        }
2563
2564        (cov_xx, cov_yy, cov_xy)
2565    }
2566
2567    #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
2568    {
2569        let mut cov_xx = 0.0_f32;
2570        let mut cov_yy = 0.0_f32;
2571        let mut cov_xy = Complex::new(0.0, 0.0);
2572
2573        for i in start..end {
2574            let l = left[i];
2575            let r = right[i];
2576            cov_xx += l.norm_sqr();
2577            cov_yy += r.norm_sqr();
2578            cov_xy += l * r.conj();
2579        }
2580
2581        (cov_xx, cov_yy, cov_xy)
2582    }
2583}
2584
2585/// SIMD-optimized gain application for a single gain value
2586#[inline]
2587pub fn apply_gain_simd(buffer: &mut [f32], gain: f32) {
2588    let len = buffer.len();
2589
2590    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2591    {
2592        use std::arch::x86_64::*;
2593        let gain_vec = unsafe { _mm256_set1_ps(gain) };
2594        let simd_len = (len / 8) * 8;
2595        for i in (0..simd_len).step_by(8) {
2596            unsafe {
2597                let ptr = buffer.as_mut_ptr().add(i);
2598                let v = _mm256_loadu_ps(ptr);
2599                let res = _mm256_mul_ps(v, gain_vec);
2600                _mm256_storeu_ps(ptr, res);
2601            }
2602        }
2603        for sample in buffer.iter_mut().take(len).skip(simd_len) {
2604            *sample *= gain;
2605        }
2606    }
2607
2608    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2609    {
2610        use std::arch::aarch64::*;
2611        let gain_vec = unsafe { vdupq_n_f32(gain) };
2612        let simd_len = (len / 4) * 4;
2613        for i in (0..simd_len).step_by(4) {
2614            unsafe {
2615                let ptr = buffer.as_mut_ptr().add(i);
2616                let v = vld1q_f32(ptr);
2617                let res = vmulq_f32(v, gain_vec);
2618                vst1q_f32(ptr, res);
2619            }
2620        }
2621        for sample in buffer[simd_len..len].iter_mut() {
2622            *sample *= gain;
2623        }
2624    }
2625
2626    #[cfg(not(any(
2627        all(target_arch = "x86_64", target_feature = "avx2"),
2628        all(target_arch = "aarch64", target_feature = "neon")
2629    )))]
2630    {
2631        for val in buffer.iter_mut() {
2632            *val *= gain;
2633        }
2634    }
2635}
2636
2637/// SIMD-optimized per-channel gain application
2638#[inline]
2639pub fn apply_per_channel_gain_simd(buffer: &mut [f32], channels: usize, gains: &[f32]) {
2640    let len = buffer.len();
2641    let num_frames = len / channels;
2642
2643    // This is harder to SIMD generically for any channel count.
2644    // We prioritize common channel counts (1, 2, 6, 8) or use scalar for now.
2645    // For stereo (channels == 2), we can optimize easily.
2646    if channels == 2 {
2647        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2648        {
2649            use std::arch::x86_64::*;
2650            let gains_vec = unsafe {
2651                _mm256_set_ps(
2652                    gains[1], gains[0], gains[1], gains[0], gains[1], gains[0], gains[1], gains[0],
2653                )
2654            };
2655            let simd_len = (num_frames / 4) * 4;
2656            for i in (0..simd_len).step_by(4) {
2657                unsafe {
2658                    let ptr = buffer.as_mut_ptr().add(i * 2);
2659                    let v = _mm256_loadu_ps(ptr);
2660                    let res = _mm256_mul_ps(v, gains_vec);
2661                    _mm256_storeu_ps(ptr, res);
2662                }
2663            }
2664            for i in simd_len..num_frames {
2665                buffer[i * 2] *= gains[0];
2666                buffer[i * 2 + 1] *= gains[1];
2667            }
2668            return;
2669        }
2670
2671        #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2672        {
2673            use std::arch::aarch64::*;
2674            let gains_vec = unsafe {
2675                let g = [gains[0], gains[1], gains[0], gains[1]];
2676                vld1q_f32(g.as_ptr())
2677            };
2678            let simd_len = (num_frames / 2) * 2;
2679            for i in (0..simd_len).step_by(2) {
2680                unsafe {
2681                    let ptr = buffer.as_mut_ptr().add(i * 2);
2682                    let v = vld1q_f32(ptr);
2683                    let res = vmulq_f32(v, gains_vec);
2684                    vst1q_f32(ptr, res);
2685                }
2686            }
2687            for i in simd_len..num_frames {
2688                buffer[i * 2] *= gains[0];
2689                buffer[i * 2 + 1] *= gains[1];
2690            }
2691            return;
2692        }
2693    }
2694
2695    // Fallback scalar loop
2696    for frame in 0..num_frames {
2697        for ch in 0..channels {
2698            buffer[frame * channels + ch] *= gains[ch];
2699        }
2700    }
2701}
2702
2703/// Fast inverse square root (1/sqrt(x)) using one Newton-Raphson iteration.
2704/// ~0.1% relative error — suitable for normalizing all-pass filter magnitudes.
2705#[inline(always)]
2706pub fn fast_inv_sqrt(x: f32) -> f32 {
2707    let half = 0.5 * x;
2708    let i = f32::to_bits(x);
2709    let i = 0x5f37_59df - (i >> 1); // Initial "magic number" estimate
2710    let y = f32::from_bits(i);
2711    y * (1.5 - half * y * y) // One Newton-Raphson refinement
2712}
2713
2714/// SIMD-optimized peak detection (maximum absolute value)
2715#[inline]
2716pub fn find_max_abs_simd(samples: &[f32]) -> f32 {
2717    let len = samples.len();
2718    if len == 0 {
2719        return 0.0;
2720    }
2721
2722    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
2723    {
2724        use std::arch::x86_64::*;
2725        let mut max_vec = unsafe { _mm256_setzero_ps() };
2726        let abs_mask = unsafe { _mm256_set1_ps(-0.0) };
2727        let simd_len = (len / 8) * 8;
2728
2729        for i in (0..simd_len).step_by(8) {
2730            unsafe {
2731                let ptr = samples.as_ptr().add(i);
2732                let v = _mm256_loadu_ps(ptr);
2733                let av = _mm256_andnot_ps(abs_mask, v);
2734                max_vec = _mm256_max_ps(max_vec, av);
2735            }
2736        }
2737
2738        let mut max_val = 0.0_f32;
2739        unsafe {
2740            let arr = std::mem::transmute::<__m256, [f32; 8]>(max_vec);
2741            for &v in &arr {
2742                if v > max_val {
2743                    max_val = v;
2744                }
2745            }
2746        }
2747
2748        for sample in samples.iter().take(len).skip(simd_len) {
2749            let v = sample.abs();
2750            if v > max_val {
2751                max_val = v;
2752            }
2753        }
2754        max_val
2755    }
2756
2757    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
2758    {
2759        use std::arch::aarch64::*;
2760        let mut max_vec = unsafe { vdupq_n_f32(0.0) };
2761        let simd_len = (len / 4) * 4;
2762
2763        for i in (0..simd_len).step_by(4) {
2764            unsafe {
2765                let ptr = samples.as_ptr().add(i);
2766                let v = vld1q_f32(ptr);
2767                let av = vabsq_f32(v);
2768                max_vec = vmaxq_f32(max_vec, av);
2769            }
2770        }
2771
2772        let mut max_val = unsafe { vmaxvq_f32(max_vec) };
2773
2774        for sample in &samples[simd_len..len] {
2775            let v = sample.abs();
2776            if v > max_val {
2777                max_val = v;
2778            }
2779        }
2780        max_val
2781    }
2782
2783    #[cfg(not(any(
2784        all(target_arch = "x86_64", target_feature = "avx2"),
2785        all(target_arch = "aarch64", target_feature = "neon")
2786    )))]
2787    {
2788        let mut max_val = 0.0_f32;
2789        for &s in samples {
2790            let v = s.abs();
2791            if v > max_val {
2792                max_val = v;
2793            }
2794        }
2795        max_val
2796    }
2797}