1use crate::kiss_fft::{KissCpx, KissFftState, opus_fft_impl};
2use std::f32::consts::PI;
3use std::mem::MaybeUninit;
4
5const MAX_N2: usize = 960;
6const MAX_N4: usize = 480;
7
8pub struct MdctLookup {
9 pub n: usize,
10 pub max_lm: usize,
11 kfft: Vec<Option<KissFftState>>,
12 trig: Vec<f32>,
13}
14
15impl MdctLookup {
16 pub fn new(n: usize, max_lm: usize) -> Self {
17 let mut kfft = Vec::new();
18 let mut trig = Vec::new();
19 let mut curr_n = n;
20
21 for shift in 0..=max_lm {
22 let n4 = curr_n / 4;
23
24 if shift == 0 {
25 kfft.push(KissFftState::new(n4));
26 } else if let Some(base) = kfft.first().unwrap().as_ref() {
27 kfft.push(KissFftState::new_sub(base, n4));
28 } else {
29 kfft.push(None);
30 }
31
32 let n2 = curr_n / 2;
33 for i in 0..n2 {
34 let angle = 2.0 * PI * (i as f32 + 0.125) / curr_n as f32;
35 trig.push(angle.cos());
36 }
37
38 curr_n >>= 1;
39 }
40
41 Self {
42 n,
43 max_lm,
44 kfft,
45 trig,
46 }
47 }
48
49 fn get_trig(&self, shift: usize) -> (&[f32], usize) {
50 let mut offset = 0;
51 let mut curr_n = self.n;
52 for _ in 0..shift {
53 offset += curr_n / 2;
54 curr_n >>= 1;
55 }
56 (&self.trig[offset..offset + curr_n / 2], curr_n / 4)
57 }
58
59 pub fn get_trig_debug(&self, shift: usize) -> &[f32] {
60 let (trig, _) = self.get_trig(shift);
61 trig
62 }
63
64 #[inline]
65 pub fn forward(
66 &self,
67 input: &[f32],
68 output: &mut [f32],
69 window: &[f32],
70 overlap: usize,
71 shift: usize,
72 stride: usize,
73 ) {
74 let st = self.kfft[shift]
75 .as_ref()
76 .expect("FFT state not initialized");
77 let n = self.n >> shift;
78 let n2 = n / 2;
79 let n4 = n / 4;
80 let scale = st.scale();
81
82 let (trig, _) = self.get_trig(shift);
83 let overlap2 = overlap / 2;
84
85 let mut f_buf = [MaybeUninit::<f32>::uninit(); MAX_N2];
86 let mut f2_buf = [MaybeUninit::<KissCpx>::uninit(); MAX_N4];
87
88 let f = unsafe { std::slice::from_raw_parts_mut(f_buf.as_mut_ptr() as *mut f32, n2) };
89 let f2 = unsafe { std::slice::from_raw_parts_mut(f2_buf.as_mut_ptr() as *mut KissCpx, n4) };
90
91 assert!(input.len() >= n2 + overlap2);
92 assert!(window.len() >= overlap);
93 assert!(
94 output.len() >= n2,
95 "MDCT forward: output buffer too small (need {}, have {})",
96 n2,
97 output.len()
98 );
99
100 {
101 let mut yp = 0usize;
102 let mut xp1 = overlap2;
103 let mut xp2 = n2 - 1 + overlap2;
104 let mut wp1 = overlap2;
105
106 let mut wp2 = overlap2.saturating_sub(1);
107
108 let limit = overlap.div_ceil(4);
109 let mid = n4.saturating_sub(limit);
110
111 let loop1_iters = limit.min(n4);
112 for _ in 0..loop1_iters {
113 let w1 = window[wp1];
114 let w2 = window[wp2];
115
116 f[yp] = input[xp1 + n2] * w2 + input[xp2] * w1;
117 yp += 1;
118
119 f[yp] = input[xp1] * w1 - input[xp2 - n2] * w2;
120 yp += 1;
121
122 xp1 += 2;
123 xp2 -= 2;
124 wp1 += 2;
125 wp2 = wp2.saturating_sub(2);
126 }
127
128 for _ in limit..mid {
129 f[yp] = input[xp2];
130 yp += 1;
131
132 f[yp] = input[xp1];
133 yp += 1;
134 xp1 += 2;
135 xp2 -= 2;
136 }
137
138 let loop3_iters = if mid > limit { n4 - mid } else { 0 };
139 let mut wp1_l3 = 0usize;
140 let mut wp2_l3 = overlap.saturating_sub(1);
141 for _ in 0..loop3_iters {
142 let w1 = window[wp1_l3];
143 let w2 = window[wp2_l3];
144
145 f[yp] = -input[xp1 - n2] * w1 + input[xp2] * w2;
146 yp += 1;
147
148 f[yp] = input[xp1] * w2 + input[xp2 + n2] * w1;
149 yp += 1;
150
151 xp1 += 2;
152 xp2 -= 2;
153 wp1_l3 += 2;
154 wp2_l3 -= 2;
155 }
156 }
157
158 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
159 unsafe {
160 if std::arch::is_x86_feature_detected!("avx") {
161 mdct_pre_rotation_avx(f, f2, trig, &st.bitrev[..n4], n4, scale);
162 } else {
163 for i in 0..n4 {
164 let re = f[2 * i];
165 let im = f[2 * i + 1];
166 let t0 = trig[i];
167 let t1 = trig[n4 + i];
168
169 let yr = re * t0 - im * t1;
170 let yi = im * t0 + re * t1;
171
172 f2[st.bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
173 }
174 }
175 }
176 #[cfg(all(
177 not(any(target_arch = "x86", target_arch = "x86_64")),
178 target_arch = "aarch64"
179 ))]
180 {
181 mdct_pre_rotation_neon(f, f2, trig, &st.bitrev[..n4], n4, scale);
182 }
183 #[cfg(all(
184 not(any(target_arch = "x86", target_arch = "x86_64")),
185 not(target_arch = "aarch64")
186 ))]
187 for i in 0..n4 {
188 let re = f[2 * i];
189 let im = f[2 * i + 1];
190 let t0 = trig[i];
191 let t1 = trig[n4 + i];
192
193 let yr = re * t0 - im * t1;
194 let yi = im * t0 + re * t1;
195
196 f2[st.bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
197 }
198
199 opus_fft_impl(st, f2);
200
201 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
202 unsafe {
203 if std::arch::is_x86_feature_detected!("avx") {
204 mdct_post_rotation_avx(f2, trig, output, n4, n2, stride);
205 } else {
206 for i in 0..n4 {
207 let fp = &f2[i];
208 let t0 = trig[i];
209 let t1 = trig[n4 + i];
210
211 let yr = fp.i * t1 - fp.r * t0;
212 let yi = fp.r * t1 + fp.i * t0;
213
214 output[i * 2 * stride] = yr;
215 output[stride * (n2 - 1 - 2 * i)] = yi;
216 }
217 }
218 }
219 #[cfg(all(
220 not(any(target_arch = "x86", target_arch = "x86_64")),
221 target_arch = "aarch64"
222 ))]
223 {
224 mdct_post_rotation_neon(f2, trig, output, n4, n2, stride);
225 }
226 #[cfg(all(
227 not(any(target_arch = "x86", target_arch = "x86_64")),
228 not(target_arch = "aarch64")
229 ))]
230 for i in 0..n4 {
231 let fp = &f2[i];
232 let t0 = trig[i];
233 let t1 = trig[n4 + i];
234
235 let yr = fp.i * t1 - fp.r * t0;
236 let yi = fp.r * t1 + fp.i * t0;
237
238 output[i * 2 * stride] = yr;
239 output[stride * (n2 - 1 - 2 * i)] = yi;
240 }
241 }
242
243 #[inline]
244 pub fn backward(
245 &self,
246 input: &[f32],
247 output: &mut [f32],
248 window: &[f32],
249 overlap: usize,
250 shift: usize,
251 stride: usize,
252 ) {
253 let st = self.kfft[shift]
254 .as_ref()
255 .expect("FFT state not initialized");
256 let n = self.n >> shift;
257 let n2 = n / 2;
258 let n4 = n / 4;
259 let overlap2 = overlap / 2;
260
261 let (trig, _) = self.get_trig(shift);
262
263 let mut f2_buf = [MaybeUninit::<KissCpx>::uninit(); MAX_N4];
264
265 let f2 = unsafe { std::slice::from_raw_parts_mut(f2_buf.as_mut_ptr() as *mut KissCpx, n4) };
266
267 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
268 unsafe {
269 if std::arch::is_x86_feature_detected!("avx") {
270 mdct_backward_pre_rotation_avx(input, f2, trig, &st.bitrev[..n4], n4, n2, stride);
271 } else {
272 for i in 0..n4 {
273 let rev = st.bitrev[i] as usize;
274 let x1 = input[2 * i * stride];
275 let x2 = input[stride * (n2 - 1 - 2 * i)];
276 let t0 = trig[i];
277 let t1 = trig[n4 + i];
278
279 let yr = x2 * t0 + x1 * t1;
280 let yi = x1 * t0 - x2 * t1;
281
282 f2[rev] = KissCpx::new(yi, yr);
283 }
284 }
285 }
286 #[cfg(all(
287 not(any(target_arch = "x86", target_arch = "x86_64")),
288 target_arch = "aarch64"
289 ))]
290 {
291 mdct_backward_pre_rotation_neon(input, f2, trig, &st.bitrev[..n4], n4, n2, stride);
292 }
293 #[cfg(all(
294 not(any(target_arch = "x86", target_arch = "x86_64")),
295 not(target_arch = "aarch64")
296 ))]
297 for i in 0..n4 {
298 let rev = st.bitrev[i] as usize;
299 let x1 = input[2 * i * stride];
300 let x2 = input[stride * (n2 - 1 - 2 * i)];
301 let t0 = trig[i];
302 let t1 = trig[n4 + i];
303
304 let yr = x2 * t0 + x1 * t1;
305 let yi = x1 * t0 - x2 * t1;
306
307 f2[rev] = KissCpx::new(yi, yr);
308 }
309
310 opus_fft_impl(st, f2);
311
312 assert!(output.len() >= overlap2 + n2);
313
314 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
315 unsafe {
316 if std::arch::is_x86_feature_detected!("avx") {
317 mdct_backward_post_rotation_avx(f2, trig, output, n4, n2, overlap2);
318 } else {
319 for i in 0..((n4 + 1) >> 1) {
320 let im0 = f2[i].r;
321 let re0 = f2[i].i;
322 let t0_0 = trig[i];
323 let t1_0 = trig[n4 + i];
324
325 let yr0 = re0 * t0_0 + im0 * t1_0;
326 let yi0 = re0 * t1_0 - im0 * t0_0;
327
328 let j = n4 - 1 - i;
329 let im1 = f2[j].r;
330 let re1 = f2[j].i;
331 let t0_1 = trig[j];
332 let t1_1 = trig[n4 + j];
333
334 let yr1 = re1 * t0_1 + im1 * t1_1;
335 let yi1 = re1 * t1_1 - im1 * t0_1;
336
337 output[overlap2 + 2 * i] = yr0;
338 output[overlap2 + n2 - 1 - 2 * i] = yi0;
339 output[overlap2 + n2 - 2 - 2 * i] = yr1;
340 output[overlap2 + 2 * i + 1] = yi1;
341 }
342 }
343 }
344 #[cfg(all(
345 not(any(target_arch = "x86", target_arch = "x86_64")),
346 target_arch = "aarch64"
347 ))]
348 {
349 mdct_backward_post_rotation_neon(f2, trig, output, n4, n2, overlap2);
350 }
351 #[cfg(all(
352 not(any(target_arch = "x86", target_arch = "x86_64")),
353 not(target_arch = "aarch64")
354 ))]
355 for i in 0..((n4 + 1) >> 1) {
356 let im0 = f2[i].r;
357 let re0 = f2[i].i;
358 let t0_0 = trig[i];
359 let t1_0 = trig[n4 + i];
360
361 let yr0 = re0 * t0_0 + im0 * t1_0;
362 let yi0 = re0 * t1_0 - im0 * t0_0;
363
364 let j = n4 - 1 - i;
365 let im1 = f2[j].r;
366 let re1 = f2[j].i;
367 let t0_1 = trig[j];
368 let t1_1 = trig[n4 + j];
369
370 let yr1 = re1 * t0_1 + im1 * t1_1;
371 let yi1 = re1 * t1_1 - im1 * t0_1;
372
373 output[overlap2 + 2 * i] = yr0;
374 output[overlap2 + n2 - 1 - 2 * i] = yi0;
375 output[overlap2 + n2 - 2 - 2 * i] = yr1;
376 output[overlap2 + 2 * i + 1] = yi1;
377 }
378
379 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
380 unsafe {
381 if std::arch::is_x86_feature_detected!("avx") {
382 mdct_tdac_avx(output, window, overlap);
383 } else {
384 for i in 0..overlap2 {
385 let x1 = output[overlap - 1 - i];
386 let x2 = output[i];
387 let wp1 = window[i];
388 let wp2 = window[overlap - 1 - i];
389
390 output[i] = x2 * wp2 - x1 * wp1;
391 output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
392 }
393 }
394 }
395 #[cfg(all(
396 not(any(target_arch = "x86", target_arch = "x86_64")),
397 target_arch = "aarch64"
398 ))]
399 {
400 mdct_tdac_neon(output, window, overlap);
401 }
402 #[cfg(all(
403 not(any(target_arch = "x86", target_arch = "x86_64")),
404 not(target_arch = "aarch64")
405 ))]
406 for i in 0..overlap2 {
407 let x1 = output[overlap - 1 - i];
408 let x2 = output[i];
409 let wp1 = window[i];
410 let wp2 = window[overlap - 1 - i];
411
412 output[i] = x2 * wp2 - x1 * wp1;
413 output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
414 }
415 }
416}
417
418#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
419#[target_feature(enable = "avx")]
420unsafe fn mdct_pre_rotation_avx(
421 f: &[f32],
422 f2: &mut [KissCpx],
423 trig: &[f32],
424 bitrev: &[i16],
425 n4: usize,
426 scale: f32,
427) {
428 for i in 0..n4 {
429 let re = f[2 * i];
430 let im = f[2 * i + 1];
431 let t0 = trig[i];
432 let t1 = trig[n4 + i];
433
434 let yr = re * t0 - im * t1;
435 let yi = im * t0 + re * t1;
436
437 f2[bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
438 }
439}
440
441#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
442#[target_feature(enable = "avx")]
443unsafe fn mdct_post_rotation_avx(
444 f2: &[KissCpx],
445 trig: &[f32],
446 output: &mut [f32],
447 n4: usize,
448 n2: usize,
449 stride: usize,
450) {
451 for i in 0..n4 {
452 let fp = &f2[i];
453 let t0 = trig[i];
454 let t1 = trig[n4 + i];
455
456 let yr = fp.i * t1 - fp.r * t0;
457 let yi = fp.r * t1 + fp.i * t0;
458
459 output[i * 2 * stride] = yr;
460 output[stride * (n2 - 1 - 2 * i)] = yi;
461 }
462}
463
464#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
465#[target_feature(enable = "avx")]
466unsafe fn mdct_backward_pre_rotation_avx(
467 input: &[f32],
468 f2: &mut [KissCpx],
469 trig: &[f32],
470 bitrev: &[i16],
471 n4: usize,
472 n2: usize,
473 stride: usize,
474) {
475 for i in 0..n4 {
476 let rev = bitrev[i] as usize;
477 let x1 = input[2 * i * stride];
478 let x2 = input[stride * (n2 - 1 - 2 * i)];
479 let t0 = trig[i];
480 let t1 = trig[n4 + i];
481
482 let yr = x2 * t0 + x1 * t1;
483 let yi = x1 * t0 - x2 * t1;
484
485 f2[rev] = KissCpx::new(yi, yr);
486 }
487}
488
489#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
490#[target_feature(enable = "avx")]
491unsafe fn mdct_backward_post_rotation_avx(
492 f2: &[KissCpx],
493 trig: &[f32],
494 output: &mut [f32],
495 n4: usize,
496 n2: usize,
497 overlap2: usize,
498) {
499 for i in 0..((n4 + 1) >> 1) {
500 let im0 = f2[i].r;
501 let re0 = f2[i].i;
502 let t0_0 = trig[i];
503 let t1_0 = trig[n4 + i];
504
505 let yr0 = re0 * t0_0 + im0 * t1_0;
506 let yi0 = re0 * t1_0 - im0 * t0_0;
507
508 let j = n4 - 1 - i;
509 let im1 = f2[j].r;
510 let re1 = f2[j].i;
511 let t0_1 = trig[j];
512 let t1_1 = trig[n4 + j];
513
514 let yr1 = re1 * t0_1 + im1 * t1_1;
515 let yi1 = re1 * t1_1 - im1 * t0_1;
516
517 output[overlap2 + 2 * i] = yr0;
518 output[overlap2 + n2 - 1 - 2 * i] = yi0;
519 output[overlap2 + n2 - 2 - 2 * i] = yr1;
520 output[overlap2 + 2 * i + 1] = yi1;
521 }
522}
523
524#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
525#[target_feature(enable = "avx")]
526unsafe fn mdct_tdac_avx(output: &mut [f32], window: &[f32], overlap: usize) {
527 use std::arch::x86_64::*;
528
529 let overlap2 = overlap / 2;
530 let mut i = 0usize;
531
532 while i + 8 <= overlap2 {
533 let x2 = _mm256_loadu_ps(output.as_ptr().add(i));
534
535 let mut x1_tmp = [0f32; 8];
536 let mut w2_tmp = [0f32; 8];
537 for j in 0..8 {
538 x1_tmp[j] = output[overlap - 1 - (i + j)];
539 w2_tmp[j] = window[overlap - 1 - (i + j)];
540 }
541 let x1 = _mm256_loadu_ps(x1_tmp.as_ptr());
542
543 let w1 = _mm256_loadu_ps(window.as_ptr().add(i));
544 let w2 = _mm256_loadu_ps(w2_tmp.as_ptr());
545
546 let out_fwd = _mm256_sub_ps(_mm256_mul_ps(x2, w2), _mm256_mul_ps(x1, w1));
547 let out_rev = _mm256_add_ps(_mm256_mul_ps(x2, w1), _mm256_mul_ps(x1, w2));
548
549 _mm256_storeu_ps(output.as_mut_ptr().add(i), out_fwd);
550
551 let mut out_rev_tmp = [0f32; 8];
552 _mm256_storeu_ps(out_rev_tmp.as_mut_ptr(), out_rev);
553 for j in 0..8 {
554 output[overlap - 1 - (i + j)] = out_rev_tmp[j];
555 }
556
557 i += 8;
558 }
559
560 for i in i..overlap2 {
561 let x1 = output[overlap - 1 - i];
562 let x2 = output[i];
563 let wp1 = window[i];
564 let wp2 = window[overlap - 1 - i];
565 output[i] = x2 * wp2 - x1 * wp1;
566 output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
567 }
568}
569
570#[cfg(target_arch = "aarch64")]
571#[inline(always)]
572fn mdct_pre_rotation_neon(
573 f: &[f32],
574 f2: &mut [KissCpx],
575 trig: &[f32],
576 bitrev: &[i16],
577 n4: usize,
578 scale: f32,
579) {
580 use std::arch::aarch64::*;
581
582 unsafe {
583 let vscale = vdupq_n_f32(scale);
584 let f_ptr = f.as_ptr();
585 let trig_ptr = trig.as_ptr();
586 let bitrev_ptr = bitrev.as_ptr();
587 let f2_ptr = f2.as_mut_ptr() as *mut f32;
588
589 let n4_vec = n4 & !3;
590 let mut i = 0;
591
592 while i < n4_vec {
593 let t0 = vld1q_f32(trig_ptr.add(i));
594 let t1 = vld1q_f32(trig_ptr.add(n4 + i));
595
596 let f0 = vld1q_f32(f_ptr.add(2 * i));
597 let f1 = vld1q_f32(f_ptr.add(2 * i + 4));
598
599 let even_odd = vuzpq_f32(f0, f1);
600 let re_v = even_odd.0;
601 let im_v = even_odd.1;
602
603 let yr = vsubq_f32(vmulq_f32(re_v, t0), vmulq_f32(im_v, t1));
604 let yi = vaddq_f32(vmulq_f32(im_v, t0), vmulq_f32(re_v, t1));
605
606 let yr = vmulq_f32(yr, vscale);
607 let yi = vmulq_f32(yi, vscale);
608
609 let yr_arr: [f32; 4] = std::mem::transmute(yr);
610 let yi_arr: [f32; 4] = std::mem::transmute(yi);
611
612 for j in 0..4 {
613 let rev = *bitrev_ptr.add(i + j) as usize;
614 *f2_ptr.add(2 * rev) = yr_arr[j];
615 *f2_ptr.add(2 * rev + 1) = yi_arr[j];
616 }
617
618 i += 4;
619 }
620
621 for i in n4_vec..n4 {
622 let re = *f_ptr.add(2 * i);
623 let im = *f_ptr.add(2 * i + 1);
624 let t0 = *trig_ptr.add(i);
625 let t1 = *trig_ptr.add(n4 + i);
626 let yr = re * t0 - im * t1;
627 let yi = im * t0 + re * t1;
628 let rev = *bitrev_ptr.add(i) as usize;
629 *f2_ptr.add(2 * rev) = yr * scale;
630 *f2_ptr.add(2 * rev + 1) = yi * scale;
631 }
632 }
633}
634
635#[cfg(target_arch = "aarch64")]
636#[inline(always)]
637fn mdct_post_rotation_neon(
638 f2: &[KissCpx],
639 trig: &[f32],
640 output: &mut [f32],
641 n4: usize,
642 n2: usize,
643 stride: usize,
644) {
645 use std::arch::aarch64::*;
646
647 if stride > 1 {
648 for i in 0..n4 {
649 let fp = &f2[i];
650 let t0 = trig[i];
651 let t1 = trig[n4 + i];
652 let yr = fp.i * t1 - fp.r * t0;
653 let yi = fp.r * t1 + fp.i * t0;
654 output[i * 2 * stride] = yr;
655 output[stride * (n2 - 1 - 2 * i)] = yi;
656 }
657 return;
658 }
659
660 unsafe {
661 let f2_ptr = f2.as_ptr() as *const f32;
662 let trig_ptr = trig.as_ptr();
663 let out_ptr = output.as_mut_ptr();
664
665 let n4_vec = n4 & !3;
666 let mut i = 0;
667
668 while i < n4_vec {
669 let c0 = vld1q_f32(f2_ptr.add(2 * i));
670 let c1 = vld1q_f32(f2_ptr.add(2 * i + 4));
671
672 let t0 = vld1q_f32(trig_ptr.add(i));
673 let t1 = vld1q_f32(trig_ptr.add(n4 + i));
674
675 let ri = vuzpq_f32(c0, c1);
676 let r_v = ri.0;
677 let i_v = ri.1;
678
679 let yr = vsubq_f32(vmulq_f32(i_v, t1), vmulq_f32(r_v, t0));
680
681 let yi = vaddq_f32(vmulq_f32(r_v, t1), vmulq_f32(i_v, t0));
682
683 let yr_arr: [f32; 4] = std::mem::transmute(yr);
684 let yi_arr: [f32; 4] = std::mem::transmute(yi);
685
686 for j in 0..4 {
687 *out_ptr.add((i + j) * 2) = yr_arr[j];
688 *out_ptr.add(n2 - 1 - 2 * (i + j)) = yi_arr[j];
689 }
690
691 i += 4;
692 }
693
694 for i in n4_vec..n4 {
695 let fp = &f2[i];
696 let t0 = trig[i];
697 let t1 = trig[n4 + i];
698 let yr = fp.i * t1 - fp.r * t0;
699 let yi = fp.r * t1 + fp.i * t0;
700 output[i * 2] = yr;
701 output[n2 - 1 - 2 * i] = yi;
702 }
703 }
704}
705
706#[cfg(target_arch = "aarch64")]
707#[inline(always)]
708fn mdct_backward_pre_rotation_neon(
709 input: &[f32],
710 f2: &mut [KissCpx],
711 trig: &[f32],
712 bitrev: &[i16],
713 n4: usize,
714 n2: usize,
715 stride: usize,
716) {
717 use std::arch::aarch64::*;
718
719 if stride != 1 {
720 for i in 0..n4 {
721 let rev = bitrev[i] as usize;
722 let x1 = input[2 * i * stride];
723 let x2 = input[stride * (n2 - 1 - 2 * i)];
724 let t0 = trig[i];
725 let t1 = trig[n4 + i];
726 let yr = x2 * t0 + x1 * t1;
727 let yi = x1 * t0 - x2 * t1;
728 f2[rev] = KissCpx::new(yi, yr);
729 }
730 return;
731 }
732
733 unsafe {
734 let in_ptr = input.as_ptr();
735 let trig_ptr = trig.as_ptr();
736 let bitrev_ptr = bitrev.as_ptr();
737 let f2_ptr = f2.as_mut_ptr() as *mut f32;
738
739 let n4_vec = n4 & !3;
740 let mut i = 0;
741
742 while i < n4_vec {
743 let f0 = vld1q_f32(in_ptr.add(2 * i));
744 let f1 = vld1q_f32(in_ptr.add(2 * i + 4));
745 let deint_x1 = vuzpq_f32(f0, f1);
746 let x1_v = deint_x1.0;
747
748 let g0 = vld1q_f32(in_ptr.add(n2 - 7 - 2 * i));
749 let g1 = vld1q_f32(in_ptr.add(n2 - 3 - 2 * i));
750 let deint_x2 = vuzpq_f32(g0, g1);
751
752 let x2_raw = deint_x2.0;
753 let x2_v = vrev64q_f32(x2_raw);
754 let x2_v = vextq_f32(x2_v, x2_v, 2);
755
756 let t0 = vld1q_f32(trig_ptr.add(i));
757 let t1 = vld1q_f32(trig_ptr.add(n4 + i));
758
759 let yr = vaddq_f32(vmulq_f32(x2_v, t0), vmulq_f32(x1_v, t1));
760 let yi = vsubq_f32(vmulq_f32(x1_v, t0), vmulq_f32(x2_v, t1));
761
762 let yr_arr: [f32; 4] = std::mem::transmute(yr);
763 let yi_arr: [f32; 4] = std::mem::transmute(yi);
764
765 for j in 0..4 {
766 let rev = *bitrev_ptr.add(i + j) as usize;
767 *f2_ptr.add(2 * rev) = yi_arr[j];
768 *f2_ptr.add(2 * rev + 1) = yr_arr[j];
769 }
770
771 i += 4;
772 }
773
774 for i in n4_vec..n4 {
775 let rev = *bitrev_ptr.add(i) as usize;
776 let x1 = *in_ptr.add(2 * i);
777 let x2 = *in_ptr.add(n2 - 1 - 2 * i);
778 let t0 = *trig_ptr.add(i);
779 let t1 = *trig_ptr.add(n4 + i);
780 let yr = x2 * t0 + x1 * t1;
781 let yi = x1 * t0 - x2 * t1;
782 *f2_ptr.add(2 * rev) = yi;
783 *f2_ptr.add(2 * rev + 1) = yr;
784 }
785 }
786}
787
788#[cfg(target_arch = "aarch64")]
789#[inline(always)]
790fn mdct_backward_post_rotation_neon(
791 f2: &[KissCpx],
792 trig: &[f32],
793 output: &mut [f32],
794 n4: usize,
795 n2: usize,
796 overlap2: usize,
797) {
798 unsafe {
799 let trig_ptr = trig.as_ptr();
800 let out_base = output.as_mut_ptr().add(overlap2);
801
802 let half = (n4 + 1) >> 1;
803
804 let mut i = 0;
805 while i + 1 < half {
806 let j0 = n4 - 1 - i;
807 let j1 = n4 - 1 - (i + 1);
808
809 let re0 = f2[i].i;
810 let im0 = f2[i].r;
811 let t0_0 = *trig_ptr.add(i);
812 let t1_0 = *trig_ptr.add(n4 + i);
813 let yr0 = re0 * t0_0 + im0 * t1_0;
814 let yi0 = re0 * t1_0 - im0 * t0_0;
815
816 let im1 = f2[j0].r;
817 let re1 = f2[j0].i;
818 let t0_1 = *trig_ptr.add(j0);
819 let t1_1 = *trig_ptr.add(n4 + j0);
820 let yr1 = re1 * t0_1 + im1 * t1_1;
821 let yi1 = re1 * t1_1 - im1 * t0_1;
822
823 *out_base.add(2 * i) = yr0;
824 *out_base.add(n2 - 1 - 2 * i) = yi0;
825 *out_base.add(n2 - 2 - 2 * i) = yr1;
826 *out_base.add(2 * i + 1) = yi1;
827
828 let re0b = f2[i + 1].i;
829 let im0b = f2[i + 1].r;
830 let t0_0b = *trig_ptr.add(i + 1);
831 let t1_0b = *trig_ptr.add(n4 + i + 1);
832 let yr0b = re0b * t0_0b + im0b * t1_0b;
833 let yi0b = re0b * t1_0b - im0b * t0_0b;
834
835 let im1b = f2[j1].r;
836 let re1b = f2[j1].i;
837 let t0_1b = *trig_ptr.add(j1);
838 let t1_1b = *trig_ptr.add(n4 + j1);
839 let yr1b = re1b * t0_1b + im1b * t1_1b;
840 let yi1b = re1b * t1_1b - im1b * t0_1b;
841
842 *out_base.add(2 * (i + 1)) = yr0b;
843 *out_base.add(n2 - 1 - 2 * (i + 1)) = yi0b;
844 *out_base.add(n2 - 2 - 2 * (i + 1)) = yr1b;
845 *out_base.add(2 * (i + 1) + 1) = yi1b;
846
847 i += 2;
848 }
849
850 if i < half {
851 let j = n4 - 1 - i;
852 let im0 = f2[i].r;
853 let re0 = f2[i].i;
854 let t0_0 = *trig_ptr.add(i);
855 let t1_0 = *trig_ptr.add(n4 + i);
856 let yr0 = re0 * t0_0 + im0 * t1_0;
857 let yi0 = re0 * t1_0 - im0 * t0_0;
858
859 let im1 = f2[j].r;
860 let re1 = f2[j].i;
861 let t0_1 = *trig_ptr.add(j);
862 let t1_1 = *trig_ptr.add(n4 + j);
863 let yr1 = re1 * t0_1 + im1 * t1_1;
864 let yi1 = re1 * t1_1 - im1 * t0_1;
865
866 *out_base.add(2 * i) = yr0;
867 *out_base.add(n2 - 1 - 2 * i) = yi0;
868 *out_base.add(n2 - 2 - 2 * i) = yr1;
869 *out_base.add(2 * i + 1) = yi1;
870 }
871 }
872}
873
874#[cfg(target_arch = "aarch64")]
875#[inline(always)]
876fn mdct_tdac_neon(output: &mut [f32], window: &[f32], overlap: usize) {
877 use std::arch::aarch64::*;
878
879 let overlap2 = overlap / 2;
880 if overlap2 < 4 {
881 for i in 0..overlap2 {
882 let x1 = output[overlap - 1 - i];
883 let x2 = output[i];
884 let wp1 = window[i];
885 let wp2 = window[overlap - 1 - i];
886 output[i] = x2 * wp2 - x1 * wp1;
887 output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
888 }
889 return;
890 }
891
892 unsafe {
893 let out_ptr = output.as_mut_ptr();
894 let win_ptr = window.as_ptr();
895 let n4 = overlap2 & !3;
896 let mut i = 0;
897
898 while i < n4 {
899 let x2_fwd = vld1q_f32(out_ptr.add(i));
900 let x1_rev = vld1q_f32(out_ptr.add(overlap - 4 - i));
901
902 let x1 = vrev64q_f32(x1_rev);
903 let x1 = vextq_f32(x1, x1, 2);
904
905 let wp1_fwd = vld1q_f32(win_ptr.add(i));
906 let wp2_rev = vld1q_f32(win_ptr.add(overlap - 4 - i));
907 let wp2 = vrev64q_f32(wp2_rev);
908 let wp2 = vextq_f32(wp2, wp2, 2);
909 let wp1 = wp1_fwd;
910
911 let out_fwd = vsubq_f32(vmulq_f32(x2_fwd, wp2), vmulq_f32(x1, wp1));
912
913 let out_rev = vaddq_f32(vmulq_f32(x2_fwd, wp1), vmulq_f32(x1, wp2));
914
915 let out_rev = vrev64q_f32(out_rev);
916 let out_rev = vextq_f32(out_rev, out_rev, 2);
917
918 vst1q_f32(out_ptr.add(i), out_fwd);
919 vst1q_f32(out_ptr.add(overlap - 4 - i), out_rev);
920
921 i += 4;
922 }
923
924 for i in n4..overlap2 {
925 let x1 = output[overlap - 1 - i];
926 let x2 = output[i];
927 output[i] = x2 * window[overlap - 1 - i] - x1 * window[i];
928 output[overlap - 1 - i] = x2 * window[i] + x1 * window[overlap - 1 - i];
929 }
930 }
931}
932
933#[cfg(test)]
934mod mdct_tests {
935 #[test]
936 fn test_mdct_backward_transient_no_blowup() {
937 let mode = crate::modes::default_mode();
938 let shift = 3;
939 let n = mode.mdct.n >> shift; let overlap = mode.overlap; let stride = 8;
942
943 let frame_size = 960usize;
944 let mut freq = vec![0.0f32; frame_size];
945 for i in 0..frame_size {
946 freq[i] = ((i as f32) * 0.01).sin() * 10.0;
947 }
948
949 let out_len = n + overlap; let mut output0 = vec![0.0f32; out_len];
951 let mut output1 = vec![0.0f32; out_len];
952
953 mode.mdct.backward(
954 &freq[0..],
955 &mut output0,
956 mode.window,
957 overlap,
958 shift,
959 stride,
960 );
961 mode.mdct.backward(
962 &freq[1..],
963 &mut output1,
964 mode.window,
965 overlap,
966 shift,
967 stride,
968 );
969
970 let max0 = output0.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
971 let max1 = output1.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
972 eprintln!("sub0 max={} sub1 max={}", max0, max1);
973 eprintln!("sub0[60..70]={:?}", &output0[60..70]);
974 eprintln!("sub1[60..70]={:?}", &output1[60..70]);
975
976 assert!(max0.abs() < 500.0, "sub0 blowup: {}", max0);
977 assert!(max1.abs() < 500.0, "sub1 blowup: {}", max1);
978 }
979
980 #[test]
981 fn test_mdct_backward_stride1_neon_matches_scalar() {
982 let mode = crate::modes::default_mode();
983 let shift = 0; let n = mode.mdct.n >> shift; let n2 = n / 2; let n4 = n / 4; let overlap = mode.overlap; let overlap2 = overlap / 2; let stride = 1;
990
991 let freq_len = n2;
992 let mut freq = vec![0.0f32; freq_len + 4];
993 for i in 0..freq_len {
994 freq[i] = ((i as f32) * 0.01).sin() * 4577.0;
995 }
996
997 let out_len = overlap2 + n2; let mut output_hw = vec![0.0f32; out_len + 100];
999 mode.mdct.backward(
1000 &freq[..],
1001 &mut output_hw,
1002 mode.window,
1003 overlap,
1004 shift,
1005 stride,
1006 );
1007
1008 let st = mode.mdct.kfft[shift].as_ref().unwrap();
1009 let (trig, _) = mode.mdct.get_trig(shift);
1010
1011 use crate::kiss_fft::KissCpx;
1012 let mut f2 = vec![KissCpx::new(0.0, 0.0); n4];
1013 for i in 0..n4 {
1014 let rev = st.bitrev[i] as usize;
1015 let x1 = freq[2 * i * stride];
1016 let x2 = freq[stride * (n2 - 1 - 2 * i)];
1017 let t0 = trig[i];
1018 let t1 = trig[n4 + i];
1019 let yr = x2 * t0 + x1 * t1;
1020 let yi = x1 * t0 - x2 * t1;
1021 f2[rev] = KissCpx::new(yi, yr);
1022 }
1023 crate::kiss_fft::opus_fft_impl(st, &mut f2);
1024
1025 let mut output_scalar = vec![0.0f32; out_len + 100];
1026 for i in 0..((n4 + 1) >> 1) {
1027 let im0 = f2[i].r;
1028 let re0 = f2[i].i;
1029 let t0_0 = trig[i];
1030 let t1_0 = trig[n4 + i];
1031 let yr0 = re0 * t0_0 + im0 * t1_0;
1032 let yi0 = re0 * t1_0 - im0 * t0_0;
1033 let j = n4 - 1 - i;
1034 let im1 = f2[j].r;
1035 let re1 = f2[j].i;
1036 let t0_1 = trig[j];
1037 let t1_1 = trig[n4 + j];
1038 let yr1 = re1 * t0_1 + im1 * t1_1;
1039 let yi1 = re1 * t1_1 - im1 * t0_1;
1040 output_scalar[overlap2 + 2 * i] = yr0;
1041 output_scalar[overlap2 + n2 - 1 - 2 * i] = yi0;
1042 output_scalar[overlap2 + n2 - 2 - 2 * i] = yr1;
1043 output_scalar[overlap2 + 2 * i + 1] = yi1;
1044 }
1045 for i in 0..overlap2 {
1047 let x1 = output_scalar[overlap - 1 - i];
1048 let x2 = output_scalar[i];
1049 let wp1 = mode.window[i];
1050 let wp2 = mode.window[overlap - 1 - i];
1051 output_scalar[i] = x2 * wp2 - x1 * wp1;
1052 output_scalar[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
1053 }
1054
1055 let max_diff = output_hw[..out_len]
1056 .iter()
1057 .zip(output_scalar[..out_len].iter())
1058 .map(|(a, b)| (a - b).abs())
1059 .fold(0.0f32, f32::max);
1060 assert!(
1061 max_diff < 0.5,
1062 "stride=1 NEON vs scalar mismatch: max_diff={}",
1063 max_diff
1064 );
1065 }
1066
1067 #[test]
1068 fn test_mdct_backward_neon_matches_scalar() {
1069 let mode = crate::modes::default_mode();
1070 let shift = 3;
1071 let n = mode.mdct.n >> shift; let n2 = n / 2; let n4 = n / 4; let overlap = mode.overlap; let overlap2 = overlap / 2; let stride = 8;
1077
1078 let frame_size = 960usize;
1080 let mut freq = vec![0.0f32; frame_size];
1081 for i in 0..frame_size {
1082 freq[i] = ((i as f32) * 0.01).sin() * 200.0;
1083 }
1084
1085 let out_len = n + overlap; let mut output_hw = vec![0.0f32; out_len];
1087 mode.mdct.backward(
1088 &freq[0..],
1089 &mut output_hw,
1090 mode.window,
1091 overlap,
1092 shift,
1093 stride,
1094 );
1095
1096 let st = mode.mdct.kfft[shift].as_ref().unwrap();
1098 let (trig, _) = mode.mdct.get_trig(shift);
1099
1100 use crate::kiss_fft::KissCpx;
1101 let mut f2 = vec![KissCpx::new(0.0, 0.0); n4];
1102 for i in 0..n4 {
1103 let rev = st.bitrev[i] as usize;
1104 let x1 = freq[2 * i * stride];
1105 let x2 = freq[stride * (n2 - 1 - 2 * i)];
1106 let t0 = trig[i];
1107 let t1 = trig[n4 + i];
1108 let yr = x2 * t0 + x1 * t1;
1109 let yi = x1 * t0 - x2 * t1;
1110 f2[rev] = KissCpx::new(yi, yr);
1111 }
1112 crate::kiss_fft::opus_fft_impl(st, &mut f2);
1113
1114 let mut output_scalar = vec![0.0f32; out_len];
1115 for i in 0..((n4 + 1) >> 1) {
1116 let im0 = f2[i].r;
1117 let re0 = f2[i].i;
1118 let t0_0 = trig[i];
1119 let t1_0 = trig[n4 + i];
1120 let yr0 = re0 * t0_0 + im0 * t1_0;
1121 let yi0 = re0 * t1_0 - im0 * t0_0;
1122 let j = n4 - 1 - i;
1123 let im1 = f2[j].r;
1124 let re1 = f2[j].i;
1125 let t0_1 = trig[j];
1126 let t1_1 = trig[n4 + j];
1127 let yr1 = re1 * t0_1 + im1 * t1_1;
1128 let yi1 = re1 * t1_1 - im1 * t0_1;
1129 output_scalar[overlap2 + 2 * i] = yr0;
1130 output_scalar[overlap2 + n2 - 1 - 2 * i] = yi0;
1131 output_scalar[overlap2 + n2 - 2 - 2 * i] = yr1;
1132 output_scalar[overlap2 + 2 * i + 1] = yi1;
1133 }
1134 for i in 0..overlap2 {
1136 let x1 = output_scalar[overlap - 1 - i];
1137 let x2 = output_scalar[i];
1138 let wp1 = mode.window[i];
1139 let wp2 = mode.window[overlap - 1 - i];
1140 output_scalar[i] = x2 * wp2 - x1 * wp1;
1141 output_scalar[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
1142 }
1143
1144 for i in 0..out_len {
1145 let diff = (output_hw[i] - output_scalar[i]).abs();
1146 if diff > 1e-3 {
1147 eprintln!(
1148 "Mismatch at output[{}]: hw={} scalar={} diff={}",
1149 i, output_hw[i], output_scalar[i], diff
1150 );
1151 }
1152 }
1153 let max_diff = output_hw
1154 .iter()
1155 .zip(output_scalar.iter())
1156 .map(|(a, b)| (a - b).abs())
1157 .fold(0.0f32, f32::max);
1158 assert!(
1159 max_diff < 0.1,
1160 "NEON/HW vs scalar mismatch: max_diff={}",
1161 max_diff
1162 );
1163 }
1164}