Skip to main content

opus_rs/
mdct.rs

1use crate::kiss_fft::{KissCpx, KissFftState, opus_fft_impl};
2use std::f32::consts::PI;
3use std::mem::MaybeUninit;
4
5const MAX_N2: usize = 960;
6const MAX_N4: usize = 480;
7
8pub struct MdctLookup {
9    pub n: usize,
10    pub max_lm: usize,
11    kfft: Vec<Option<KissFftState>>,
12    trig: Vec<f32>,
13}
14
15impl MdctLookup {
16    pub fn new(n: usize, max_lm: usize) -> Self {
17        let mut kfft = Vec::new();
18        let mut trig = Vec::new();
19        let mut curr_n = n;
20
21        for shift in 0..=max_lm {
22            let n4 = curr_n / 4;
23
24            if shift == 0 {
25                kfft.push(KissFftState::new(n4));
26            } else if let Some(base) = kfft.first().unwrap().as_ref() {
27                kfft.push(KissFftState::new_sub(base, n4));
28            } else {
29                kfft.push(None);
30            }
31
32            let n2 = curr_n / 2;
33            for i in 0..n2 {
34                let angle = 2.0 * PI * (i as f32 + 0.125) / curr_n as f32;
35                trig.push(angle.cos());
36            }
37
38            curr_n >>= 1;
39        }
40
41        Self {
42            n,
43            max_lm,
44            kfft,
45            trig,
46        }
47    }
48
49    fn get_trig(&self, shift: usize) -> (&[f32], usize) {
50        let mut offset = 0;
51        let mut curr_n = self.n;
52        for _ in 0..shift {
53            offset += curr_n / 2;
54            curr_n >>= 1;
55        }
56        (&self.trig[offset..offset + curr_n / 2], curr_n / 4)
57    }
58
59    pub fn get_trig_debug(&self, shift: usize) -> &[f32] {
60        let (trig, _) = self.get_trig(shift);
61        trig
62    }
63
64    #[inline]
65    pub fn forward(
66        &self,
67        input: &[f32],
68        output: &mut [f32],
69        window: &[f32],
70        overlap: usize,
71        shift: usize,
72        stride: usize,
73    ) {
74        let st = self.kfft[shift]
75            .as_ref()
76            .expect("FFT state not initialized");
77        let n = self.n >> shift;
78        let n2 = n / 2;
79        let n4 = n / 4;
80        let scale = st.scale();
81
82        let (trig, _) = self.get_trig(shift);
83        let overlap2 = overlap / 2;
84
85        let mut f_buf = [MaybeUninit::<f32>::uninit(); MAX_N2];
86        let mut f2_buf = [MaybeUninit::<KissCpx>::uninit(); MAX_N4];
87
88        let f = unsafe { std::slice::from_raw_parts_mut(f_buf.as_mut_ptr() as *mut f32, n2) };
89        let f2 = unsafe { std::slice::from_raw_parts_mut(f2_buf.as_mut_ptr() as *mut KissCpx, n4) };
90
91        assert!(input.len() >= n2 + overlap2);
92        assert!(window.len() >= overlap);
93        assert!(
94            output.len() >= n2,
95            "MDCT forward: output buffer too small (need {}, have {})",
96            n2,
97            output.len()
98        );
99
100        {
101            let mut yp = 0usize;
102            let mut xp1 = overlap2;
103            let mut xp2 = n2 - 1 + overlap2;
104            let mut wp1 = overlap2;
105
106            let mut wp2 = overlap2.saturating_sub(1);
107
108            let limit = overlap.div_ceil(4);
109            let mid = n4.saturating_sub(limit);
110
111            let loop1_iters = limit.min(n4);
112            for _ in 0..loop1_iters {
113                let w1 = window[wp1];
114                let w2 = window[wp2];
115
116                f[yp] = input[xp1 + n2] * w2 + input[xp2] * w1;
117                yp += 1;
118
119                f[yp] = input[xp1] * w1 - input[xp2 - n2] * w2;
120                yp += 1;
121
122                xp1 += 2;
123                xp2 -= 2;
124                wp1 += 2;
125                wp2 = wp2.saturating_sub(2);
126            }
127
128            for _ in limit..mid {
129                f[yp] = input[xp2];
130                yp += 1;
131
132                f[yp] = input[xp1];
133                yp += 1;
134                xp1 += 2;
135                xp2 -= 2;
136            }
137
138            let loop3_iters = if mid > limit { n4 - mid } else { 0 };
139            let mut wp1_l3 = 0usize;
140            let mut wp2_l3 = overlap.saturating_sub(1);
141            for _ in 0..loop3_iters {
142                let w1 = window[wp1_l3];
143                let w2 = window[wp2_l3];
144
145                f[yp] = -input[xp1 - n2] * w1 + input[xp2] * w2;
146                yp += 1;
147
148                f[yp] = input[xp1] * w2 + input[xp2 + n2] * w1;
149                yp += 1;
150
151                xp1 += 2;
152                xp2 -= 2;
153                wp1_l3 += 2;
154                wp2_l3 -= 2;
155            }
156        }
157
158        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
159        unsafe {
160            if std::arch::is_x86_feature_detected!("avx") {
161                mdct_pre_rotation_avx(f, f2, trig, &st.bitrev[..n4], n4, scale);
162            } else {
163                for i in 0..n4 {
164                    let re = f[2 * i];
165                    let im = f[2 * i + 1];
166                    let t0 = trig[i];
167                    let t1 = trig[n4 + i];
168
169                    let yr = re * t0 - im * t1;
170                    let yi = im * t0 + re * t1;
171
172                    f2[st.bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
173                }
174            }
175        }
176        #[cfg(all(
177            not(any(target_arch = "x86", target_arch = "x86_64")),
178            target_arch = "aarch64"
179        ))]
180        {
181            mdct_pre_rotation_neon(f, f2, trig, &st.bitrev[..n4], n4, scale);
182        }
183        #[cfg(all(
184            not(any(target_arch = "x86", target_arch = "x86_64")),
185            not(target_arch = "aarch64")
186        ))]
187        for i in 0..n4 {
188            let re = f[2 * i];
189            let im = f[2 * i + 1];
190            let t0 = trig[i];
191            let t1 = trig[n4 + i];
192
193            let yr = re * t0 - im * t1;
194            let yi = im * t0 + re * t1;
195
196            f2[st.bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
197        }
198
199        opus_fft_impl(st, f2);
200
201        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
202        unsafe {
203            if std::arch::is_x86_feature_detected!("avx") {
204                mdct_post_rotation_avx(f2, trig, output, n4, n2, stride);
205            } else {
206                for i in 0..n4 {
207                    let fp = &f2[i];
208                    let t0 = trig[i];
209                    let t1 = trig[n4 + i];
210
211                    let yr = fp.i * t1 - fp.r * t0;
212                    let yi = fp.r * t1 + fp.i * t0;
213
214                    output[i * 2 * stride] = yr;
215                    output[stride * (n2 - 1 - 2 * i)] = yi;
216                }
217            }
218        }
219        #[cfg(all(
220            not(any(target_arch = "x86", target_arch = "x86_64")),
221            target_arch = "aarch64"
222        ))]
223        {
224            mdct_post_rotation_neon(f2, trig, output, n4, n2, stride);
225        }
226        #[cfg(all(
227            not(any(target_arch = "x86", target_arch = "x86_64")),
228            not(target_arch = "aarch64")
229        ))]
230        for i in 0..n4 {
231            let fp = &f2[i];
232            let t0 = trig[i];
233            let t1 = trig[n4 + i];
234
235            let yr = fp.i * t1 - fp.r * t0;
236            let yi = fp.r * t1 + fp.i * t0;
237
238            output[i * 2 * stride] = yr;
239            output[stride * (n2 - 1 - 2 * i)] = yi;
240        }
241    }
242
243    #[inline]
244    pub fn backward(
245        &self,
246        input: &[f32],
247        output: &mut [f32],
248        window: &[f32],
249        overlap: usize,
250        shift: usize,
251        stride: usize,
252    ) {
253        let st = self.kfft[shift]
254            .as_ref()
255            .expect("FFT state not initialized");
256        let n = self.n >> shift;
257        let n2 = n / 2;
258        let n4 = n / 4;
259        let overlap2 = overlap / 2;
260
261        let (trig, _) = self.get_trig(shift);
262
263        let mut f2_buf = [MaybeUninit::<KissCpx>::uninit(); MAX_N4];
264
265        let f2 = unsafe { std::slice::from_raw_parts_mut(f2_buf.as_mut_ptr() as *mut KissCpx, n4) };
266
267        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
268        unsafe {
269            if std::arch::is_x86_feature_detected!("avx") {
270                mdct_backward_pre_rotation_avx(input, f2, trig, &st.bitrev[..n4], n4, n2, stride);
271            } else {
272                for i in 0..n4 {
273                    let rev = st.bitrev[i] as usize;
274                    let x1 = input[2 * i * stride];
275                    let x2 = input[stride * (n2 - 1 - 2 * i)];
276                    let t0 = trig[i];
277                    let t1 = trig[n4 + i];
278
279                    let yr = x2 * t0 + x1 * t1;
280                    let yi = x1 * t0 - x2 * t1;
281
282                    f2[rev] = KissCpx::new(yi, yr);
283                }
284            }
285        }
286        #[cfg(all(
287            not(any(target_arch = "x86", target_arch = "x86_64")),
288            target_arch = "aarch64"
289        ))]
290        {
291            mdct_backward_pre_rotation_neon(input, f2, trig, &st.bitrev[..n4], n4, n2, stride);
292        }
293        #[cfg(all(
294            not(any(target_arch = "x86", target_arch = "x86_64")),
295            not(target_arch = "aarch64")
296        ))]
297        for i in 0..n4 {
298            let rev = st.bitrev[i] as usize;
299            let x1 = input[2 * i * stride];
300            let x2 = input[stride * (n2 - 1 - 2 * i)];
301            let t0 = trig[i];
302            let t1 = trig[n4 + i];
303
304            let yr = x2 * t0 + x1 * t1;
305            let yi = x1 * t0 - x2 * t1;
306
307            f2[rev] = KissCpx::new(yi, yr);
308        }
309
310        opus_fft_impl(st, f2);
311
312        assert!(output.len() >= overlap2 + n2);
313
314        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
315        unsafe {
316            if std::arch::is_x86_feature_detected!("avx") {
317                mdct_backward_post_rotation_avx(f2, trig, output, n4, n2, overlap2);
318            } else {
319                for i in 0..((n4 + 1) >> 1) {
320                    let im0 = f2[i].r;
321                    let re0 = f2[i].i;
322                    let t0_0 = trig[i];
323                    let t1_0 = trig[n4 + i];
324
325                    let yr0 = re0 * t0_0 + im0 * t1_0;
326                    let yi0 = re0 * t1_0 - im0 * t0_0;
327
328                    let j = n4 - 1 - i;
329                    let im1 = f2[j].r;
330                    let re1 = f2[j].i;
331                    let t0_1 = trig[j];
332                    let t1_1 = trig[n4 + j];
333
334                    let yr1 = re1 * t0_1 + im1 * t1_1;
335                    let yi1 = re1 * t1_1 - im1 * t0_1;
336
337                    output[overlap2 + 2 * i] = yr0;
338                    output[overlap2 + n2 - 1 - 2 * i] = yi0;
339                    output[overlap2 + n2 - 2 - 2 * i] = yr1;
340                    output[overlap2 + 2 * i + 1] = yi1;
341                }
342            }
343        }
344        #[cfg(all(
345            not(any(target_arch = "x86", target_arch = "x86_64")),
346            target_arch = "aarch64"
347        ))]
348        {
349            mdct_backward_post_rotation_neon(f2, trig, output, n4, n2, overlap2);
350        }
351        #[cfg(all(
352            not(any(target_arch = "x86", target_arch = "x86_64")),
353            not(target_arch = "aarch64")
354        ))]
355        for i in 0..((n4 + 1) >> 1) {
356            let im0 = f2[i].r;
357            let re0 = f2[i].i;
358            let t0_0 = trig[i];
359            let t1_0 = trig[n4 + i];
360
361            let yr0 = re0 * t0_0 + im0 * t1_0;
362            let yi0 = re0 * t1_0 - im0 * t0_0;
363
364            let j = n4 - 1 - i;
365            let im1 = f2[j].r;
366            let re1 = f2[j].i;
367            let t0_1 = trig[j];
368            let t1_1 = trig[n4 + j];
369
370            let yr1 = re1 * t0_1 + im1 * t1_1;
371            let yi1 = re1 * t1_1 - im1 * t0_1;
372
373            output[overlap2 + 2 * i] = yr0;
374            output[overlap2 + n2 - 1 - 2 * i] = yi0;
375            output[overlap2 + n2 - 2 - 2 * i] = yr1;
376            output[overlap2 + 2 * i + 1] = yi1;
377        }
378
379        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
380        unsafe {
381            if std::arch::is_x86_feature_detected!("avx") {
382                mdct_tdac_avx(output, window, overlap);
383            } else {
384                for i in 0..overlap2 {
385                    let x1 = output[overlap - 1 - i];
386                    let x2 = output[i];
387                    let wp1 = window[i];
388                    let wp2 = window[overlap - 1 - i];
389
390                    output[i] = x2 * wp2 - x1 * wp1;
391                    output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
392                }
393            }
394        }
395        #[cfg(all(
396            not(any(target_arch = "x86", target_arch = "x86_64")),
397            target_arch = "aarch64"
398        ))]
399        {
400            mdct_tdac_neon(output, window, overlap);
401        }
402        #[cfg(all(
403            not(any(target_arch = "x86", target_arch = "x86_64")),
404            not(target_arch = "aarch64")
405        ))]
406        for i in 0..overlap2 {
407            let x1 = output[overlap - 1 - i];
408            let x2 = output[i];
409            let wp1 = window[i];
410            let wp2 = window[overlap - 1 - i];
411
412            output[i] = x2 * wp2 - x1 * wp1;
413            output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
414        }
415    }
416}
417
418#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
419#[target_feature(enable = "avx")]
420unsafe fn mdct_pre_rotation_avx(
421    f: &[f32],
422    f2: &mut [KissCpx],
423    trig: &[f32],
424    bitrev: &[i16],
425    n4: usize,
426    scale: f32,
427) {
428    for i in 0..n4 {
429        let re = f[2 * i];
430        let im = f[2 * i + 1];
431        let t0 = trig[i];
432        let t1 = trig[n4 + i];
433
434        let yr = re * t0 - im * t1;
435        let yi = im * t0 + re * t1;
436
437        f2[bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
438    }
439}
440
441#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
442#[target_feature(enable = "avx")]
443unsafe fn mdct_post_rotation_avx(
444    f2: &[KissCpx],
445    trig: &[f32],
446    output: &mut [f32],
447    n4: usize,
448    n2: usize,
449    stride: usize,
450) {
451    for i in 0..n4 {
452        let fp = &f2[i];
453        let t0 = trig[i];
454        let t1 = trig[n4 + i];
455
456        let yr = fp.i * t1 - fp.r * t0;
457        let yi = fp.r * t1 + fp.i * t0;
458
459        output[i * 2 * stride] = yr;
460        output[stride * (n2 - 1 - 2 * i)] = yi;
461    }
462}
463
464#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
465#[target_feature(enable = "avx")]
466unsafe fn mdct_backward_pre_rotation_avx(
467    input: &[f32],
468    f2: &mut [KissCpx],
469    trig: &[f32],
470    bitrev: &[i16],
471    n4: usize,
472    n2: usize,
473    stride: usize,
474) {
475    for i in 0..n4 {
476        let rev = bitrev[i] as usize;
477        let x1 = input[2 * i * stride];
478        let x2 = input[stride * (n2 - 1 - 2 * i)];
479        let t0 = trig[i];
480        let t1 = trig[n4 + i];
481
482        let yr = x2 * t0 + x1 * t1;
483        let yi = x1 * t0 - x2 * t1;
484
485        f2[rev] = KissCpx::new(yi, yr);
486    }
487}
488
489#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
490#[target_feature(enable = "avx")]
491unsafe fn mdct_backward_post_rotation_avx(
492    f2: &[KissCpx],
493    trig: &[f32],
494    output: &mut [f32],
495    n4: usize,
496    n2: usize,
497    overlap2: usize,
498) {
499    for i in 0..((n4 + 1) >> 1) {
500        let im0 = f2[i].r;
501        let re0 = f2[i].i;
502        let t0_0 = trig[i];
503        let t1_0 = trig[n4 + i];
504
505        let yr0 = re0 * t0_0 + im0 * t1_0;
506        let yi0 = re0 * t1_0 - im0 * t0_0;
507
508        let j = n4 - 1 - i;
509        let im1 = f2[j].r;
510        let re1 = f2[j].i;
511        let t0_1 = trig[j];
512        let t1_1 = trig[n4 + j];
513
514        let yr1 = re1 * t0_1 + im1 * t1_1;
515        let yi1 = re1 * t1_1 - im1 * t0_1;
516
517        output[overlap2 + 2 * i] = yr0;
518        output[overlap2 + n2 - 1 - 2 * i] = yi0;
519        output[overlap2 + n2 - 2 - 2 * i] = yr1;
520        output[overlap2 + 2 * i + 1] = yi1;
521    }
522}
523
524#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
525#[target_feature(enable = "avx")]
526unsafe fn mdct_tdac_avx(output: &mut [f32], window: &[f32], overlap: usize) {
527    use std::arch::x86_64::*;
528
529    let overlap2 = overlap / 2;
530    let mut i = 0usize;
531
532    while i + 8 <= overlap2 {
533        let x2 = _mm256_loadu_ps(output.as_ptr().add(i));
534
535        let mut x1_tmp = [0f32; 8];
536        let mut w2_tmp = [0f32; 8];
537        for j in 0..8 {
538            x1_tmp[j] = output[overlap - 1 - (i + j)];
539            w2_tmp[j] = window[overlap - 1 - (i + j)];
540        }
541        let x1 = _mm256_loadu_ps(x1_tmp.as_ptr());
542
543        let w1 = _mm256_loadu_ps(window.as_ptr().add(i));
544        let w2 = _mm256_loadu_ps(w2_tmp.as_ptr());
545
546        let out_fwd = _mm256_sub_ps(_mm256_mul_ps(x2, w2), _mm256_mul_ps(x1, w1));
547        let out_rev = _mm256_add_ps(_mm256_mul_ps(x2, w1), _mm256_mul_ps(x1, w2));
548
549        _mm256_storeu_ps(output.as_mut_ptr().add(i), out_fwd);
550
551        let mut out_rev_tmp = [0f32; 8];
552        _mm256_storeu_ps(out_rev_tmp.as_mut_ptr(), out_rev);
553        for j in 0..8 {
554            output[overlap - 1 - (i + j)] = out_rev_tmp[j];
555        }
556
557        i += 8;
558    }
559
560    for i in i..overlap2 {
561        let x1 = output[overlap - 1 - i];
562        let x2 = output[i];
563        let wp1 = window[i];
564        let wp2 = window[overlap - 1 - i];
565        output[i] = x2 * wp2 - x1 * wp1;
566        output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
567    }
568}
569
570#[cfg(target_arch = "aarch64")]
571#[inline(always)]
572fn mdct_pre_rotation_neon(
573    f: &[f32],
574    f2: &mut [KissCpx],
575    trig: &[f32],
576    bitrev: &[i16],
577    n4: usize,
578    scale: f32,
579) {
580    use std::arch::aarch64::*;
581
582    unsafe {
583        let vscale = vdupq_n_f32(scale);
584        let f_ptr = f.as_ptr();
585        let trig_ptr = trig.as_ptr();
586        let bitrev_ptr = bitrev.as_ptr();
587        let f2_ptr = f2.as_mut_ptr() as *mut f32;
588
589        let n4_vec = n4 & !3;
590        let mut i = 0;
591
592        while i < n4_vec {
593            let t0 = vld1q_f32(trig_ptr.add(i));
594            let t1 = vld1q_f32(trig_ptr.add(n4 + i));
595
596            let f0 = vld1q_f32(f_ptr.add(2 * i));
597            let f1 = vld1q_f32(f_ptr.add(2 * i + 4));
598
599            let even_odd = vuzpq_f32(f0, f1);
600            let re_v = even_odd.0;
601            let im_v = even_odd.1;
602
603            let yr = vsubq_f32(vmulq_f32(re_v, t0), vmulq_f32(im_v, t1));
604            let yi = vaddq_f32(vmulq_f32(im_v, t0), vmulq_f32(re_v, t1));
605
606            let yr = vmulq_f32(yr, vscale);
607            let yi = vmulq_f32(yi, vscale);
608
609            let yr_arr: [f32; 4] = std::mem::transmute(yr);
610            let yi_arr: [f32; 4] = std::mem::transmute(yi);
611
612            for j in 0..4 {
613                let rev = *bitrev_ptr.add(i + j) as usize;
614                *f2_ptr.add(2 * rev) = yr_arr[j];
615                *f2_ptr.add(2 * rev + 1) = yi_arr[j];
616            }
617
618            i += 4;
619        }
620
621        for i in n4_vec..n4 {
622            let re = *f_ptr.add(2 * i);
623            let im = *f_ptr.add(2 * i + 1);
624            let t0 = *trig_ptr.add(i);
625            let t1 = *trig_ptr.add(n4 + i);
626            let yr = re * t0 - im * t1;
627            let yi = im * t0 + re * t1;
628            let rev = *bitrev_ptr.add(i) as usize;
629            *f2_ptr.add(2 * rev) = yr * scale;
630            *f2_ptr.add(2 * rev + 1) = yi * scale;
631        }
632    }
633}
634
635#[cfg(target_arch = "aarch64")]
636#[inline(always)]
637fn mdct_post_rotation_neon(
638    f2: &[KissCpx],
639    trig: &[f32],
640    output: &mut [f32],
641    n4: usize,
642    n2: usize,
643    stride: usize,
644) {
645    use std::arch::aarch64::*;
646
647    if stride > 1 {
648        for i in 0..n4 {
649            let fp = &f2[i];
650            let t0 = trig[i];
651            let t1 = trig[n4 + i];
652            let yr = fp.i * t1 - fp.r * t0;
653            let yi = fp.r * t1 + fp.i * t0;
654            output[i * 2 * stride] = yr;
655            output[stride * (n2 - 1 - 2 * i)] = yi;
656        }
657        return;
658    }
659
660    unsafe {
661        let f2_ptr = f2.as_ptr() as *const f32;
662        let trig_ptr = trig.as_ptr();
663        let out_ptr = output.as_mut_ptr();
664
665        let n4_vec = n4 & !3;
666        let mut i = 0;
667
668        while i < n4_vec {
669            let c0 = vld1q_f32(f2_ptr.add(2 * i));
670            let c1 = vld1q_f32(f2_ptr.add(2 * i + 4));
671
672            let t0 = vld1q_f32(trig_ptr.add(i));
673            let t1 = vld1q_f32(trig_ptr.add(n4 + i));
674
675            let ri = vuzpq_f32(c0, c1);
676            let r_v = ri.0;
677            let i_v = ri.1;
678
679            let yr = vsubq_f32(vmulq_f32(i_v, t1), vmulq_f32(r_v, t0));
680
681            let yi = vaddq_f32(vmulq_f32(r_v, t1), vmulq_f32(i_v, t0));
682
683            let yr_arr: [f32; 4] = std::mem::transmute(yr);
684            let yi_arr: [f32; 4] = std::mem::transmute(yi);
685
686            for j in 0..4 {
687                *out_ptr.add((i + j) * 2) = yr_arr[j];
688                *out_ptr.add(n2 - 1 - 2 * (i + j)) = yi_arr[j];
689            }
690
691            i += 4;
692        }
693
694        for i in n4_vec..n4 {
695            let fp = &f2[i];
696            let t0 = trig[i];
697            let t1 = trig[n4 + i];
698            let yr = fp.i * t1 - fp.r * t0;
699            let yi = fp.r * t1 + fp.i * t0;
700            output[i * 2] = yr;
701            output[n2 - 1 - 2 * i] = yi;
702        }
703    }
704}
705
706#[cfg(target_arch = "aarch64")]
707#[inline(always)]
708fn mdct_backward_pre_rotation_neon(
709    input: &[f32],
710    f2: &mut [KissCpx],
711    trig: &[f32],
712    bitrev: &[i16],
713    n4: usize,
714    n2: usize,
715    stride: usize,
716) {
717    use std::arch::aarch64::*;
718
719    if stride != 1 {
720        for i in 0..n4 {
721            let rev = bitrev[i] as usize;
722            let x1 = input[2 * i * stride];
723            let x2 = input[stride * (n2 - 1 - 2 * i)];
724            let t0 = trig[i];
725            let t1 = trig[n4 + i];
726            let yr = x2 * t0 + x1 * t1;
727            let yi = x1 * t0 - x2 * t1;
728            f2[rev] = KissCpx::new(yi, yr);
729        }
730        return;
731    }
732
733    unsafe {
734        let in_ptr = input.as_ptr();
735        let trig_ptr = trig.as_ptr();
736        let bitrev_ptr = bitrev.as_ptr();
737        let f2_ptr = f2.as_mut_ptr() as *mut f32;
738
739        let n4_vec = n4 & !3;
740        let mut i = 0;
741
742        while i < n4_vec {
743            let f0 = vld1q_f32(in_ptr.add(2 * i));
744            let f1 = vld1q_f32(in_ptr.add(2 * i + 4));
745            let deint_x1 = vuzpq_f32(f0, f1);
746            let x1_v = deint_x1.0;
747
748            let g0 = vld1q_f32(in_ptr.add(n2 - 7 - 2 * i));
749            let g1 = vld1q_f32(in_ptr.add(n2 - 3 - 2 * i));
750            let deint_x2 = vuzpq_f32(g0, g1);
751
752            let x2_raw = deint_x2.0;
753            let x2_v = vrev64q_f32(x2_raw);
754            let x2_v = vextq_f32(x2_v, x2_v, 2);
755
756            let t0 = vld1q_f32(trig_ptr.add(i));
757            let t1 = vld1q_f32(trig_ptr.add(n4 + i));
758
759            let yr = vaddq_f32(vmulq_f32(x2_v, t0), vmulq_f32(x1_v, t1));
760            let yi = vsubq_f32(vmulq_f32(x1_v, t0), vmulq_f32(x2_v, t1));
761
762            let yr_arr: [f32; 4] = std::mem::transmute(yr);
763            let yi_arr: [f32; 4] = std::mem::transmute(yi);
764
765            for j in 0..4 {
766                let rev = *bitrev_ptr.add(i + j) as usize;
767                *f2_ptr.add(2 * rev) = yi_arr[j];
768                *f2_ptr.add(2 * rev + 1) = yr_arr[j];
769            }
770
771            i += 4;
772        }
773
774        for i in n4_vec..n4 {
775            let rev = *bitrev_ptr.add(i) as usize;
776            let x1 = *in_ptr.add(2 * i);
777            let x2 = *in_ptr.add(n2 - 1 - 2 * i);
778            let t0 = *trig_ptr.add(i);
779            let t1 = *trig_ptr.add(n4 + i);
780            let yr = x2 * t0 + x1 * t1;
781            let yi = x1 * t0 - x2 * t1;
782            *f2_ptr.add(2 * rev) = yi;
783            *f2_ptr.add(2 * rev + 1) = yr;
784        }
785    }
786}
787
788#[cfg(target_arch = "aarch64")]
789#[inline(always)]
790fn mdct_backward_post_rotation_neon(
791    f2: &[KissCpx],
792    trig: &[f32],
793    output: &mut [f32],
794    n4: usize,
795    n2: usize,
796    overlap2: usize,
797) {
798    unsafe {
799        let trig_ptr = trig.as_ptr();
800        let out_base = output.as_mut_ptr().add(overlap2);
801
802        let half = (n4 + 1) >> 1;
803
804        let mut i = 0;
805        while i + 1 < half {
806            let j0 = n4 - 1 - i;
807            let j1 = n4 - 1 - (i + 1);
808
809            let re0 = f2[i].i;
810            let im0 = f2[i].r;
811            let t0_0 = *trig_ptr.add(i);
812            let t1_0 = *trig_ptr.add(n4 + i);
813            let yr0 = re0 * t0_0 + im0 * t1_0;
814            let yi0 = re0 * t1_0 - im0 * t0_0;
815
816            let im1 = f2[j0].r;
817            let re1 = f2[j0].i;
818            let t0_1 = *trig_ptr.add(j0);
819            let t1_1 = *trig_ptr.add(n4 + j0);
820            let yr1 = re1 * t0_1 + im1 * t1_1;
821            let yi1 = re1 * t1_1 - im1 * t0_1;
822
823            *out_base.add(2 * i) = yr0;
824            *out_base.add(n2 - 1 - 2 * i) = yi0;
825            *out_base.add(n2 - 2 - 2 * i) = yr1;
826            *out_base.add(2 * i + 1) = yi1;
827
828            let re0b = f2[i + 1].i;
829            let im0b = f2[i + 1].r;
830            let t0_0b = *trig_ptr.add(i + 1);
831            let t1_0b = *trig_ptr.add(n4 + i + 1);
832            let yr0b = re0b * t0_0b + im0b * t1_0b;
833            let yi0b = re0b * t1_0b - im0b * t0_0b;
834
835            let im1b = f2[j1].r;
836            let re1b = f2[j1].i;
837            let t0_1b = *trig_ptr.add(j1);
838            let t1_1b = *trig_ptr.add(n4 + j1);
839            let yr1b = re1b * t0_1b + im1b * t1_1b;
840            let yi1b = re1b * t1_1b - im1b * t0_1b;
841
842            *out_base.add(2 * (i + 1)) = yr0b;
843            *out_base.add(n2 - 1 - 2 * (i + 1)) = yi0b;
844            *out_base.add(n2 - 2 - 2 * (i + 1)) = yr1b;
845            *out_base.add(2 * (i + 1) + 1) = yi1b;
846
847            i += 2;
848        }
849
850        if i < half {
851            let j = n4 - 1 - i;
852            let im0 = f2[i].r;
853            let re0 = f2[i].i;
854            let t0_0 = *trig_ptr.add(i);
855            let t1_0 = *trig_ptr.add(n4 + i);
856            let yr0 = re0 * t0_0 + im0 * t1_0;
857            let yi0 = re0 * t1_0 - im0 * t0_0;
858
859            let im1 = f2[j].r;
860            let re1 = f2[j].i;
861            let t0_1 = *trig_ptr.add(j);
862            let t1_1 = *trig_ptr.add(n4 + j);
863            let yr1 = re1 * t0_1 + im1 * t1_1;
864            let yi1 = re1 * t1_1 - im1 * t0_1;
865
866            *out_base.add(2 * i) = yr0;
867            *out_base.add(n2 - 1 - 2 * i) = yi0;
868            *out_base.add(n2 - 2 - 2 * i) = yr1;
869            *out_base.add(2 * i + 1) = yi1;
870        }
871    }
872}
873
874#[cfg(target_arch = "aarch64")]
875#[inline(always)]
876fn mdct_tdac_neon(output: &mut [f32], window: &[f32], overlap: usize) {
877    use std::arch::aarch64::*;
878
879    let overlap2 = overlap / 2;
880    if overlap2 < 4 {
881        for i in 0..overlap2 {
882            let x1 = output[overlap - 1 - i];
883            let x2 = output[i];
884            let wp1 = window[i];
885            let wp2 = window[overlap - 1 - i];
886            output[i] = x2 * wp2 - x1 * wp1;
887            output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
888        }
889        return;
890    }
891
892    unsafe {
893        let out_ptr = output.as_mut_ptr();
894        let win_ptr = window.as_ptr();
895        let n4 = overlap2 & !3;
896        let mut i = 0;
897
898        while i < n4 {
899            let x2_fwd = vld1q_f32(out_ptr.add(i));
900            let x1_rev = vld1q_f32(out_ptr.add(overlap - 4 - i));
901
902            let x1 = vrev64q_f32(x1_rev);
903            let x1 = vextq_f32(x1, x1, 2);
904
905            let wp1_fwd = vld1q_f32(win_ptr.add(i));
906            let wp2_rev = vld1q_f32(win_ptr.add(overlap - 4 - i));
907            let wp2 = vrev64q_f32(wp2_rev);
908            let wp2 = vextq_f32(wp2, wp2, 2);
909            let wp1 = wp1_fwd;
910
911            let out_fwd = vsubq_f32(vmulq_f32(x2_fwd, wp2), vmulq_f32(x1, wp1));
912
913            let out_rev = vaddq_f32(vmulq_f32(x2_fwd, wp1), vmulq_f32(x1, wp2));
914
915            let out_rev = vrev64q_f32(out_rev);
916            let out_rev = vextq_f32(out_rev, out_rev, 2);
917
918            vst1q_f32(out_ptr.add(i), out_fwd);
919            vst1q_f32(out_ptr.add(overlap - 4 - i), out_rev);
920
921            i += 4;
922        }
923
924        for i in n4..overlap2 {
925            let x1 = output[overlap - 1 - i];
926            let x2 = output[i];
927            output[i] = x2 * window[overlap - 1 - i] - x1 * window[i];
928            output[overlap - 1 - i] = x2 * window[i] + x1 * window[overlap - 1 - i];
929        }
930    }
931}
932
933#[cfg(test)]
934mod mdct_tests {
935    #[test]
936    fn test_mdct_backward_transient_no_blowup() {
937        let mode = crate::modes::default_mode();
938        let shift = 3;
939        let n = mode.mdct.n >> shift; // 120
940        let overlap = mode.overlap; // 120
941        let stride = 8;
942
943        let frame_size = 960usize;
944        let mut freq = vec![0.0f32; frame_size];
945        for i in 0..frame_size {
946            freq[i] = ((i as f32) * 0.01).sin() * 10.0;
947        }
948
949        let out_len = n + overlap; // 240
950        let mut output0 = vec![0.0f32; out_len];
951        let mut output1 = vec![0.0f32; out_len];
952
953        mode.mdct.backward(
954            &freq[0..],
955            &mut output0,
956            mode.window,
957            overlap,
958            shift,
959            stride,
960        );
961        mode.mdct.backward(
962            &freq[1..],
963            &mut output1,
964            mode.window,
965            overlap,
966            shift,
967            stride,
968        );
969
970        let max0 = output0.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
971        let max1 = output1.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
972        eprintln!("sub0 max={} sub1 max={}", max0, max1);
973        eprintln!("sub0[60..70]={:?}", &output0[60..70]);
974        eprintln!("sub1[60..70]={:?}", &output1[60..70]);
975
976        assert!(max0.abs() < 500.0, "sub0 blowup: {}", max0);
977        assert!(max1.abs() < 500.0, "sub1 blowup: {}", max1);
978    }
979
980    #[test]
981    fn test_mdct_backward_stride1_neon_matches_scalar() {
982        let mode = crate::modes::default_mode();
983        let shift = 0; // non-transient full-size MDCT
984        let n = mode.mdct.n >> shift; // 1920
985        let n2 = n / 2; // 960
986        let n4 = n / 4; // 480
987        let overlap = mode.overlap; // 120
988        let overlap2 = overlap / 2; // 60
989        let stride = 1;
990
991        let freq_len = n2;
992        let mut freq = vec![0.0f32; freq_len + 4];
993        for i in 0..freq_len {
994            freq[i] = ((i as f32) * 0.01).sin() * 4577.0;
995        }
996
997        let out_len = overlap2 + n2; // 60 + 960 = 1020
998        let mut output_hw = vec![0.0f32; out_len + 100];
999        mode.mdct.backward(
1000            &freq[..],
1001            &mut output_hw,
1002            mode.window,
1003            overlap,
1004            shift,
1005            stride,
1006        );
1007
1008        let st = mode.mdct.kfft[shift].as_ref().unwrap();
1009        let (trig, _) = mode.mdct.get_trig(shift);
1010
1011        use crate::kiss_fft::KissCpx;
1012        let mut f2 = vec![KissCpx::new(0.0, 0.0); n4];
1013        for i in 0..n4 {
1014            let rev = st.bitrev[i] as usize;
1015            let x1 = freq[2 * i * stride];
1016            let x2 = freq[stride * (n2 - 1 - 2 * i)];
1017            let t0 = trig[i];
1018            let t1 = trig[n4 + i];
1019            let yr = x2 * t0 + x1 * t1;
1020            let yi = x1 * t0 - x2 * t1;
1021            f2[rev] = KissCpx::new(yi, yr);
1022        }
1023        crate::kiss_fft::opus_fft_impl(st, &mut f2);
1024
1025        let mut output_scalar = vec![0.0f32; out_len + 100];
1026        for i in 0..((n4 + 1) >> 1) {
1027            let im0 = f2[i].r;
1028            let re0 = f2[i].i;
1029            let t0_0 = trig[i];
1030            let t1_0 = trig[n4 + i];
1031            let yr0 = re0 * t0_0 + im0 * t1_0;
1032            let yi0 = re0 * t1_0 - im0 * t0_0;
1033            let j = n4 - 1 - i;
1034            let im1 = f2[j].r;
1035            let re1 = f2[j].i;
1036            let t0_1 = trig[j];
1037            let t1_1 = trig[n4 + j];
1038            let yr1 = re1 * t0_1 + im1 * t1_1;
1039            let yi1 = re1 * t1_1 - im1 * t0_1;
1040            output_scalar[overlap2 + 2 * i] = yr0;
1041            output_scalar[overlap2 + n2 - 1 - 2 * i] = yi0;
1042            output_scalar[overlap2 + n2 - 2 - 2 * i] = yr1;
1043            output_scalar[overlap2 + 2 * i + 1] = yi1;
1044        }
1045        // TDAC
1046        for i in 0..overlap2 {
1047            let x1 = output_scalar[overlap - 1 - i];
1048            let x2 = output_scalar[i];
1049            let wp1 = mode.window[i];
1050            let wp2 = mode.window[overlap - 1 - i];
1051            output_scalar[i] = x2 * wp2 - x1 * wp1;
1052            output_scalar[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
1053        }
1054
1055        let max_diff = output_hw[..out_len]
1056            .iter()
1057            .zip(output_scalar[..out_len].iter())
1058            .map(|(a, b)| (a - b).abs())
1059            .fold(0.0f32, f32::max);
1060        assert!(
1061            max_diff < 0.5,
1062            "stride=1 NEON vs scalar mismatch: max_diff={}",
1063            max_diff
1064        );
1065    }
1066
1067    #[test]
1068    fn test_mdct_backward_neon_matches_scalar() {
1069        let mode = crate::modes::default_mode();
1070        let shift = 3;
1071        let n = mode.mdct.n >> shift; // 240
1072        let n2 = n / 2; // 120
1073        let n4 = n / 4; // 60
1074        let overlap = mode.overlap; // 120
1075        let overlap2 = overlap / 2; // 60
1076        let stride = 8;
1077
1078        // Build a realistic freq vector (sine wave @ 440Hz)
1079        let frame_size = 960usize;
1080        let mut freq = vec![0.0f32; frame_size];
1081        for i in 0..frame_size {
1082            freq[i] = ((i as f32) * 0.01).sin() * 200.0;
1083        }
1084
1085        let out_len = n + overlap; // 360
1086        let mut output_hw = vec![0.0f32; out_len];
1087        mode.mdct.backward(
1088            &freq[0..],
1089            &mut output_hw,
1090            mode.window,
1091            overlap,
1092            shift,
1093            stride,
1094        );
1095
1096        // Scalar reference
1097        let st = mode.mdct.kfft[shift].as_ref().unwrap();
1098        let (trig, _) = mode.mdct.get_trig(shift);
1099
1100        use crate::kiss_fft::KissCpx;
1101        let mut f2 = vec![KissCpx::new(0.0, 0.0); n4];
1102        for i in 0..n4 {
1103            let rev = st.bitrev[i] as usize;
1104            let x1 = freq[2 * i * stride];
1105            let x2 = freq[stride * (n2 - 1 - 2 * i)];
1106            let t0 = trig[i];
1107            let t1 = trig[n4 + i];
1108            let yr = x2 * t0 + x1 * t1;
1109            let yi = x1 * t0 - x2 * t1;
1110            f2[rev] = KissCpx::new(yi, yr);
1111        }
1112        crate::kiss_fft::opus_fft_impl(st, &mut f2);
1113
1114        let mut output_scalar = vec![0.0f32; out_len];
1115        for i in 0..((n4 + 1) >> 1) {
1116            let im0 = f2[i].r;
1117            let re0 = f2[i].i;
1118            let t0_0 = trig[i];
1119            let t1_0 = trig[n4 + i];
1120            let yr0 = re0 * t0_0 + im0 * t1_0;
1121            let yi0 = re0 * t1_0 - im0 * t0_0;
1122            let j = n4 - 1 - i;
1123            let im1 = f2[j].r;
1124            let re1 = f2[j].i;
1125            let t0_1 = trig[j];
1126            let t1_1 = trig[n4 + j];
1127            let yr1 = re1 * t0_1 + im1 * t1_1;
1128            let yi1 = re1 * t1_1 - im1 * t0_1;
1129            output_scalar[overlap2 + 2 * i] = yr0;
1130            output_scalar[overlap2 + n2 - 1 - 2 * i] = yi0;
1131            output_scalar[overlap2 + n2 - 2 - 2 * i] = yr1;
1132            output_scalar[overlap2 + 2 * i + 1] = yi1;
1133        }
1134        // TDAC
1135        for i in 0..overlap2 {
1136            let x1 = output_scalar[overlap - 1 - i];
1137            let x2 = output_scalar[i];
1138            let wp1 = mode.window[i];
1139            let wp2 = mode.window[overlap - 1 - i];
1140            output_scalar[i] = x2 * wp2 - x1 * wp1;
1141            output_scalar[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
1142        }
1143
1144        for i in 0..out_len {
1145            let diff = (output_hw[i] - output_scalar[i]).abs();
1146            if diff > 1e-3 {
1147                eprintln!(
1148                    "Mismatch at output[{}]: hw={} scalar={} diff={}",
1149                    i, output_hw[i], output_scalar[i], diff
1150                );
1151            }
1152        }
1153        let max_diff = output_hw
1154            .iter()
1155            .zip(output_scalar.iter())
1156            .map(|(a, b)| (a - b).abs())
1157            .fold(0.0f32, f32::max);
1158        assert!(
1159            max_diff < 0.1,
1160            "NEON/HW vs scalar mismatch: max_diff={}",
1161            max_diff
1162        );
1163    }
1164}