Skip to main content

opus_rs/
mdct.rs

1use crate::kiss_fft::{KissCpx, KissFftState, opus_fft_impl};
2use std::f32::consts::PI;
3use std::mem::MaybeUninit;
4
5const MAX_N2: usize = 960;
6const MAX_N4: usize = 480;
7
8pub struct MdctLookup {
9    pub n: usize,
10    pub max_lm: usize,
11    kfft: Vec<Option<KissFftState>>,
12    trig: Vec<f32>,
13}
14
15impl MdctLookup {
16    pub fn new(n: usize, max_lm: usize) -> Self {
17        let mut kfft = Vec::new();
18        let mut trig = Vec::new();
19        let mut curr_n = n;
20
21        for shift in 0..=max_lm {
22            let n4 = curr_n / 4;
23
24            if shift == 0 {
25                kfft.push(KissFftState::new(n4));
26            } else if let Some(base) = kfft.first().unwrap().as_ref() {
27                kfft.push(KissFftState::new_sub(base, n4));
28            } else {
29                kfft.push(None);
30            }
31
32            let n2 = curr_n / 2;
33            for i in 0..n2 {
34                let angle = 2.0 * PI * (i as f32 + 0.125) / curr_n as f32;
35                trig.push(angle.cos());
36            }
37
38            curr_n >>= 1;
39        }
40
41        Self {
42            n,
43            max_lm,
44            kfft,
45            trig,
46        }
47    }
48
49    fn get_trig(&self, shift: usize) -> (&[f32], usize) {
50        let mut offset = 0;
51        let mut curr_n = self.n;
52        for _ in 0..shift {
53            offset += curr_n / 2;
54            curr_n >>= 1;
55        }
56        (&self.trig[offset..offset + curr_n / 2], curr_n / 4)
57    }
58
59    pub fn get_trig_debug(&self, shift: usize) -> &[f32] {
60        let (trig, _) = self.get_trig(shift);
61        trig
62    }
63
64    #[inline]
65    pub fn forward(
66        &self,
67        input: &[f32],
68        output: &mut [f32],
69        window: &[f32],
70        overlap: usize,
71        shift: usize,
72        stride: usize,
73    ) {
74        let st = self.kfft[shift]
75            .as_ref()
76            .expect("FFT state not initialized");
77        let n = self.n >> shift;
78        let n2 = n / 2;
79        let n4 = n / 4;
80        let scale = st.scale();
81
82        let (trig, _) = self.get_trig(shift);
83        let overlap2 = overlap / 2;
84
85        let mut f_buf = [MaybeUninit::<f32>::uninit(); MAX_N2];
86        let mut f2_buf = [MaybeUninit::<KissCpx>::uninit(); MAX_N4];
87
88        let f = unsafe { std::slice::from_raw_parts_mut(f_buf.as_mut_ptr() as *mut f32, n2) };
89        let f2 = unsafe { std::slice::from_raw_parts_mut(f2_buf.as_mut_ptr() as *mut KissCpx, n4) };
90
91        assert!(input.len() >= n2 + overlap2);
92        assert!(window.len() >= overlap);
93
94        {
95            let mut yp = 0usize;
96            let mut xp1 = overlap2;
97            let mut xp2 = n2 - 1 + overlap2;
98            let mut wp1 = overlap2;
99
100            let mut wp2 = overlap2.saturating_sub(1);
101
102            let limit = overlap.div_ceil(4);
103            let mid = n4.saturating_sub(limit);
104
105            let loop1_iters = limit.min(n4);
106            for _ in 0..loop1_iters {
107                let w1 = window[wp1];
108                let w2 = window[wp2];
109
110                f[yp] = input[xp1 + n2] * w2 + input[xp2] * w1;
111                yp += 1;
112
113                f[yp] = input[xp1] * w1 - input[xp2 - n2] * w2;
114                yp += 1;
115
116                xp1 += 2;
117                xp2 -= 2;
118                wp1 += 2;
119                wp2 = wp2.saturating_sub(2);
120            }
121
122            for _ in limit..mid {
123                f[yp] = input[xp2];
124                yp += 1;
125
126                f[yp] = input[xp1];
127                yp += 1;
128                xp1 += 2;
129                xp2 -= 2;
130            }
131
132            let loop3_iters = if mid > limit { n4 - mid } else { 0 };
133            let mut wp1_l3 = 0usize;
134            let mut wp2_l3 = overlap.saturating_sub(1);
135            for _ in 0..loop3_iters {
136                let w1 = window[wp1_l3];
137                let w2 = window[wp2_l3];
138
139                f[yp] = -input[xp1 - n2] * w1 + input[xp2] * w2;
140                yp += 1;
141
142                f[yp] = input[xp1] * w2 + input[xp2 + n2] * w1;
143                yp += 1;
144
145                xp1 += 2;
146                xp2 -= 2;
147                wp1_l3 += 2;
148                wp2_l3 -= 2;
149            }
150        }
151
152        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
153        unsafe {
154            if std::arch::is_x86_feature_detected!("avx") {
155                mdct_pre_rotation_avx(f, f2, trig, &st.bitrev[..n4], n4, scale);
156            } else {
157                for i in 0..n4 {
158                    let re = f[2 * i];
159                    let im = f[2 * i + 1];
160                    let t0 = trig[i];
161                    let t1 = trig[n4 + i];
162
163                    let yr = re * t0 - im * t1;
164                    let yi = im * t0 + re * t1;
165
166                    f2[st.bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
167                }
168            }
169        }
170        #[cfg(all(
171            not(any(target_arch = "x86", target_arch = "x86_64")),
172            target_arch = "aarch64"
173        ))]
174        {
175            mdct_pre_rotation_neon(f, f2, trig, &st.bitrev[..n4], n4, scale);
176        }
177        #[cfg(all(
178            not(any(target_arch = "x86", target_arch = "x86_64")),
179            not(target_arch = "aarch64")
180        ))]
181        for i in 0..n4 {
182            let re = f[2 * i];
183            let im = f[2 * i + 1];
184            let t0 = trig[i];
185            let t1 = trig[n4 + i];
186
187            let yr = re * t0 - im * t1;
188            let yi = im * t0 + re * t1;
189
190            f2[st.bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
191        }
192
193        opus_fft_impl(st, f2);
194
195        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
196        unsafe {
197            if std::arch::is_x86_feature_detected!("avx") {
198                mdct_post_rotation_avx(f2, trig, output, n4, n2, stride);
199            } else {
200                for i in 0..n4 {
201                    let fp = &f2[i];
202                    let t0 = trig[i];
203                    let t1 = trig[n4 + i];
204
205                    let yr = fp.i * t1 - fp.r * t0;
206                    let yi = fp.r * t1 + fp.i * t0;
207
208                    output[i * 2 * stride] = yr;
209                    output[stride * (n2 - 1 - 2 * i)] = yi;
210                }
211            }
212        }
213        #[cfg(all(
214            not(any(target_arch = "x86", target_arch = "x86_64")),
215            target_arch = "aarch64"
216        ))]
217        {
218            mdct_post_rotation_neon(f2, trig, output, n4, n2, stride);
219        }
220        #[cfg(all(
221            not(any(target_arch = "x86", target_arch = "x86_64")),
222            not(target_arch = "aarch64")
223        ))]
224        for i in 0..n4 {
225            let fp = &f2[i];
226            let t0 = trig[i];
227            let t1 = trig[n4 + i];
228
229            let yr = fp.i * t1 - fp.r * t0;
230            let yi = fp.r * t1 + fp.i * t0;
231
232            output[i * 2 * stride] = yr;
233            output[stride * (n2 - 1 - 2 * i)] = yi;
234        }
235    }
236
237    #[inline]
238    pub fn backward(
239        &self,
240        input: &[f32],
241        output: &mut [f32],
242        window: &[f32],
243        overlap: usize,
244        shift: usize,
245        stride: usize,
246    ) {
247        let st = self.kfft[shift]
248            .as_ref()
249            .expect("FFT state not initialized");
250        let n = self.n >> shift;
251        let n2 = n / 2;
252        let n4 = n / 4;
253        let overlap2 = overlap / 2;
254
255        let (trig, _) = self.get_trig(shift);
256
257        let mut f2_buf = [MaybeUninit::<KissCpx>::uninit(); MAX_N4];
258
259        let f2 = unsafe { std::slice::from_raw_parts_mut(f2_buf.as_mut_ptr() as *mut KissCpx, n4) };
260
261        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
262        unsafe {
263            if std::arch::is_x86_feature_detected!("avx") {
264                mdct_backward_pre_rotation_avx(input, f2, trig, &st.bitrev[..n4], n4, n2, stride);
265            } else {
266                for i in 0..n4 {
267                    let rev = st.bitrev[i] as usize;
268                    let x1 = input[2 * i * stride];
269                    let x2 = input[stride * (n2 - 1 - 2 * i)];
270                    let t0 = trig[i];
271                    let t1 = trig[n4 + i];
272
273                    let yr = x2 * t0 + x1 * t1;
274                    let yi = x1 * t0 - x2 * t1;
275
276                    f2[rev] = KissCpx::new(yi, yr);
277                }
278            }
279        }
280        #[cfg(all(
281            not(any(target_arch = "x86", target_arch = "x86_64")),
282            target_arch = "aarch64"
283        ))]
284        {
285            mdct_backward_pre_rotation_neon(input, f2, trig, &st.bitrev[..n4], n4, n2, stride);
286        }
287        #[cfg(all(
288            not(any(target_arch = "x86", target_arch = "x86_64")),
289            not(target_arch = "aarch64")
290        ))]
291        for i in 0..n4 {
292            let rev = st.bitrev[i] as usize;
293            let x1 = input[2 * i * stride];
294            let x2 = input[stride * (n2 - 1 - 2 * i)];
295            let t0 = trig[i];
296            let t1 = trig[n4 + i];
297
298            let yr = x2 * t0 + x1 * t1;
299            let yi = x1 * t0 - x2 * t1;
300
301            f2[rev] = KissCpx::new(yi, yr);
302        }
303
304        opus_fft_impl(st, f2);
305
306        assert!(output.len() >= overlap2 + n2);
307
308        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
309        unsafe {
310            if std::arch::is_x86_feature_detected!("avx") {
311                mdct_backward_post_rotation_avx(f2, trig, output, n4, n2, overlap2);
312            } else {
313                for i in 0..((n4 + 1) >> 1) {
314                    let im0 = f2[i].r;
315                    let re0 = f2[i].i;
316                    let t0_0 = trig[i];
317                    let t1_0 = trig[n4 + i];
318
319                    let yr0 = re0 * t0_0 + im0 * t1_0;
320                    let yi0 = re0 * t1_0 - im0 * t0_0;
321
322                    let j = n4 - 1 - i;
323                    let im1 = f2[j].r;
324                    let re1 = f2[j].i;
325                    let t0_1 = trig[j];
326                    let t1_1 = trig[n4 + j];
327
328                    let yr1 = re1 * t0_1 + im1 * t1_1;
329                    let yi1 = re1 * t1_1 - im1 * t0_1;
330
331                    output[overlap2 + 2 * i] = yr0;
332                    output[overlap2 + n2 - 1 - 2 * i] = yi0;
333                    output[overlap2 + n2 - 2 - 2 * i] = yr1;
334                    output[overlap2 + 2 * i + 1] = yi1;
335                }
336            }
337        }
338        #[cfg(all(
339            not(any(target_arch = "x86", target_arch = "x86_64")),
340            target_arch = "aarch64"
341        ))]
342        {
343            mdct_backward_post_rotation_neon(f2, trig, output, n4, n2, overlap2);
344        }
345        #[cfg(all(
346            not(any(target_arch = "x86", target_arch = "x86_64")),
347            not(target_arch = "aarch64")
348        ))]
349        for i in 0..((n4 + 1) >> 1) {
350            let im0 = f2[i].r;
351            let re0 = f2[i].i;
352            let t0_0 = trig[i];
353            let t1_0 = trig[n4 + i];
354
355            let yr0 = re0 * t0_0 + im0 * t1_0;
356            let yi0 = re0 * t1_0 - im0 * t0_0;
357
358            let j = n4 - 1 - i;
359            let im1 = f2[j].r;
360            let re1 = f2[j].i;
361            let t0_1 = trig[j];
362            let t1_1 = trig[n4 + j];
363
364            let yr1 = re1 * t0_1 + im1 * t1_1;
365            let yi1 = re1 * t1_1 - im1 * t0_1;
366
367            output[overlap2 + 2 * i] = yr0;
368            output[overlap2 + n2 - 1 - 2 * i] = yi0;
369            output[overlap2 + n2 - 2 - 2 * i] = yr1;
370            output[overlap2 + 2 * i + 1] = yi1;
371        }
372
373        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
374        unsafe {
375            if std::arch::is_x86_feature_detected!("avx") {
376                mdct_tdac_avx(output, window, overlap);
377            } else {
378                for i in 0..overlap2 {
379                    let x1 = output[overlap - 1 - i];
380                    let x2 = output[i];
381                    let wp1 = window[i];
382                    let wp2 = window[overlap - 1 - i];
383
384                    output[i] = x2 * wp2 - x1 * wp1;
385                    output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
386                }
387            }
388        }
389        #[cfg(all(
390            not(any(target_arch = "x86", target_arch = "x86_64")),
391            target_arch = "aarch64"
392        ))]
393        {
394            mdct_tdac_neon(output, window, overlap);
395        }
396        #[cfg(all(
397            not(any(target_arch = "x86", target_arch = "x86_64")),
398            not(target_arch = "aarch64")
399        ))]
400        for i in 0..overlap2 {
401            let x1 = output[overlap - 1 - i];
402            let x2 = output[i];
403            let wp1 = window[i];
404            let wp2 = window[overlap - 1 - i];
405
406            output[i] = x2 * wp2 - x1 * wp1;
407            output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
408        }
409    }
410}
411
412#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
413#[target_feature(enable = "avx")]
414unsafe fn mdct_pre_rotation_avx(
415    f: &[f32],
416    f2: &mut [KissCpx],
417    trig: &[f32],
418    bitrev: &[i16],
419    n4: usize,
420    scale: f32,
421) {
422    for i in 0..n4 {
423        let re = f[2 * i];
424        let im = f[2 * i + 1];
425        let t0 = trig[i];
426        let t1 = trig[n4 + i];
427
428        let yr = re * t0 - im * t1;
429        let yi = im * t0 + re * t1;
430
431        f2[bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
432    }
433}
434
435#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
436#[target_feature(enable = "avx")]
437unsafe fn mdct_post_rotation_avx(
438    f2: &[KissCpx],
439    trig: &[f32],
440    output: &mut [f32],
441    n4: usize,
442    n2: usize,
443    stride: usize,
444) {
445    for i in 0..n4 {
446        let fp = &f2[i];
447        let t0 = trig[i];
448        let t1 = trig[n4 + i];
449
450        let yr = fp.i * t1 - fp.r * t0;
451        let yi = fp.r * t1 + fp.i * t0;
452
453        output[i * 2 * stride] = yr;
454        output[stride * (n2 - 1 - 2 * i)] = yi;
455    }
456}
457
458#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
459#[target_feature(enable = "avx")]
460unsafe fn mdct_backward_pre_rotation_avx(
461    input: &[f32],
462    f2: &mut [KissCpx],
463    trig: &[f32],
464    bitrev: &[i16],
465    n4: usize,
466    n2: usize,
467    stride: usize,
468) {
469    for i in 0..n4 {
470        let rev = bitrev[i] as usize;
471        let x1 = input[2 * i * stride];
472        let x2 = input[stride * (n2 - 1 - 2 * i)];
473        let t0 = trig[i];
474        let t1 = trig[n4 + i];
475
476        let yr = x2 * t0 + x1 * t1;
477        let yi = x1 * t0 - x2 * t1;
478
479        f2[rev] = KissCpx::new(yi, yr);
480    }
481}
482
483#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
484#[target_feature(enable = "avx")]
485unsafe fn mdct_backward_post_rotation_avx(
486    f2: &[KissCpx],
487    trig: &[f32],
488    output: &mut [f32],
489    n4: usize,
490    n2: usize,
491    overlap2: usize,
492) {
493    for i in 0..((n4 + 1) >> 1) {
494        let im0 = f2[i].r;
495        let re0 = f2[i].i;
496        let t0_0 = trig[i];
497        let t1_0 = trig[n4 + i];
498
499        let yr0 = re0 * t0_0 + im0 * t1_0;
500        let yi0 = re0 * t1_0 - im0 * t0_0;
501
502        let j = n4 - 1 - i;
503        let im1 = f2[j].r;
504        let re1 = f2[j].i;
505        let t0_1 = trig[j];
506        let t1_1 = trig[n4 + j];
507
508        let yr1 = re1 * t0_1 + im1 * t1_1;
509        let yi1 = re1 * t1_1 - im1 * t0_1;
510
511        output[overlap2 + 2 * i] = yr0;
512        output[overlap2 + n2 - 1 - 2 * i] = yi0;
513        output[overlap2 + n2 - 2 - 2 * i] = yr1;
514        output[overlap2 + 2 * i + 1] = yi1;
515    }
516}
517
518#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
519#[target_feature(enable = "avx")]
520unsafe fn mdct_tdac_avx(output: &mut [f32], window: &[f32], overlap: usize) {
521    use std::arch::x86_64::*;
522
523    let overlap2 = overlap / 2;
524    let mut i = 0usize;
525
526    while i + 8 <= overlap2 {
527        let x2 = _mm256_loadu_ps(output.as_ptr().add(i));
528
529        let mut x1_tmp = [0f32; 8];
530        let mut w2_tmp = [0f32; 8];
531        for j in 0..8 {
532            x1_tmp[j] = output[overlap - 1 - (i + j)];
533            w2_tmp[j] = window[overlap - 1 - (i + j)];
534        }
535        let x1 = _mm256_loadu_ps(x1_tmp.as_ptr());
536
537        let w1 = _mm256_loadu_ps(window.as_ptr().add(i));
538        let w2 = _mm256_loadu_ps(w2_tmp.as_ptr());
539
540        let out_fwd = _mm256_sub_ps(_mm256_mul_ps(x2, w2), _mm256_mul_ps(x1, w1));
541        let out_rev = _mm256_add_ps(_mm256_mul_ps(x2, w1), _mm256_mul_ps(x1, w2));
542
543        _mm256_storeu_ps(output.as_mut_ptr().add(i), out_fwd);
544
545        let mut out_rev_tmp = [0f32; 8];
546        _mm256_storeu_ps(out_rev_tmp.as_mut_ptr(), out_rev);
547        for j in 0..8 {
548            output[overlap - 1 - (i + j)] = out_rev_tmp[j];
549        }
550
551        i += 8;
552    }
553
554    for i in i..overlap2 {
555        let x1 = output[overlap - 1 - i];
556        let x2 = output[i];
557        let wp1 = window[i];
558        let wp2 = window[overlap - 1 - i];
559        output[i] = x2 * wp2 - x1 * wp1;
560        output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
561    }
562}
563
564#[cfg(target_arch = "aarch64")]
565#[inline(always)]
566fn mdct_pre_rotation_neon(
567    f: &[f32],
568    f2: &mut [KissCpx],
569    trig: &[f32],
570    bitrev: &[i16],
571    n4: usize,
572    scale: f32,
573) {
574    use std::arch::aarch64::*;
575
576    unsafe {
577        let vscale = vdupq_n_f32(scale);
578        let f_ptr = f.as_ptr();
579        let trig_ptr = trig.as_ptr();
580        let bitrev_ptr = bitrev.as_ptr();
581        let f2_ptr = f2.as_mut_ptr() as *mut f32;
582
583        let n4_vec = n4 & !3;
584        let mut i = 0;
585
586        while i < n4_vec {
587            let t0 = vld1q_f32(trig_ptr.add(i));
588            let t1 = vld1q_f32(trig_ptr.add(n4 + i));
589
590            let f0 = vld1q_f32(f_ptr.add(2 * i));
591            let f1 = vld1q_f32(f_ptr.add(2 * i + 4));
592
593            let even_odd = vuzpq_f32(f0, f1);
594            let re_v = even_odd.0;
595            let im_v = even_odd.1;
596
597            let yr = vsubq_f32(vmulq_f32(re_v, t0), vmulq_f32(im_v, t1));
598            let yi = vaddq_f32(vmulq_f32(im_v, t0), vmulq_f32(re_v, t1));
599
600            let yr = vmulq_f32(yr, vscale);
601            let yi = vmulq_f32(yi, vscale);
602
603            let yr_arr: [f32; 4] = std::mem::transmute(yr);
604            let yi_arr: [f32; 4] = std::mem::transmute(yi);
605
606            for j in 0..4 {
607                let rev = *bitrev_ptr.add(i + j) as usize;
608                *f2_ptr.add(2 * rev) = yr_arr[j];
609                *f2_ptr.add(2 * rev + 1) = yi_arr[j];
610            }
611
612            i += 4;
613        }
614
615        for i in n4_vec..n4 {
616            let re = *f_ptr.add(2 * i);
617            let im = *f_ptr.add(2 * i + 1);
618            let t0 = *trig_ptr.add(i);
619            let t1 = *trig_ptr.add(n4 + i);
620            let yr = re * t0 - im * t1;
621            let yi = im * t0 + re * t1;
622            let rev = *bitrev_ptr.add(i) as usize;
623            *f2_ptr.add(2 * rev) = yr * scale;
624            *f2_ptr.add(2 * rev + 1) = yi * scale;
625        }
626    }
627}
628
629#[cfg(target_arch = "aarch64")]
630#[inline(always)]
631fn mdct_post_rotation_neon(
632    f2: &[KissCpx],
633    trig: &[f32],
634    output: &mut [f32],
635    n4: usize,
636    n2: usize,
637    stride: usize,
638) {
639    use std::arch::aarch64::*;
640
641    if stride > 1 {
642        for i in 0..n4 {
643            let fp = &f2[i];
644            let t0 = trig[i];
645            let t1 = trig[n4 + i];
646            let yr = fp.i * t1 - fp.r * t0;
647            let yi = fp.r * t1 + fp.i * t0;
648            output[i * 2 * stride] = yr;
649            output[stride * (n2 - 1 - 2 * i)] = yi;
650        }
651        return;
652    }
653
654    unsafe {
655        let f2_ptr = f2.as_ptr() as *const f32;
656        let trig_ptr = trig.as_ptr();
657        let out_ptr = output.as_mut_ptr();
658
659        let n4_vec = n4 & !3;
660        let mut i = 0;
661
662        while i < n4_vec {
663            let c0 = vld1q_f32(f2_ptr.add(2 * i));
664            let c1 = vld1q_f32(f2_ptr.add(2 * i + 4));
665
666            let t0 = vld1q_f32(trig_ptr.add(i));
667            let t1 = vld1q_f32(trig_ptr.add(n4 + i));
668
669            let ri = vuzpq_f32(c0, c1);
670            let r_v = ri.0;
671            let i_v = ri.1;
672
673            let yr = vsubq_f32(vmulq_f32(i_v, t1), vmulq_f32(r_v, t0));
674
675            let yi = vaddq_f32(vmulq_f32(r_v, t1), vmulq_f32(i_v, t0));
676
677            let yr_arr: [f32; 4] = std::mem::transmute(yr);
678            let yi_arr: [f32; 4] = std::mem::transmute(yi);
679
680            for j in 0..4 {
681                *out_ptr.add((i + j) * 2) = yr_arr[j];
682                *out_ptr.add(n2 - 1 - 2 * (i + j)) = yi_arr[j];
683            }
684
685            i += 4;
686        }
687
688        for i in n4_vec..n4 {
689            let fp = &f2[i];
690            let t0 = trig[i];
691            let t1 = trig[n4 + i];
692            let yr = fp.i * t1 - fp.r * t0;
693            let yi = fp.r * t1 + fp.i * t0;
694            output[i * 2] = yr;
695            output[n2 - 1 - 2 * i] = yi;
696        }
697    }
698}
699
700#[cfg(target_arch = "aarch64")]
701#[inline(always)]
702fn mdct_backward_pre_rotation_neon(
703    input: &[f32],
704    f2: &mut [KissCpx],
705    trig: &[f32],
706    bitrev: &[i16],
707    n4: usize,
708    n2: usize,
709    stride: usize,
710) {
711    use std::arch::aarch64::*;
712
713    if stride != 1 {
714        for i in 0..n4 {
715            let rev = bitrev[i] as usize;
716            let x1 = input[2 * i * stride];
717            let x2 = input[stride * (n2 - 1 - 2 * i)];
718            let t0 = trig[i];
719            let t1 = trig[n4 + i];
720            let yr = x2 * t0 + x1 * t1;
721            let yi = x1 * t0 - x2 * t1;
722            f2[rev] = KissCpx::new(yi, yr);
723        }
724        return;
725    }
726
727    unsafe {
728        let in_ptr = input.as_ptr();
729        let trig_ptr = trig.as_ptr();
730        let bitrev_ptr = bitrev.as_ptr();
731        let f2_ptr = f2.as_mut_ptr() as *mut f32;
732
733        let n4_vec = n4 & !3;
734        let mut i = 0;
735
736        while i < n4_vec {
737            let f0 = vld1q_f32(in_ptr.add(2 * i));
738            let f1 = vld1q_f32(in_ptr.add(2 * i + 4));
739            let deint_x1 = vuzpq_f32(f0, f1);
740            let x1_v = deint_x1.0;
741
742            let g0 = vld1q_f32(in_ptr.add(n2 - 7 - 2 * i));
743            let g1 = vld1q_f32(in_ptr.add(n2 - 3 - 2 * i));
744            let deint_x2 = vuzpq_f32(g0, g1);
745
746            let x2_raw = deint_x2.0;
747            let x2_v = vrev64q_f32(x2_raw);
748            let x2_v = vextq_f32(x2_v, x2_v, 2);
749
750            let t0 = vld1q_f32(trig_ptr.add(i));
751            let t1 = vld1q_f32(trig_ptr.add(n4 + i));
752
753            let yr = vaddq_f32(vmulq_f32(x2_v, t0), vmulq_f32(x1_v, t1));
754            let yi = vsubq_f32(vmulq_f32(x1_v, t0), vmulq_f32(x2_v, t1));
755
756            let yr_arr: [f32; 4] = std::mem::transmute(yr);
757            let yi_arr: [f32; 4] = std::mem::transmute(yi);
758
759            for j in 0..4 {
760                let rev = *bitrev_ptr.add(i + j) as usize;
761                *f2_ptr.add(2 * rev) = yi_arr[j];
762                *f2_ptr.add(2 * rev + 1) = yr_arr[j];
763            }
764
765            i += 4;
766        }
767
768        for i in n4_vec..n4 {
769            let rev = *bitrev_ptr.add(i) as usize;
770            let x1 = *in_ptr.add(2 * i);
771            let x2 = *in_ptr.add(n2 - 1 - 2 * i);
772            let t0 = *trig_ptr.add(i);
773            let t1 = *trig_ptr.add(n4 + i);
774            let yr = x2 * t0 + x1 * t1;
775            let yi = x1 * t0 - x2 * t1;
776            *f2_ptr.add(2 * rev) = yi;
777            *f2_ptr.add(2 * rev + 1) = yr;
778        }
779    }
780}
781
782#[cfg(target_arch = "aarch64")]
783#[inline(always)]
784fn mdct_backward_post_rotation_neon(
785    f2: &[KissCpx],
786    trig: &[f32],
787    output: &mut [f32],
788    n4: usize,
789    n2: usize,
790    overlap2: usize,
791) {
792    unsafe {
793        let trig_ptr = trig.as_ptr();
794        let out_base = output.as_mut_ptr().add(overlap2);
795
796        let half = (n4 + 1) >> 1;
797
798        let mut i = 0;
799        while i + 1 < half {
800            let j0 = n4 - 1 - i;
801            let j1 = n4 - 1 - (i + 1);
802
803            let re0 = f2[i].i;
804            let im0 = f2[i].r;
805            let t0_0 = *trig_ptr.add(i);
806            let t1_0 = *trig_ptr.add(n4 + i);
807            let yr0 = re0 * t0_0 + im0 * t1_0;
808            let yi0 = re0 * t1_0 - im0 * t0_0;
809
810            let im1 = f2[j0].r;
811            let re1 = f2[j0].i;
812            let t0_1 = *trig_ptr.add(j0);
813            let t1_1 = *trig_ptr.add(n4 + j0);
814            let yr1 = re1 * t0_1 + im1 * t1_1;
815            let yi1 = re1 * t1_1 - im1 * t0_1;
816
817            *out_base.add(2 * i) = yr0;
818            *out_base.add(n2 - 1 - 2 * i) = yi0;
819            *out_base.add(n2 - 2 - 2 * i) = yr1;
820            *out_base.add(2 * i + 1) = yi1;
821
822            let re0b = f2[i + 1].i;
823            let im0b = f2[i + 1].r;
824            let t0_0b = *trig_ptr.add(i + 1);
825            let t1_0b = *trig_ptr.add(n4 + i + 1);
826            let yr0b = re0b * t0_0b + im0b * t1_0b;
827            let yi0b = re0b * t1_0b - im0b * t0_0b;
828
829            let im1b = f2[j1].r;
830            let re1b = f2[j1].i;
831            let t0_1b = *trig_ptr.add(j1);
832            let t1_1b = *trig_ptr.add(n4 + j1);
833            let yr1b = re1b * t0_1b + im1b * t1_1b;
834            let yi1b = re1b * t1_1b - im1b * t0_1b;
835
836            *out_base.add(2 * (i + 1)) = yr0b;
837            *out_base.add(n2 - 1 - 2 * (i + 1)) = yi0b;
838            *out_base.add(n2 - 2 - 2 * (i + 1)) = yr1b;
839            *out_base.add(2 * (i + 1) + 1) = yi1b;
840
841            i += 2;
842        }
843
844        if i < half {
845            let j = n4 - 1 - i;
846            let im0 = f2[i].r;
847            let re0 = f2[i].i;
848            let t0_0 = *trig_ptr.add(i);
849            let t1_0 = *trig_ptr.add(n4 + i);
850            let yr0 = re0 * t0_0 + im0 * t1_0;
851            let yi0 = re0 * t1_0 - im0 * t0_0;
852
853            let im1 = f2[j].r;
854            let re1 = f2[j].i;
855            let t0_1 = *trig_ptr.add(j);
856            let t1_1 = *trig_ptr.add(n4 + j);
857            let yr1 = re1 * t0_1 + im1 * t1_1;
858            let yi1 = re1 * t1_1 - im1 * t0_1;
859
860            *out_base.add(2 * i) = yr0;
861            *out_base.add(n2 - 1 - 2 * i) = yi0;
862            *out_base.add(n2 - 2 - 2 * i) = yr1;
863            *out_base.add(2 * i + 1) = yi1;
864        }
865    }
866}
867
868#[cfg(target_arch = "aarch64")]
869#[inline(always)]
870fn mdct_tdac_neon(output: &mut [f32], window: &[f32], overlap: usize) {
871    use std::arch::aarch64::*;
872
873    let overlap2 = overlap / 2;
874    if overlap2 < 4 {
875        for i in 0..overlap2 {
876            let x1 = output[overlap - 1 - i];
877            let x2 = output[i];
878            let wp1 = window[i];
879            let wp2 = window[overlap - 1 - i];
880            output[i] = x2 * wp2 - x1 * wp1;
881            output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
882        }
883        return;
884    }
885
886    unsafe {
887        let out_ptr = output.as_mut_ptr();
888        let win_ptr = window.as_ptr();
889        let n4 = overlap2 & !3;
890        let mut i = 0;
891
892        while i < n4 {
893            let x2_fwd = vld1q_f32(out_ptr.add(i));
894            let x1_rev = vld1q_f32(out_ptr.add(overlap - 4 - i));
895
896            let x1 = vrev64q_f32(x1_rev);
897            let x1 = vextq_f32(x1, x1, 2);
898
899            let wp1_fwd = vld1q_f32(win_ptr.add(i));
900            let wp2_rev = vld1q_f32(win_ptr.add(overlap - 4 - i));
901            let wp2 = vrev64q_f32(wp2_rev);
902            let wp2 = vextq_f32(wp2, wp2, 2);
903            let wp1 = wp1_fwd;
904
905            let out_fwd = vsubq_f32(vmulq_f32(x2_fwd, wp2), vmulq_f32(x1, wp1));
906
907            let out_rev = vaddq_f32(vmulq_f32(x2_fwd, wp1), vmulq_f32(x1, wp2));
908
909            let out_rev = vrev64q_f32(out_rev);
910            let out_rev = vextq_f32(out_rev, out_rev, 2);
911
912            vst1q_f32(out_ptr.add(i), out_fwd);
913            vst1q_f32(out_ptr.add(overlap - 4 - i), out_rev);
914
915            i += 4;
916        }
917
918        for i in n4..overlap2 {
919            let x1 = output[overlap - 1 - i];
920            let x2 = output[i];
921            output[i] = x2 * window[overlap - 1 - i] - x1 * window[i];
922            output[overlap - 1 - i] = x2 * window[i] + x1 * window[overlap - 1 - i];
923        }
924    }
925}
926
927#[cfg(test)]
928mod mdct_tests {
929    #[test]
930    fn test_mdct_backward_transient_no_blowup() {
931        let mode = crate::modes::default_mode();
932        let shift = 3;
933        let n = mode.mdct.n >> shift; // 120
934        let overlap = mode.overlap; // 120
935        let stride = 8;
936
937        let frame_size = 960usize;
938        let mut freq = vec![0.0f32; frame_size];
939        for i in 0..frame_size {
940            freq[i] = ((i as f32) * 0.01).sin() * 10.0;
941        }
942
943        let out_len = n + overlap; // 240
944        let mut output0 = vec![0.0f32; out_len];
945        let mut output1 = vec![0.0f32; out_len];
946
947        mode.mdct.backward(
948            &freq[0..],
949            &mut output0,
950            mode.window,
951            overlap,
952            shift,
953            stride,
954        );
955        mode.mdct.backward(
956            &freq[1..],
957            &mut output1,
958            mode.window,
959            overlap,
960            shift,
961            stride,
962        );
963
964        let max0 = output0.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
965        let max1 = output1.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
966        eprintln!("sub0 max={} sub1 max={}", max0, max1);
967        eprintln!("sub0[60..70]={:?}", &output0[60..70]);
968        eprintln!("sub1[60..70]={:?}", &output1[60..70]);
969
970        assert!(max0.abs() < 500.0, "sub0 blowup: {}", max0);
971        assert!(max1.abs() < 500.0, "sub1 blowup: {}", max1);
972    }
973
974    #[test]
975    fn test_mdct_backward_stride1_neon_matches_scalar() {
976        let mode = crate::modes::default_mode();
977        let shift = 0; // non-transient full-size MDCT
978        let n = mode.mdct.n >> shift; // 1920
979        let n2 = n / 2; // 960
980        let n4 = n / 4; // 480
981        let overlap = mode.overlap; // 120
982        let overlap2 = overlap / 2; // 60
983        let stride = 1;
984
985        let freq_len = n2;
986        let mut freq = vec![0.0f32; freq_len + 4];
987        for i in 0..freq_len {
988            freq[i] = ((i as f32) * 0.01).sin() * 4577.0;
989        }
990
991        let out_len = overlap2 + n2; // 60 + 960 = 1020
992        let mut output_hw = vec![0.0f32; out_len + 100];
993        mode.mdct.backward(
994            &freq[..],
995            &mut output_hw,
996            mode.window,
997            overlap,
998            shift,
999            stride,
1000        );
1001
1002        let st = mode.mdct.kfft[shift].as_ref().unwrap();
1003        let (trig, _) = mode.mdct.get_trig(shift);
1004
1005        use crate::kiss_fft::KissCpx;
1006        let mut f2 = vec![KissCpx::new(0.0, 0.0); n4];
1007        for i in 0..n4 {
1008            let rev = st.bitrev[i] as usize;
1009            let x1 = freq[2 * i * stride];
1010            let x2 = freq[stride * (n2 - 1 - 2 * i)];
1011            let t0 = trig[i];
1012            let t1 = trig[n4 + i];
1013            let yr = x2 * t0 + x1 * t1;
1014            let yi = x1 * t0 - x2 * t1;
1015            f2[rev] = KissCpx::new(yi, yr);
1016        }
1017        crate::kiss_fft::opus_fft_impl(st, &mut f2);
1018
1019        let mut output_scalar = vec![0.0f32; out_len + 100];
1020        for i in 0..((n4 + 1) >> 1) {
1021            let im0 = f2[i].r;
1022            let re0 = f2[i].i;
1023            let t0_0 = trig[i];
1024            let t1_0 = trig[n4 + i];
1025            let yr0 = re0 * t0_0 + im0 * t1_0;
1026            let yi0 = re0 * t1_0 - im0 * t0_0;
1027            let j = n4 - 1 - i;
1028            let im1 = f2[j].r;
1029            let re1 = f2[j].i;
1030            let t0_1 = trig[j];
1031            let t1_1 = trig[n4 + j];
1032            let yr1 = re1 * t0_1 + im1 * t1_1;
1033            let yi1 = re1 * t1_1 - im1 * t0_1;
1034            output_scalar[overlap2 + 2 * i] = yr0;
1035            output_scalar[overlap2 + n2 - 1 - 2 * i] = yi0;
1036            output_scalar[overlap2 + n2 - 2 - 2 * i] = yr1;
1037            output_scalar[overlap2 + 2 * i + 1] = yi1;
1038        }
1039        // TDAC
1040        for i in 0..overlap2 {
1041            let x1 = output_scalar[overlap - 1 - i];
1042            let x2 = output_scalar[i];
1043            let wp1 = mode.window[i];
1044            let wp2 = mode.window[overlap - 1 - i];
1045            output_scalar[i] = x2 * wp2 - x1 * wp1;
1046            output_scalar[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
1047        }
1048
1049        let max_diff = output_hw[..out_len]
1050            .iter()
1051            .zip(output_scalar[..out_len].iter())
1052            .map(|(a, b)| (a - b).abs())
1053            .fold(0.0f32, f32::max);
1054        assert!(
1055            max_diff < 0.5,
1056            "stride=1 NEON vs scalar mismatch: max_diff={}",
1057            max_diff
1058        );
1059    }
1060
1061    #[test]
1062    fn test_mdct_backward_neon_matches_scalar() {
1063        let mode = crate::modes::default_mode();
1064        let shift = 3;
1065        let n = mode.mdct.n >> shift; // 240
1066        let n2 = n / 2; // 120
1067        let n4 = n / 4; // 60
1068        let overlap = mode.overlap; // 120
1069        let overlap2 = overlap / 2; // 60
1070        let stride = 8;
1071
1072        // Build a realistic freq vector (sine wave @ 440Hz)
1073        let frame_size = 960usize;
1074        let mut freq = vec![0.0f32; frame_size];
1075        for i in 0..frame_size {
1076            freq[i] = ((i as f32) * 0.01).sin() * 200.0;
1077        }
1078
1079        let out_len = n + overlap; // 360
1080        let mut output_hw = vec![0.0f32; out_len];
1081        mode.mdct.backward(
1082            &freq[0..],
1083            &mut output_hw,
1084            mode.window,
1085            overlap,
1086            shift,
1087            stride,
1088        );
1089
1090        // Scalar reference
1091        let st = mode.mdct.kfft[shift].as_ref().unwrap();
1092        let (trig, _) = mode.mdct.get_trig(shift);
1093
1094        use crate::kiss_fft::KissCpx;
1095        let mut f2 = vec![KissCpx::new(0.0, 0.0); n4];
1096        for i in 0..n4 {
1097            let rev = st.bitrev[i] as usize;
1098            let x1 = freq[2 * i * stride];
1099            let x2 = freq[stride * (n2 - 1 - 2 * i)];
1100            let t0 = trig[i];
1101            let t1 = trig[n4 + i];
1102            let yr = x2 * t0 + x1 * t1;
1103            let yi = x1 * t0 - x2 * t1;
1104            f2[rev] = KissCpx::new(yi, yr);
1105        }
1106        crate::kiss_fft::opus_fft_impl(st, &mut f2);
1107
1108        let mut output_scalar = vec![0.0f32; out_len];
1109        for i in 0..((n4 + 1) >> 1) {
1110            let im0 = f2[i].r;
1111            let re0 = f2[i].i;
1112            let t0_0 = trig[i];
1113            let t1_0 = trig[n4 + i];
1114            let yr0 = re0 * t0_0 + im0 * t1_0;
1115            let yi0 = re0 * t1_0 - im0 * t0_0;
1116            let j = n4 - 1 - i;
1117            let im1 = f2[j].r;
1118            let re1 = f2[j].i;
1119            let t0_1 = trig[j];
1120            let t1_1 = trig[n4 + j];
1121            let yr1 = re1 * t0_1 + im1 * t1_1;
1122            let yi1 = re1 * t1_1 - im1 * t0_1;
1123            output_scalar[overlap2 + 2 * i] = yr0;
1124            output_scalar[overlap2 + n2 - 1 - 2 * i] = yi0;
1125            output_scalar[overlap2 + n2 - 2 - 2 * i] = yr1;
1126            output_scalar[overlap2 + 2 * i + 1] = yi1;
1127        }
1128        // TDAC
1129        for i in 0..overlap2 {
1130            let x1 = output_scalar[overlap - 1 - i];
1131            let x2 = output_scalar[i];
1132            let wp1 = mode.window[i];
1133            let wp2 = mode.window[overlap - 1 - i];
1134            output_scalar[i] = x2 * wp2 - x1 * wp1;
1135            output_scalar[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
1136        }
1137
1138        for i in 0..out_len {
1139            let diff = (output_hw[i] - output_scalar[i]).abs();
1140            if diff > 1e-3 {
1141                eprintln!(
1142                    "Mismatch at output[{}]: hw={} scalar={} diff={}",
1143                    i, output_hw[i], output_scalar[i], diff
1144                );
1145            }
1146        }
1147        let max_diff = output_hw
1148            .iter()
1149            .zip(output_scalar.iter())
1150            .map(|(a, b)| (a - b).abs())
1151            .fold(0.0f32, f32::max);
1152        assert!(
1153            max_diff < 0.1,
1154            "NEON/HW vs scalar mismatch: max_diff={}",
1155            max_diff
1156        );
1157    }
1158}