1use crate::kiss_fft::{KissCpx, KissFftState, opus_fft_impl};
2use std::f32::consts::PI;
3use std::mem::MaybeUninit;
4
5const MAX_N2: usize = 960;
6const MAX_N4: usize = 480;
7
8pub struct MdctLookup {
9 pub n: usize,
10 pub max_lm: usize,
11 kfft: Vec<Option<KissFftState>>,
12 trig: Vec<f32>,
13}
14
15impl MdctLookup {
16 pub fn new(n: usize, max_lm: usize) -> Self {
17 let mut kfft = Vec::new();
18 let mut trig = Vec::new();
19 let mut curr_n = n;
20
21 for shift in 0..=max_lm {
22 let n4 = curr_n / 4;
23
24 if shift == 0 {
25 kfft.push(KissFftState::new(n4));
26 } else if let Some(base) = kfft.first().unwrap().as_ref() {
27 kfft.push(KissFftState::new_sub(base, n4));
28 } else {
29 kfft.push(None);
30 }
31
32 let n2 = curr_n / 2;
33 for i in 0..n2 {
34 let angle = 2.0 * PI * (i as f32 + 0.125) / curr_n as f32;
35 trig.push(angle.cos());
36 }
37
38 curr_n >>= 1;
39 }
40
41 Self {
42 n,
43 max_lm,
44 kfft,
45 trig,
46 }
47 }
48
49 fn get_trig(&self, shift: usize) -> (&[f32], usize) {
50 let mut offset = 0;
51 let mut curr_n = self.n;
52 for _ in 0..shift {
53 offset += curr_n / 2;
54 curr_n >>= 1;
55 }
56 (&self.trig[offset..offset + curr_n / 2], curr_n / 4)
57 }
58
59 pub fn get_trig_debug(&self, shift: usize) -> &[f32] {
60 let (trig, _) = self.get_trig(shift);
61 trig
62 }
63
64 #[inline]
65 pub fn forward(
66 &self,
67 input: &[f32],
68 output: &mut [f32],
69 window: &[f32],
70 overlap: usize,
71 shift: usize,
72 stride: usize,
73 ) {
74 let st = self.kfft[shift]
75 .as_ref()
76 .expect("FFT state not initialized");
77 let n = self.n >> shift;
78 let n2 = n / 2;
79 let n4 = n / 4;
80 let scale = st.scale();
81
82 let (trig, _) = self.get_trig(shift);
83 let overlap2 = overlap / 2;
84
85 let mut f_buf = [MaybeUninit::<f32>::uninit(); MAX_N2];
86 let mut f2_buf = [MaybeUninit::<KissCpx>::uninit(); MAX_N4];
87
88 let f = unsafe { std::slice::from_raw_parts_mut(f_buf.as_mut_ptr() as *mut f32, n2) };
89 let f2 = unsafe { std::slice::from_raw_parts_mut(f2_buf.as_mut_ptr() as *mut KissCpx, n4) };
90
91 assert!(input.len() >= n2 + overlap2);
92 assert!(window.len() >= overlap);
93
94 {
95 let mut yp = 0usize;
96 let mut xp1 = overlap2;
97 let mut xp2 = n2 - 1 + overlap2;
98 let mut wp1 = overlap2;
99
100 let mut wp2 = overlap2.saturating_sub(1);
101
102 let limit = overlap.div_ceil(4);
103 let mid = n4.saturating_sub(limit);
104
105 let loop1_iters = limit.min(n4);
106 for _ in 0..loop1_iters {
107 let w1 = window[wp1];
108 let w2 = window[wp2];
109
110 f[yp] = input[xp1 + n2] * w2 + input[xp2] * w1;
111 yp += 1;
112
113 f[yp] = input[xp1] * w1 - input[xp2 - n2] * w2;
114 yp += 1;
115
116 xp1 += 2;
117 xp2 -= 2;
118 wp1 += 2;
119 wp2 = wp2.saturating_sub(2);
120 }
121
122 for _ in limit..mid {
123 f[yp] = input[xp2];
124 yp += 1;
125
126 f[yp] = input[xp1];
127 yp += 1;
128 xp1 += 2;
129 xp2 -= 2;
130 }
131
132 let loop3_iters = if mid > limit { n4 - mid } else { 0 };
133 let mut wp1_l3 = 0usize;
134 let mut wp2_l3 = overlap.saturating_sub(1);
135 for _ in 0..loop3_iters {
136 let w1 = window[wp1_l3];
137 let w2 = window[wp2_l3];
138
139 f[yp] = -input[xp1 - n2] * w1 + input[xp2] * w2;
140 yp += 1;
141
142 f[yp] = input[xp1] * w2 + input[xp2 + n2] * w1;
143 yp += 1;
144
145 xp1 += 2;
146 xp2 -= 2;
147 wp1_l3 += 2;
148 wp2_l3 -= 2;
149 }
150 }
151
152 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
153 unsafe {
154 if std::arch::is_x86_feature_detected!("avx") {
155 mdct_pre_rotation_avx(f, f2, trig, &st.bitrev[..n4], n4, scale);
156 } else {
157 for i in 0..n4 {
158 let re = f[2 * i];
159 let im = f[2 * i + 1];
160 let t0 = trig[i];
161 let t1 = trig[n4 + i];
162
163 let yr = re * t0 - im * t1;
164 let yi = im * t0 + re * t1;
165
166 f2[st.bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
167 }
168 }
169 }
170 #[cfg(all(
171 not(any(target_arch = "x86", target_arch = "x86_64")),
172 target_arch = "aarch64"
173 ))]
174 {
175 mdct_pre_rotation_neon(f, f2, trig, &st.bitrev[..n4], n4, scale);
176 }
177 #[cfg(all(
178 not(any(target_arch = "x86", target_arch = "x86_64")),
179 not(target_arch = "aarch64")
180 ))]
181 for i in 0..n4 {
182 let re = f[2 * i];
183 let im = f[2 * i + 1];
184 let t0 = trig[i];
185 let t1 = trig[n4 + i];
186
187 let yr = re * t0 - im * t1;
188 let yi = im * t0 + re * t1;
189
190 f2[st.bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
191 }
192
193 opus_fft_impl(st, f2);
194
195 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
196 unsafe {
197 if std::arch::is_x86_feature_detected!("avx") {
198 mdct_post_rotation_avx(f2, trig, output, n4, n2, stride);
199 } else {
200 for i in 0..n4 {
201 let fp = &f2[i];
202 let t0 = trig[i];
203 let t1 = trig[n4 + i];
204
205 let yr = fp.i * t1 - fp.r * t0;
206 let yi = fp.r * t1 + fp.i * t0;
207
208 output[i * 2 * stride] = yr;
209 output[stride * (n2 - 1 - 2 * i)] = yi;
210 }
211 }
212 }
213 #[cfg(all(
214 not(any(target_arch = "x86", target_arch = "x86_64")),
215 target_arch = "aarch64"
216 ))]
217 {
218 mdct_post_rotation_neon(f2, trig, output, n4, n2, stride);
219 }
220 #[cfg(all(
221 not(any(target_arch = "x86", target_arch = "x86_64")),
222 not(target_arch = "aarch64")
223 ))]
224 for i in 0..n4 {
225 let fp = &f2[i];
226 let t0 = trig[i];
227 let t1 = trig[n4 + i];
228
229 let yr = fp.i * t1 - fp.r * t0;
230 let yi = fp.r * t1 + fp.i * t0;
231
232 output[i * 2 * stride] = yr;
233 output[stride * (n2 - 1 - 2 * i)] = yi;
234 }
235 }
236
237 #[inline]
238 pub fn backward(
239 &self,
240 input: &[f32],
241 output: &mut [f32],
242 window: &[f32],
243 overlap: usize,
244 shift: usize,
245 stride: usize,
246 ) {
247 let st = self.kfft[shift]
248 .as_ref()
249 .expect("FFT state not initialized");
250 let n = self.n >> shift;
251 let n2 = n / 2;
252 let n4 = n / 4;
253 let overlap2 = overlap / 2;
254
255 let (trig, _) = self.get_trig(shift);
256
257 let mut f2_buf = [MaybeUninit::<KissCpx>::uninit(); MAX_N4];
258
259 let f2 = unsafe { std::slice::from_raw_parts_mut(f2_buf.as_mut_ptr() as *mut KissCpx, n4) };
260
261 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
262 unsafe {
263 if std::arch::is_x86_feature_detected!("avx") {
264 mdct_backward_pre_rotation_avx(input, f2, trig, &st.bitrev[..n4], n4, n2, stride);
265 } else {
266 for i in 0..n4 {
267 let rev = st.bitrev[i] as usize;
268 let x1 = input[2 * i * stride];
269 let x2 = input[stride * (n2 - 1 - 2 * i)];
270 let t0 = trig[i];
271 let t1 = trig[n4 + i];
272
273 let yr = x2 * t0 + x1 * t1;
274 let yi = x1 * t0 - x2 * t1;
275
276 f2[rev] = KissCpx::new(yi, yr);
277 }
278 }
279 }
280 #[cfg(all(
281 not(any(target_arch = "x86", target_arch = "x86_64")),
282 target_arch = "aarch64"
283 ))]
284 {
285 mdct_backward_pre_rotation_neon(input, f2, trig, &st.bitrev[..n4], n4, n2, stride);
286 }
287 #[cfg(all(
288 not(any(target_arch = "x86", target_arch = "x86_64")),
289 not(target_arch = "aarch64")
290 ))]
291 for i in 0..n4 {
292 let rev = st.bitrev[i] as usize;
293 let x1 = input[2 * i * stride];
294 let x2 = input[stride * (n2 - 1 - 2 * i)];
295 let t0 = trig[i];
296 let t1 = trig[n4 + i];
297
298 let yr = x2 * t0 + x1 * t1;
299 let yi = x1 * t0 - x2 * t1;
300
301 f2[rev] = KissCpx::new(yi, yr);
302 }
303
304 opus_fft_impl(st, f2);
305
306 assert!(output.len() >= overlap2 + n2);
307
308 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
309 unsafe {
310 if std::arch::is_x86_feature_detected!("avx") {
311 mdct_backward_post_rotation_avx(f2, trig, output, n4, n2, overlap2);
312 } else {
313 for i in 0..((n4 + 1) >> 1) {
314 let im0 = f2[i].r;
315 let re0 = f2[i].i;
316 let t0_0 = trig[i];
317 let t1_0 = trig[n4 + i];
318
319 let yr0 = re0 * t0_0 + im0 * t1_0;
320 let yi0 = re0 * t1_0 - im0 * t0_0;
321
322 let j = n4 - 1 - i;
323 let im1 = f2[j].r;
324 let re1 = f2[j].i;
325 let t0_1 = trig[j];
326 let t1_1 = trig[n4 + j];
327
328 let yr1 = re1 * t0_1 + im1 * t1_1;
329 let yi1 = re1 * t1_1 - im1 * t0_1;
330
331 output[overlap2 + 2 * i] = yr0;
332 output[overlap2 + n2 - 1 - 2 * i] = yi0;
333 output[overlap2 + n2 - 2 - 2 * i] = yr1;
334 output[overlap2 + 2 * i + 1] = yi1;
335 }
336 }
337 }
338 #[cfg(all(
339 not(any(target_arch = "x86", target_arch = "x86_64")),
340 target_arch = "aarch64"
341 ))]
342 {
343 mdct_backward_post_rotation_neon(f2, trig, output, n4, n2, overlap2);
344 }
345 #[cfg(all(
346 not(any(target_arch = "x86", target_arch = "x86_64")),
347 not(target_arch = "aarch64")
348 ))]
349 for i in 0..((n4 + 1) >> 1) {
350 let im0 = f2[i].r;
351 let re0 = f2[i].i;
352 let t0_0 = trig[i];
353 let t1_0 = trig[n4 + i];
354
355 let yr0 = re0 * t0_0 + im0 * t1_0;
356 let yi0 = re0 * t1_0 - im0 * t0_0;
357
358 let j = n4 - 1 - i;
359 let im1 = f2[j].r;
360 let re1 = f2[j].i;
361 let t0_1 = trig[j];
362 let t1_1 = trig[n4 + j];
363
364 let yr1 = re1 * t0_1 + im1 * t1_1;
365 let yi1 = re1 * t1_1 - im1 * t0_1;
366
367 output[overlap2 + 2 * i] = yr0;
368 output[overlap2 + n2 - 1 - 2 * i] = yi0;
369 output[overlap2 + n2 - 2 - 2 * i] = yr1;
370 output[overlap2 + 2 * i + 1] = yi1;
371 }
372
373 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
374 unsafe {
375 if std::arch::is_x86_feature_detected!("avx") {
376 mdct_tdac_avx(output, window, overlap);
377 } else {
378 for i in 0..overlap2 {
379 let x1 = output[overlap - 1 - i];
380 let x2 = output[i];
381 let wp1 = window[i];
382 let wp2 = window[overlap - 1 - i];
383
384 output[i] = x2 * wp2 - x1 * wp1;
385 output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
386 }
387 }
388 }
389 #[cfg(all(
390 not(any(target_arch = "x86", target_arch = "x86_64")),
391 target_arch = "aarch64"
392 ))]
393 {
394 mdct_tdac_neon(output, window, overlap);
395 }
396 #[cfg(all(
397 not(any(target_arch = "x86", target_arch = "x86_64")),
398 not(target_arch = "aarch64")
399 ))]
400 for i in 0..overlap2 {
401 let x1 = output[overlap - 1 - i];
402 let x2 = output[i];
403 let wp1 = window[i];
404 let wp2 = window[overlap - 1 - i];
405
406 output[i] = x2 * wp2 - x1 * wp1;
407 output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
408 }
409 }
410}
411
412#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
413#[target_feature(enable = "avx")]
414unsafe fn mdct_pre_rotation_avx(
415 f: &[f32],
416 f2: &mut [KissCpx],
417 trig: &[f32],
418 bitrev: &[i16],
419 n4: usize,
420 scale: f32,
421) {
422 for i in 0..n4 {
423 let re = f[2 * i];
424 let im = f[2 * i + 1];
425 let t0 = trig[i];
426 let t1 = trig[n4 + i];
427
428 let yr = re * t0 - im * t1;
429 let yi = im * t0 + re * t1;
430
431 f2[bitrev[i] as usize] = KissCpx::new(yr * scale, yi * scale);
432 }
433}
434
435#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
436#[target_feature(enable = "avx")]
437unsafe fn mdct_post_rotation_avx(
438 f2: &[KissCpx],
439 trig: &[f32],
440 output: &mut [f32],
441 n4: usize,
442 n2: usize,
443 stride: usize,
444) {
445 for i in 0..n4 {
446 let fp = &f2[i];
447 let t0 = trig[i];
448 let t1 = trig[n4 + i];
449
450 let yr = fp.i * t1 - fp.r * t0;
451 let yi = fp.r * t1 + fp.i * t0;
452
453 output[i * 2 * stride] = yr;
454 output[stride * (n2 - 1 - 2 * i)] = yi;
455 }
456}
457
458#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
459#[target_feature(enable = "avx")]
460unsafe fn mdct_backward_pre_rotation_avx(
461 input: &[f32],
462 f2: &mut [KissCpx],
463 trig: &[f32],
464 bitrev: &[i16],
465 n4: usize,
466 n2: usize,
467 stride: usize,
468) {
469 for i in 0..n4 {
470 let rev = bitrev[i] as usize;
471 let x1 = input[2 * i * stride];
472 let x2 = input[stride * (n2 - 1 - 2 * i)];
473 let t0 = trig[i];
474 let t1 = trig[n4 + i];
475
476 let yr = x2 * t0 + x1 * t1;
477 let yi = x1 * t0 - x2 * t1;
478
479 f2[rev] = KissCpx::new(yi, yr);
480 }
481}
482
483#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
484#[target_feature(enable = "avx")]
485unsafe fn mdct_backward_post_rotation_avx(
486 f2: &[KissCpx],
487 trig: &[f32],
488 output: &mut [f32],
489 n4: usize,
490 n2: usize,
491 overlap2: usize,
492) {
493 for i in 0..((n4 + 1) >> 1) {
494 let im0 = f2[i].r;
495 let re0 = f2[i].i;
496 let t0_0 = trig[i];
497 let t1_0 = trig[n4 + i];
498
499 let yr0 = re0 * t0_0 + im0 * t1_0;
500 let yi0 = re0 * t1_0 - im0 * t0_0;
501
502 let j = n4 - 1 - i;
503 let im1 = f2[j].r;
504 let re1 = f2[j].i;
505 let t0_1 = trig[j];
506 let t1_1 = trig[n4 + j];
507
508 let yr1 = re1 * t0_1 + im1 * t1_1;
509 let yi1 = re1 * t1_1 - im1 * t0_1;
510
511 output[overlap2 + 2 * i] = yr0;
512 output[overlap2 + n2 - 1 - 2 * i] = yi0;
513 output[overlap2 + n2 - 2 - 2 * i] = yr1;
514 output[overlap2 + 2 * i + 1] = yi1;
515 }
516}
517
518#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
519#[target_feature(enable = "avx")]
520unsafe fn mdct_tdac_avx(output: &mut [f32], window: &[f32], overlap: usize) {
521 use std::arch::x86_64::*;
522
523 let overlap2 = overlap / 2;
524 let mut i = 0usize;
525
526 while i + 8 <= overlap2 {
527 let x2 = _mm256_loadu_ps(output.as_ptr().add(i));
528
529 let mut x1_tmp = [0f32; 8];
530 let mut w2_tmp = [0f32; 8];
531 for j in 0..8 {
532 x1_tmp[j] = output[overlap - 1 - (i + j)];
533 w2_tmp[j] = window[overlap - 1 - (i + j)];
534 }
535 let x1 = _mm256_loadu_ps(x1_tmp.as_ptr());
536
537 let w1 = _mm256_loadu_ps(window.as_ptr().add(i));
538 let w2 = _mm256_loadu_ps(w2_tmp.as_ptr());
539
540 let out_fwd = _mm256_sub_ps(_mm256_mul_ps(x2, w2), _mm256_mul_ps(x1, w1));
541 let out_rev = _mm256_add_ps(_mm256_mul_ps(x2, w1), _mm256_mul_ps(x1, w2));
542
543 _mm256_storeu_ps(output.as_mut_ptr().add(i), out_fwd);
544
545 let mut out_rev_tmp = [0f32; 8];
546 _mm256_storeu_ps(out_rev_tmp.as_mut_ptr(), out_rev);
547 for j in 0..8 {
548 output[overlap - 1 - (i + j)] = out_rev_tmp[j];
549 }
550
551 i += 8;
552 }
553
554 for i in i..overlap2 {
555 let x1 = output[overlap - 1 - i];
556 let x2 = output[i];
557 let wp1 = window[i];
558 let wp2 = window[overlap - 1 - i];
559 output[i] = x2 * wp2 - x1 * wp1;
560 output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
561 }
562}
563
564#[cfg(target_arch = "aarch64")]
565#[inline(always)]
566fn mdct_pre_rotation_neon(
567 f: &[f32],
568 f2: &mut [KissCpx],
569 trig: &[f32],
570 bitrev: &[i16],
571 n4: usize,
572 scale: f32,
573) {
574 use std::arch::aarch64::*;
575
576 unsafe {
577 let vscale = vdupq_n_f32(scale);
578 let f_ptr = f.as_ptr();
579 let trig_ptr = trig.as_ptr();
580 let bitrev_ptr = bitrev.as_ptr();
581 let f2_ptr = f2.as_mut_ptr() as *mut f32;
582
583 let n4_vec = n4 & !3;
584 let mut i = 0;
585
586 while i < n4_vec {
587 let t0 = vld1q_f32(trig_ptr.add(i));
588 let t1 = vld1q_f32(trig_ptr.add(n4 + i));
589
590 let f0 = vld1q_f32(f_ptr.add(2 * i));
591 let f1 = vld1q_f32(f_ptr.add(2 * i + 4));
592
593 let even_odd = vuzpq_f32(f0, f1);
594 let re_v = even_odd.0;
595 let im_v = even_odd.1;
596
597 let yr = vsubq_f32(vmulq_f32(re_v, t0), vmulq_f32(im_v, t1));
598 let yi = vaddq_f32(vmulq_f32(im_v, t0), vmulq_f32(re_v, t1));
599
600 let yr = vmulq_f32(yr, vscale);
601 let yi = vmulq_f32(yi, vscale);
602
603 let yr_arr: [f32; 4] = std::mem::transmute(yr);
604 let yi_arr: [f32; 4] = std::mem::transmute(yi);
605
606 for j in 0..4 {
607 let rev = *bitrev_ptr.add(i + j) as usize;
608 *f2_ptr.add(2 * rev) = yr_arr[j];
609 *f2_ptr.add(2 * rev + 1) = yi_arr[j];
610 }
611
612 i += 4;
613 }
614
615 for i in n4_vec..n4 {
616 let re = *f_ptr.add(2 * i);
617 let im = *f_ptr.add(2 * i + 1);
618 let t0 = *trig_ptr.add(i);
619 let t1 = *trig_ptr.add(n4 + i);
620 let yr = re * t0 - im * t1;
621 let yi = im * t0 + re * t1;
622 let rev = *bitrev_ptr.add(i) as usize;
623 *f2_ptr.add(2 * rev) = yr * scale;
624 *f2_ptr.add(2 * rev + 1) = yi * scale;
625 }
626 }
627}
628
629#[cfg(target_arch = "aarch64")]
630#[inline(always)]
631fn mdct_post_rotation_neon(
632 f2: &[KissCpx],
633 trig: &[f32],
634 output: &mut [f32],
635 n4: usize,
636 n2: usize,
637 stride: usize,
638) {
639 use std::arch::aarch64::*;
640
641 if stride > 1 {
642 for i in 0..n4 {
643 let fp = &f2[i];
644 let t0 = trig[i];
645 let t1 = trig[n4 + i];
646 let yr = fp.i * t1 - fp.r * t0;
647 let yi = fp.r * t1 + fp.i * t0;
648 output[i * 2 * stride] = yr;
649 output[stride * (n2 - 1 - 2 * i)] = yi;
650 }
651 return;
652 }
653
654 unsafe {
655 let f2_ptr = f2.as_ptr() as *const f32;
656 let trig_ptr = trig.as_ptr();
657 let out_ptr = output.as_mut_ptr();
658
659 let n4_vec = n4 & !3;
660 let mut i = 0;
661
662 while i < n4_vec {
663 let c0 = vld1q_f32(f2_ptr.add(2 * i));
664 let c1 = vld1q_f32(f2_ptr.add(2 * i + 4));
665
666 let t0 = vld1q_f32(trig_ptr.add(i));
667 let t1 = vld1q_f32(trig_ptr.add(n4 + i));
668
669 let ri = vuzpq_f32(c0, c1);
670 let r_v = ri.0;
671 let i_v = ri.1;
672
673 let yr = vsubq_f32(vmulq_f32(i_v, t1), vmulq_f32(r_v, t0));
674
675 let yi = vaddq_f32(vmulq_f32(r_v, t1), vmulq_f32(i_v, t0));
676
677 let yr_arr: [f32; 4] = std::mem::transmute(yr);
678 let yi_arr: [f32; 4] = std::mem::transmute(yi);
679
680 for j in 0..4 {
681 *out_ptr.add((i + j) * 2) = yr_arr[j];
682 *out_ptr.add(n2 - 1 - 2 * (i + j)) = yi_arr[j];
683 }
684
685 i += 4;
686 }
687
688 for i in n4_vec..n4 {
689 let fp = &f2[i];
690 let t0 = trig[i];
691 let t1 = trig[n4 + i];
692 let yr = fp.i * t1 - fp.r * t0;
693 let yi = fp.r * t1 + fp.i * t0;
694 output[i * 2] = yr;
695 output[n2 - 1 - 2 * i] = yi;
696 }
697 }
698}
699
700#[cfg(target_arch = "aarch64")]
701#[inline(always)]
702fn mdct_backward_pre_rotation_neon(
703 input: &[f32],
704 f2: &mut [KissCpx],
705 trig: &[f32],
706 bitrev: &[i16],
707 n4: usize,
708 n2: usize,
709 stride: usize,
710) {
711 use std::arch::aarch64::*;
712
713 if stride != 1 {
714 for i in 0..n4 {
715 let rev = bitrev[i] as usize;
716 let x1 = input[2 * i * stride];
717 let x2 = input[stride * (n2 - 1 - 2 * i)];
718 let t0 = trig[i];
719 let t1 = trig[n4 + i];
720 let yr = x2 * t0 + x1 * t1;
721 let yi = x1 * t0 - x2 * t1;
722 f2[rev] = KissCpx::new(yi, yr);
723 }
724 return;
725 }
726
727 unsafe {
728 let in_ptr = input.as_ptr();
729 let trig_ptr = trig.as_ptr();
730 let bitrev_ptr = bitrev.as_ptr();
731 let f2_ptr = f2.as_mut_ptr() as *mut f32;
732
733 let n4_vec = n4 & !3;
734 let mut i = 0;
735
736 while i < n4_vec {
737 let f0 = vld1q_f32(in_ptr.add(2 * i));
738 let f1 = vld1q_f32(in_ptr.add(2 * i + 4));
739 let deint_x1 = vuzpq_f32(f0, f1);
740 let x1_v = deint_x1.0;
741
742 let g0 = vld1q_f32(in_ptr.add(n2 - 7 - 2 * i));
743 let g1 = vld1q_f32(in_ptr.add(n2 - 3 - 2 * i));
744 let deint_x2 = vuzpq_f32(g0, g1);
745
746 let x2_raw = deint_x2.0;
747 let x2_v = vrev64q_f32(x2_raw);
748 let x2_v = vextq_f32(x2_v, x2_v, 2);
749
750 let t0 = vld1q_f32(trig_ptr.add(i));
751 let t1 = vld1q_f32(trig_ptr.add(n4 + i));
752
753 let yr = vaddq_f32(vmulq_f32(x2_v, t0), vmulq_f32(x1_v, t1));
754 let yi = vsubq_f32(vmulq_f32(x1_v, t0), vmulq_f32(x2_v, t1));
755
756 let yr_arr: [f32; 4] = std::mem::transmute(yr);
757 let yi_arr: [f32; 4] = std::mem::transmute(yi);
758
759 for j in 0..4 {
760 let rev = *bitrev_ptr.add(i + j) as usize;
761 *f2_ptr.add(2 * rev) = yi_arr[j];
762 *f2_ptr.add(2 * rev + 1) = yr_arr[j];
763 }
764
765 i += 4;
766 }
767
768 for i in n4_vec..n4 {
769 let rev = *bitrev_ptr.add(i) as usize;
770 let x1 = *in_ptr.add(2 * i);
771 let x2 = *in_ptr.add(n2 - 1 - 2 * i);
772 let t0 = *trig_ptr.add(i);
773 let t1 = *trig_ptr.add(n4 + i);
774 let yr = x2 * t0 + x1 * t1;
775 let yi = x1 * t0 - x2 * t1;
776 *f2_ptr.add(2 * rev) = yi;
777 *f2_ptr.add(2 * rev + 1) = yr;
778 }
779 }
780}
781
782#[cfg(target_arch = "aarch64")]
783#[inline(always)]
784fn mdct_backward_post_rotation_neon(
785 f2: &[KissCpx],
786 trig: &[f32],
787 output: &mut [f32],
788 n4: usize,
789 n2: usize,
790 overlap2: usize,
791) {
792 unsafe {
793 let trig_ptr = trig.as_ptr();
794 let out_base = output.as_mut_ptr().add(overlap2);
795
796 let half = (n4 + 1) >> 1;
797
798 let mut i = 0;
799 while i + 1 < half {
800 let j0 = n4 - 1 - i;
801 let j1 = n4 - 1 - (i + 1);
802
803 let re0 = f2[i].i;
804 let im0 = f2[i].r;
805 let t0_0 = *trig_ptr.add(i);
806 let t1_0 = *trig_ptr.add(n4 + i);
807 let yr0 = re0 * t0_0 + im0 * t1_0;
808 let yi0 = re0 * t1_0 - im0 * t0_0;
809
810 let im1 = f2[j0].r;
811 let re1 = f2[j0].i;
812 let t0_1 = *trig_ptr.add(j0);
813 let t1_1 = *trig_ptr.add(n4 + j0);
814 let yr1 = re1 * t0_1 + im1 * t1_1;
815 let yi1 = re1 * t1_1 - im1 * t0_1;
816
817 *out_base.add(2 * i) = yr0;
818 *out_base.add(n2 - 1 - 2 * i) = yi0;
819 *out_base.add(n2 - 2 - 2 * i) = yr1;
820 *out_base.add(2 * i + 1) = yi1;
821
822 let re0b = f2[i + 1].i;
823 let im0b = f2[i + 1].r;
824 let t0_0b = *trig_ptr.add(i + 1);
825 let t1_0b = *trig_ptr.add(n4 + i + 1);
826 let yr0b = re0b * t0_0b + im0b * t1_0b;
827 let yi0b = re0b * t1_0b - im0b * t0_0b;
828
829 let im1b = f2[j1].r;
830 let re1b = f2[j1].i;
831 let t0_1b = *trig_ptr.add(j1);
832 let t1_1b = *trig_ptr.add(n4 + j1);
833 let yr1b = re1b * t0_1b + im1b * t1_1b;
834 let yi1b = re1b * t1_1b - im1b * t0_1b;
835
836 *out_base.add(2 * (i + 1)) = yr0b;
837 *out_base.add(n2 - 1 - 2 * (i + 1)) = yi0b;
838 *out_base.add(n2 - 2 - 2 * (i + 1)) = yr1b;
839 *out_base.add(2 * (i + 1) + 1) = yi1b;
840
841 i += 2;
842 }
843
844 if i < half {
845 let j = n4 - 1 - i;
846 let im0 = f2[i].r;
847 let re0 = f2[i].i;
848 let t0_0 = *trig_ptr.add(i);
849 let t1_0 = *trig_ptr.add(n4 + i);
850 let yr0 = re0 * t0_0 + im0 * t1_0;
851 let yi0 = re0 * t1_0 - im0 * t0_0;
852
853 let im1 = f2[j].r;
854 let re1 = f2[j].i;
855 let t0_1 = *trig_ptr.add(j);
856 let t1_1 = *trig_ptr.add(n4 + j);
857 let yr1 = re1 * t0_1 + im1 * t1_1;
858 let yi1 = re1 * t1_1 - im1 * t0_1;
859
860 *out_base.add(2 * i) = yr0;
861 *out_base.add(n2 - 1 - 2 * i) = yi0;
862 *out_base.add(n2 - 2 - 2 * i) = yr1;
863 *out_base.add(2 * i + 1) = yi1;
864 }
865 }
866}
867
868#[cfg(target_arch = "aarch64")]
869#[inline(always)]
870fn mdct_tdac_neon(output: &mut [f32], window: &[f32], overlap: usize) {
871 use std::arch::aarch64::*;
872
873 let overlap2 = overlap / 2;
874 if overlap2 < 4 {
875 for i in 0..overlap2 {
876 let x1 = output[overlap - 1 - i];
877 let x2 = output[i];
878 let wp1 = window[i];
879 let wp2 = window[overlap - 1 - i];
880 output[i] = x2 * wp2 - x1 * wp1;
881 output[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
882 }
883 return;
884 }
885
886 unsafe {
887 let out_ptr = output.as_mut_ptr();
888 let win_ptr = window.as_ptr();
889 let n4 = overlap2 & !3;
890 let mut i = 0;
891
892 while i < n4 {
893 let x2_fwd = vld1q_f32(out_ptr.add(i));
894 let x1_rev = vld1q_f32(out_ptr.add(overlap - 4 - i));
895
896 let x1 = vrev64q_f32(x1_rev);
897 let x1 = vextq_f32(x1, x1, 2);
898
899 let wp1_fwd = vld1q_f32(win_ptr.add(i));
900 let wp2_rev = vld1q_f32(win_ptr.add(overlap - 4 - i));
901 let wp2 = vrev64q_f32(wp2_rev);
902 let wp2 = vextq_f32(wp2, wp2, 2);
903 let wp1 = wp1_fwd;
904
905 let out_fwd = vsubq_f32(vmulq_f32(x2_fwd, wp2), vmulq_f32(x1, wp1));
906
907 let out_rev = vaddq_f32(vmulq_f32(x2_fwd, wp1), vmulq_f32(x1, wp2));
908
909 let out_rev = vrev64q_f32(out_rev);
910 let out_rev = vextq_f32(out_rev, out_rev, 2);
911
912 vst1q_f32(out_ptr.add(i), out_fwd);
913 vst1q_f32(out_ptr.add(overlap - 4 - i), out_rev);
914
915 i += 4;
916 }
917
918 for i in n4..overlap2 {
919 let x1 = output[overlap - 1 - i];
920 let x2 = output[i];
921 output[i] = x2 * window[overlap - 1 - i] - x1 * window[i];
922 output[overlap - 1 - i] = x2 * window[i] + x1 * window[overlap - 1 - i];
923 }
924 }
925}
926
927#[cfg(test)]
928mod mdct_tests {
929 #[test]
930 fn test_mdct_backward_transient_no_blowup() {
931 let mode = crate::modes::default_mode();
932 let shift = 3;
933 let n = mode.mdct.n >> shift; let overlap = mode.overlap; let stride = 8;
936
937 let frame_size = 960usize;
938 let mut freq = vec![0.0f32; frame_size];
939 for i in 0..frame_size {
940 freq[i] = ((i as f32) * 0.01).sin() * 10.0;
941 }
942
943 let out_len = n + overlap; let mut output0 = vec![0.0f32; out_len];
945 let mut output1 = vec![0.0f32; out_len];
946
947 mode.mdct.backward(
948 &freq[0..],
949 &mut output0,
950 mode.window,
951 overlap,
952 shift,
953 stride,
954 );
955 mode.mdct.backward(
956 &freq[1..],
957 &mut output1,
958 mode.window,
959 overlap,
960 shift,
961 stride,
962 );
963
964 let max0 = output0.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
965 let max1 = output1.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
966 eprintln!("sub0 max={} sub1 max={}", max0, max1);
967 eprintln!("sub0[60..70]={:?}", &output0[60..70]);
968 eprintln!("sub1[60..70]={:?}", &output1[60..70]);
969
970 assert!(max0.abs() < 500.0, "sub0 blowup: {}", max0);
971 assert!(max1.abs() < 500.0, "sub1 blowup: {}", max1);
972 }
973
974 #[test]
975 fn test_mdct_backward_stride1_neon_matches_scalar() {
976 let mode = crate::modes::default_mode();
977 let shift = 0; let n = mode.mdct.n >> shift; let n2 = n / 2; let n4 = n / 4; let overlap = mode.overlap; let overlap2 = overlap / 2; let stride = 1;
984
985 let freq_len = n2;
986 let mut freq = vec![0.0f32; freq_len + 4];
987 for i in 0..freq_len {
988 freq[i] = ((i as f32) * 0.01).sin() * 4577.0;
989 }
990
991 let out_len = overlap2 + n2; let mut output_hw = vec![0.0f32; out_len + 100];
993 mode.mdct.backward(
994 &freq[..],
995 &mut output_hw,
996 mode.window,
997 overlap,
998 shift,
999 stride,
1000 );
1001
1002 let st = mode.mdct.kfft[shift].as_ref().unwrap();
1003 let (trig, _) = mode.mdct.get_trig(shift);
1004
1005 use crate::kiss_fft::KissCpx;
1006 let mut f2 = vec![KissCpx::new(0.0, 0.0); n4];
1007 for i in 0..n4 {
1008 let rev = st.bitrev[i] as usize;
1009 let x1 = freq[2 * i * stride];
1010 let x2 = freq[stride * (n2 - 1 - 2 * i)];
1011 let t0 = trig[i];
1012 let t1 = trig[n4 + i];
1013 let yr = x2 * t0 + x1 * t1;
1014 let yi = x1 * t0 - x2 * t1;
1015 f2[rev] = KissCpx::new(yi, yr);
1016 }
1017 crate::kiss_fft::opus_fft_impl(st, &mut f2);
1018
1019 let mut output_scalar = vec![0.0f32; out_len + 100];
1020 for i in 0..((n4 + 1) >> 1) {
1021 let im0 = f2[i].r;
1022 let re0 = f2[i].i;
1023 let t0_0 = trig[i];
1024 let t1_0 = trig[n4 + i];
1025 let yr0 = re0 * t0_0 + im0 * t1_0;
1026 let yi0 = re0 * t1_0 - im0 * t0_0;
1027 let j = n4 - 1 - i;
1028 let im1 = f2[j].r;
1029 let re1 = f2[j].i;
1030 let t0_1 = trig[j];
1031 let t1_1 = trig[n4 + j];
1032 let yr1 = re1 * t0_1 + im1 * t1_1;
1033 let yi1 = re1 * t1_1 - im1 * t0_1;
1034 output_scalar[overlap2 + 2 * i] = yr0;
1035 output_scalar[overlap2 + n2 - 1 - 2 * i] = yi0;
1036 output_scalar[overlap2 + n2 - 2 - 2 * i] = yr1;
1037 output_scalar[overlap2 + 2 * i + 1] = yi1;
1038 }
1039 for i in 0..overlap2 {
1041 let x1 = output_scalar[overlap - 1 - i];
1042 let x2 = output_scalar[i];
1043 let wp1 = mode.window[i];
1044 let wp2 = mode.window[overlap - 1 - i];
1045 output_scalar[i] = x2 * wp2 - x1 * wp1;
1046 output_scalar[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
1047 }
1048
1049 let max_diff = output_hw[..out_len]
1050 .iter()
1051 .zip(output_scalar[..out_len].iter())
1052 .map(|(a, b)| (a - b).abs())
1053 .fold(0.0f32, f32::max);
1054 assert!(
1055 max_diff < 0.5,
1056 "stride=1 NEON vs scalar mismatch: max_diff={}",
1057 max_diff
1058 );
1059 }
1060
1061 #[test]
1062 fn test_mdct_backward_neon_matches_scalar() {
1063 let mode = crate::modes::default_mode();
1064 let shift = 3;
1065 let n = mode.mdct.n >> shift; let n2 = n / 2; let n4 = n / 4; let overlap = mode.overlap; let overlap2 = overlap / 2; let stride = 8;
1071
1072 let frame_size = 960usize;
1074 let mut freq = vec![0.0f32; frame_size];
1075 for i in 0..frame_size {
1076 freq[i] = ((i as f32) * 0.01).sin() * 200.0;
1077 }
1078
1079 let out_len = n + overlap; let mut output_hw = vec![0.0f32; out_len];
1081 mode.mdct.backward(
1082 &freq[0..],
1083 &mut output_hw,
1084 mode.window,
1085 overlap,
1086 shift,
1087 stride,
1088 );
1089
1090 let st = mode.mdct.kfft[shift].as_ref().unwrap();
1092 let (trig, _) = mode.mdct.get_trig(shift);
1093
1094 use crate::kiss_fft::KissCpx;
1095 let mut f2 = vec![KissCpx::new(0.0, 0.0); n4];
1096 for i in 0..n4 {
1097 let rev = st.bitrev[i] as usize;
1098 let x1 = freq[2 * i * stride];
1099 let x2 = freq[stride * (n2 - 1 - 2 * i)];
1100 let t0 = trig[i];
1101 let t1 = trig[n4 + i];
1102 let yr = x2 * t0 + x1 * t1;
1103 let yi = x1 * t0 - x2 * t1;
1104 f2[rev] = KissCpx::new(yi, yr);
1105 }
1106 crate::kiss_fft::opus_fft_impl(st, &mut f2);
1107
1108 let mut output_scalar = vec![0.0f32; out_len];
1109 for i in 0..((n4 + 1) >> 1) {
1110 let im0 = f2[i].r;
1111 let re0 = f2[i].i;
1112 let t0_0 = trig[i];
1113 let t1_0 = trig[n4 + i];
1114 let yr0 = re0 * t0_0 + im0 * t1_0;
1115 let yi0 = re0 * t1_0 - im0 * t0_0;
1116 let j = n4 - 1 - i;
1117 let im1 = f2[j].r;
1118 let re1 = f2[j].i;
1119 let t0_1 = trig[j];
1120 let t1_1 = trig[n4 + j];
1121 let yr1 = re1 * t0_1 + im1 * t1_1;
1122 let yi1 = re1 * t1_1 - im1 * t0_1;
1123 output_scalar[overlap2 + 2 * i] = yr0;
1124 output_scalar[overlap2 + n2 - 1 - 2 * i] = yi0;
1125 output_scalar[overlap2 + n2 - 2 - 2 * i] = yr1;
1126 output_scalar[overlap2 + 2 * i + 1] = yi1;
1127 }
1128 for i in 0..overlap2 {
1130 let x1 = output_scalar[overlap - 1 - i];
1131 let x2 = output_scalar[i];
1132 let wp1 = mode.window[i];
1133 let wp2 = mode.window[overlap - 1 - i];
1134 output_scalar[i] = x2 * wp2 - x1 * wp1;
1135 output_scalar[overlap - 1 - i] = x2 * wp1 + x1 * wp2;
1136 }
1137
1138 for i in 0..out_len {
1139 let diff = (output_hw[i] - output_scalar[i]).abs();
1140 if diff > 1e-3 {
1141 eprintln!(
1142 "Mismatch at output[{}]: hw={} scalar={} diff={}",
1143 i, output_hw[i], output_scalar[i], diff
1144 );
1145 }
1146 }
1147 let max_diff = output_hw
1148 .iter()
1149 .zip(output_scalar.iter())
1150 .map(|(a, b)| (a - b).abs())
1151 .fold(0.0f32, f32::max);
1152 assert!(
1153 max_diff < 0.1,
1154 "NEON/HW vs scalar mismatch: max_diff={}",
1155 max_diff
1156 );
1157 }
1158}