Skip to main content

opus_rs/
celt.rs

1use crate::bands::{
2    SPREAD_NONE, SPREAD_NORMAL, compute_band_energies, denormalise_bands, haar1, log2amp,
3    normalise_bands, quant_all_bands, spreading_decision,
4};
5use crate::modes::{CeltMode, SPREAD_ICDF, TAPSET_ICDF, TF_SELECT_TABLE, TRIM_ICDF};
6use crate::quant_bands::{
7    quant_coarse_energy_advanced, quant_energy_finalise, quant_fine_energy, unquant_coarse_energy,
8    unquant_energy_finalise, unquant_fine_energy,
9};
10use crate::range_coder::RangeCoder;
11use crate::rate::{BITRES, clt_compute_allocation};
12
13#[cfg(target_arch = "aarch64")]
14use std::arch::aarch64::*;
15
16#[cfg(target_arch = "aarch64")]
17#[inline(always)]
18#[allow(unsafe_op_in_unsafe_fn)]
19unsafe fn sum_abs_neon(x: &[f32], n: usize) -> f32 {
20    let mut sum_vec = vdupq_n_f32(0.0);
21    let mut i = 0;
22
23    while i + 16 <= n {
24        let x0 = vld1q_f32(x.as_ptr().add(i));
25        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
26        let x2 = vld1q_f32(x.as_ptr().add(i + 8));
27        let x3 = vld1q_f32(x.as_ptr().add(i + 12));
28
29        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x0), vdupq_n_f32(1.0));
30        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x1), vdupq_n_f32(1.0));
31        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x2), vdupq_n_f32(1.0));
32        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x3), vdupq_n_f32(1.0));
33
34        i += 16;
35    }
36
37    while i + 8 <= n {
38        let x0 = vld1q_f32(x.as_ptr().add(i));
39        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
40        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x0), vdupq_n_f32(1.0));
41        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x1), vdupq_n_f32(1.0));
42        i += 8;
43    }
44
45    while i + 4 <= n {
46        let x0 = vld1q_f32(x.as_ptr().add(i));
47        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x0), vdupq_n_f32(1.0));
48        i += 4;
49    }
50
51    let mut sum = vaddvq_f32(sum_vec);
52
53    for j in i..n {
54        sum += x[j].abs();
55    }
56
57    sum
58}
59
60#[inline(always)]
61fn sum_abs(x: &[f32]) -> f32 {
62    #[cfg(target_arch = "x86_64")]
63    unsafe {
64        if std::arch::is_x86_feature_detected!("avx") {
65            return sum_abs_avx(x, x.len());
66        }
67    }
68    #[cfg(target_arch = "aarch64")]
69    unsafe {
70        sum_abs_neon(x, x.len())
71    }
72    #[cfg(not(target_arch = "aarch64"))]
73    {
74        x.iter().map(|&v| v.abs()).sum()
75    }
76}
77
78const MAX_FRAME_SIZE: usize = 2880;
79
80const DECODE_BUFFER_SIZE: usize = 3072;
81
82const INV_TABLE: [u8; 128] = [
83    255, 255, 156, 110, 86, 70, 59, 51, 45, 40, 37, 33, 31, 28, 26, 25, 23, 22, 21, 20, 19, 18, 17,
84    16, 16, 15, 15, 14, 13, 13, 12, 12, 12, 12, 11, 11, 11, 10, 10, 10, 9, 9, 9, 9, 9, 9, 8, 8, 8,
85    8, 8, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
86    5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3,
87    3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2,
88];
89
90const MAX_TRANSIENT_LEN: usize = 3000;
91
92#[allow(clippy::too_many_arguments)]
93fn transient_analysis(
94    input: &[f32],
95    len: usize,
96    channels: usize,
97    tf_estimate: &mut f32,
98    tf_chan: &mut usize,
99    allow_weak_transients: bool,
100    weak_transient: &mut bool,
101    _tone_freq: f32,
102    toneishness: f32,
103    tmp: &mut [f32],
104    tmp2: &mut [f32],
105) -> bool {
106    let mut mask_metric = 0.0f32;
107    let mut forward_decay = 0.0625f32;
108
109    *weak_transient = false;
110    if allow_weak_transients {
111        forward_decay = 0.03125f32;
112    }
113
114    let len2 = len / 2;
115    debug_assert!(len <= MAX_TRANSIENT_LEN);
116
117    for c in 0..channels {
118        let mut mem0 = 0.0f32;
119        let mut mem1 = 0.0f32;
120
121        for i in 0..len {
122            let x = input[c * len + i];
123            let y = mem0 + x;
124            let mem00 = mem0;
125            mem0 = mem0 - x + 0.5 * mem1;
126            mem1 = x - mem00;
127            tmp[i] = y;
128        }
129
130        tmp[..12].fill(0.0);
131
132        let mut mean = 0.0f32;
133        mem0 = 0.0f32;
134        for i in 0..len2 {
135            let x2 = (tmp[2 * i] * tmp[2 * i] + tmp[2 * i + 1] * tmp[2 * i + 1]) / 16.0;
136            mean += x2 / 4096.0;
137            mem0 = x2 + (1.0 - forward_decay) * mem0;
138            tmp2[i] = forward_decay * mem0;
139        }
140
141        mem0 = 0.0f32;
142        let mut max_e = 0.0f32;
143        for i in (0..len2).rev() {
144            mem0 = tmp2[i] + 0.875 * mem0;
145            tmp2[i] = 0.125 * mem0;
146            if tmp2[i] > max_e {
147                max_e = tmp2[i];
148            }
149        }
150
151        mean = (mean * max_e * 0.5 * (len2 as f32)).sqrt();
152        let norm = (len2 as f32) / (1e-10 + mean);
153
154        let mut unmask = 0.0f32;
155        for i in (12..(len2 - 5)).step_by(4) {
156            let id = (64.0 * norm * (tmp2[i] + 1e-10)).floor() as i32;
157            let id = id.clamp(0, 127) as usize;
158            unmask += INV_TABLE[id] as f32;
159        }
160
161        unmask = 64.0 * unmask * 4.0 / (6.0 * (len2 as f32 - 17.0));
162        if unmask > mask_metric {
163            *tf_chan = c;
164            mask_metric = unmask;
165        }
166    }
167
168    let mut is_transient = mask_metric > 200.0;
169
170    if toneishness > 0.98 && _tone_freq < 0.026 {
171        is_transient = false;
172        mask_metric = 0.0;
173    }
174
175    *tf_estimate = (mask_metric - 150.0).clamp(0.0, 1.0);
176
177    is_transient
178}
179
180fn l1_metric(tmp: &[f32], n: usize, lm: i32, bias: f32) -> f32 {
181    #[cfg(target_arch = "x86_64")]
182    unsafe {
183        if n >= 16 && std::arch::is_x86_feature_detected!("avx") {
184            return l1_metric_avx(tmp, n, lm, bias);
185        }
186    }
187    #[cfg(target_arch = "aarch64")]
188    {
189        if n >= 16 {
190            return unsafe { l1_metric_neon(tmp, n, lm, bias) };
191        }
192    }
193
194    let mut l1 = 0.0f32;
195    for &tv in tmp[..n].iter() {
196        l1 += tv.abs();
197    }
198    l1 + (lm as f32) * bias * l1
199}
200
201#[cfg(target_arch = "x86_64")]
202#[target_feature(enable = "avx")]
203unsafe fn sum_abs_avx(x: &[f32], n: usize) -> f32 {
204    use std::arch::x86_64::*;
205
206    let mut sum0 = _mm256_setzero_ps();
207    let mut sum1 = _mm256_setzero_ps();
208    let mut i = 0usize;
209    let sign_mask = _mm256_set1_ps(-0.0);
210
211    while i + 16 <= n {
212        let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
213        let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
214        sum0 = _mm256_add_ps(sum0, _mm256_andnot_ps(sign_mask, v0));
215        sum1 = _mm256_add_ps(sum1, _mm256_andnot_ps(sign_mask, v1));
216        i += 16;
217    }
218
219    while i + 8 <= n {
220        let v = _mm256_loadu_ps(x.as_ptr().add(i));
221        sum0 = _mm256_add_ps(sum0, _mm256_andnot_ps(sign_mask, v));
222        i += 8;
223    }
224
225    let sum = _mm256_add_ps(sum0, sum1);
226    let hi = _mm256_extractf128_ps(sum, 1);
227    let lo = _mm256_castps256_ps128(sum);
228    let s4 = _mm_add_ps(lo, hi);
229    let t1 = _mm_movehl_ps(s4, s4);
230    let s2 = _mm_add_ps(s4, t1);
231    let t2 = _mm_shuffle_ps(s2, s2, 0x55);
232    let mut out = _mm_cvtss_f32(_mm_add_ss(s2, t2));
233
234    for j in i..n {
235        out += x[j].abs();
236    }
237
238    out
239}
240
241#[cfg(target_arch = "x86_64")]
242#[target_feature(enable = "avx")]
243unsafe fn l1_metric_avx(tmp: &[f32], n: usize, lm: i32, bias: f32) -> f32 {
244    let l1 = sum_abs_avx(tmp, n);
245    l1 + (lm as f32) * bias * l1
246}
247
248#[cfg(target_arch = "aarch64")]
249#[target_feature(enable = "neon")]
250unsafe fn l1_metric_neon(tmp: &[f32], n: usize, lm: i32, bias: f32) -> f32 {
251    unsafe {
252        let mut sum4 = vdupq_n_f32(0.0);
253        let mut i = 0;
254
255        while i + 15 < n {
256            let v0 = vld1q_f32(tmp.as_ptr().add(i));
257            let v1 = vld1q_f32(tmp.as_ptr().add(i + 4));
258            let v2 = vld1q_f32(tmp.as_ptr().add(i + 8));
259            let v3 = vld1q_f32(tmp.as_ptr().add(i + 12));
260
261            sum4 = vaddq_f32(sum4, vabsq_f32(v0));
262            sum4 = vaddq_f32(sum4, vabsq_f32(v1));
263            sum4 = vaddq_f32(sum4, vabsq_f32(v2));
264            sum4 = vaddq_f32(sum4, vabsq_f32(v3));
265
266            i += 16;
267        }
268
269        while i + 3 < n {
270            let v = vld1q_f32(tmp.as_ptr().add(i));
271            sum4 = vaddq_f32(sum4, vabsq_f32(v));
272            i += 4;
273        }
274
275        let sum2 = vpaddq_f32(sum4, sum4);
276        let sum1 = vpaddq_f32(sum2, sum2);
277        let mut l1 = vgetq_lane_f32(sum1, 0);
278
279        while i < n {
280            l1 += tmp[i].abs();
281            i += 1;
282        }
283
284        l1 + (lm as f32) * bias * l1
285    }
286}
287
288const MAX_NB_EBANDS: usize = 21;
289
290const MAX_TF_TMP: usize = 176;
291
292#[allow(clippy::too_many_arguments)]
293fn tf_analysis(
294    mode: &CeltMode,
295    len: usize,
296    is_transient: bool,
297    tf_res: &mut [i32],
298    lambda: i32,
299    x: &[f32],
300    n0: usize,
301    lm: i32,
302    tf_estimate: f32,
303    tf_chan: usize,
304) -> i32 {
305    debug_assert!(len <= MAX_NB_EBANDS);
306    let mut metric = [0i32; MAX_NB_EBANDS];
307    let mut tmp = [0.0f32; MAX_TF_TMP];
308    let mut tmp_1 = [0.0f32; MAX_TF_TMP];
309
310    let bias = 0.04 * (-0.25f32).max(0.5 - tf_estimate);
311
312    for (i, metric_i) in metric[..len].iter_mut().enumerate() {
313        let n = ((mode.e_bands[i + 1] - mode.e_bands[i]) as usize) << lm;
314        let narrow = (mode.e_bands[i + 1] - mode.e_bands[i]) == 1;
315        let offset = tf_chan * n0 + ((mode.e_bands[i] as usize) << lm);
316        tmp[..n].copy_from_slice(&x[offset..offset + n]);
317
318        let mut l1 = l1_metric(&tmp[..n], n, if is_transient { lm } else { 0 }, bias);
319        let mut best_l1 = l1;
320        let mut best_level = 0;
321
322        if is_transient && !narrow {
323            tmp_1[..n].copy_from_slice(&tmp[..n]);
324            haar1(&mut tmp_1[..n], n >> lm, 1 << lm);
325            l1 = l1_metric(&tmp_1[..n], n, lm + 1, bias);
326            if l1 < best_l1 {
327                best_l1 = l1;
328                best_level = -1;
329            }
330        }
331
332        for k in 0..(lm + if is_transient || narrow { 0 } else { 1 }) {
333            let b = if is_transient { lm - k - 1 } else { k + 1 };
334
335            haar1(&mut tmp[..n], n >> k, 1 << k);
336            l1 = l1_metric(&tmp[..n], n, b, bias);
337
338            if l1 < best_l1 {
339                best_l1 = l1;
340                best_level = k + 1;
341            }
342        }
343
344        if is_transient {
345            *metric_i = 2 * best_level;
346        } else {
347            *metric_i = -2 * best_level;
348        }
349
350        if narrow && (*metric_i == 0 || *metric_i == -2 * lm) {
351            *metric_i -= 1;
352        }
353    }
354
355    let mut tf_select = 0;
356    let importance = [1.0f32; MAX_NB_EBANDS];
357    let mut selcost = [0.0f32; 2];
358
359    for sel in 0..2 {
360        let mut cost0 = importance[0]
361            * ((metric[0]
362                - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * sel] as i32)
363                as f32)
364                .abs();
365        let mut cost1 = importance[0]
366            * ((metric[0]
367                - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * sel + 1]
368                    as i32) as f32)
369                .abs()
370            + (if is_transient { 0.0 } else { lambda as f32 });
371
372        for i in 1..len {
373            let curr0 = cost0.min(cost1 + lambda as f32);
374            let curr1 = (cost0 + lambda as f32).min(cost1);
375            cost0 = curr0
376                + importance[i]
377                    * ((metric[i]
378                        - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * sel]
379                            as i32) as f32)
380                        .abs();
381            cost1 = curr1
382                + importance[i]
383                    * ((metric[i]
384                        - 2 * TF_SELECT_TABLE[lm as usize]
385                            [4 * (is_transient as usize) + 2 * sel + 1]
386                            as i32) as f32)
387                        .abs();
388        }
389        selcost[sel] = cost0.min(cost1);
390    }
391
392    if selcost[1] < selcost[0] {
393        tf_select = 1;
394    }
395
396    let mut cost0 = importance[0]
397        * ((metric[0]
398            - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * tf_select] as i32)
399            as f32)
400            .abs();
401    let mut cost1 = importance[0]
402        * ((metric[0]
403            - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * tf_select + 1]
404                as i32) as f32)
405            .abs()
406        + (if is_transient { 0.0 } else { lambda as f32 });
407
408    tf_res[0] = if cost0 < cost1 { 0 } else { 1 };
409
410    for i in 1..len {
411        let curr0 = cost0.min(cost1 + lambda as f32);
412        let curr1 = (cost0 + lambda as f32).min(cost1);
413        cost0 = curr0
414            + importance[i]
415                * ((metric[i]
416                    - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * tf_select]
417                        as i32) as f32)
418                    .abs();
419        cost1 = curr1
420            + importance[i]
421                * ((metric[i]
422                    - 2 * TF_SELECT_TABLE[lm as usize]
423                        [4 * (is_transient as usize) + 2 * tf_select + 1]
424                        as i32) as f32)
425                    .abs();
426        tf_res[i] = if cost0 < cost1 { 0 } else { 1 };
427    }
428
429    tf_select as i32
430}
431
432fn tf_encode(
433    start: usize,
434    end: usize,
435    is_transient: bool,
436    tf_res: &mut [i32],
437    lm: i32,
438    mut tf_select: i32,
439    rc: &mut RangeCoder,
440) -> i32 {
441    let mut curr = 0;
442    let mut tf_changed = 0;
443    let mut logp = if is_transient { 2 } else { 4 };
444    let mut budget = rc.storage as i32 * 8;
445    let mut tell = rc.tell();
446
447    let tf_select_rsv = if lm > 0 && tell + logp < budget { 1 } else { 0 };
448    budget -= tf_select_rsv;
449
450    for tf_res_i in tf_res[start..end].iter_mut() {
451        if tell + logp <= budget {
452            rc.encode_bit_logp(*tf_res_i ^ curr != 0, logp as u32);
453            tell = rc.tell();
454            curr = *tf_res_i;
455            tf_changed |= curr;
456        } else {
457            *tf_res_i = curr;
458        }
459        logp = if is_transient { 4 } else { 5 };
460    }
461
462    if tf_select_rsv != 0
463        && TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + (tf_changed as usize)]
464            != TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 + (tf_changed as usize)]
465    {
466        rc.encode_bit_logp(tf_select != 0, 1);
467    } else {
468        tf_select = 0;
469    }
470
471    for tf_res_i in tf_res[start..end].iter_mut() {
472        *tf_res_i = TF_SELECT_TABLE[lm as usize]
473            [4 * (is_transient as usize) + 2 * (tf_select as usize) + (*tf_res_i as usize)]
474            as i32;
475    }
476
477    tf_changed
478}
479
480fn tf_decode(
481    start: usize,
482    end: usize,
483    is_transient: bool,
484    tf_res: &mut [i32],
485    lm: i32,
486    rc: &mut RangeCoder,
487) {
488    let mut curr = 0;
489    let mut tf_changed = 0;
490    let mut logp = if is_transient { 2 } else { 4 };
491    let budget = rc.storage as i32 * 8;
492    let mut tell = rc.tell();
493
494    let tf_select_rsv = if lm > 0 && tell + logp < budget { 1 } else { 0 };
495    let budget = budget - tf_select_rsv;
496
497    for tf_res_i in tf_res[start..end].iter_mut() {
498        if tell + logp <= budget {
499            curr ^= if rc.decode_bit_logp(logp as u32) {
500                1
501            } else {
502                0
503            };
504            tell = rc.tell();
505            tf_changed |= curr;
506        }
507        *tf_res_i = curr;
508        logp = if is_transient { 4 } else { 5 };
509    }
510
511    let mut tf_select = 0;
512    let _budget = budget + tf_select_rsv;
513    if tf_select_rsv > 0
514        && TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + (tf_changed as usize)]
515            != TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 + (tf_changed as usize)]
516    {
517        tf_select = if rc.decode_bit_logp(1) { 1 } else { 0 };
518    }
519
520    for tf_res_i in tf_res[start..end].iter_mut() {
521        *tf_res_i = TF_SELECT_TABLE[lm as usize]
522            [4 * (is_transient as usize) + 2 * (tf_select as usize) + (*tf_res_i as usize)]
523            as i32;
524    }
525}
526
527fn stereo_analysis(m: &CeltMode, x: &[f32], lm: i32, n0: usize) -> bool {
528    let mut sum_lr = 1e-9f32;
529    let mut sum_ms = 1e-9f32;
530
531    for i in 0..13 {
532        let start = (m.e_bands[i] as usize) << lm;
533        let end = (m.e_bands[i + 1] as usize) << lm;
534        for j in start..end {
535            let l = x[j];
536            let r = x[n0 + j];
537            let m_val = l + r;
538            let s_val = l - r;
539            sum_lr += l.abs() + r.abs();
540            sum_ms += m_val.abs() + s_val.abs();
541        }
542    }
543
544    sum_ms *= std::f32::consts::FRAC_1_SQRT_2;
545    let mut thetas = 13;
546    if lm <= 1 {
547        thetas -= 8;
548    }
549
550    let left = (((m.e_bands[13] as usize) << (lm + 1)) + thetas) as f32 * sum_ms;
551    let right = ((m.e_bands[13] as usize) << (lm + 1)) as f32 * sum_lr;
552
553    left > right
554}
555
556const COMBFILTER_MINPERIOD: usize = 15;
557const COMBFILTER_MAXPERIOD: usize = 1024;
558
559const PREFILTER_GAINS: [[f32; 3]; 3] = [
560    [0.306_640_6, 0.217_041, 0.129_638_7],
561    [0.463_867_2, 0.268_066_4, 0.0],
562    [0.799_804_7, 0.100_097_7, 0.0],
563];
564
565#[allow(clippy::too_many_arguments)]
566fn comb_filter_const(
567    y: &mut [f32],
568    x: &[f32],
569    y_idx: usize,
570    x_idx: usize,
571    t: usize,
572    n: usize,
573    g10: f32,
574    g11: f32,
575    g12: f32,
576) {
577    #[cfg(target_arch = "aarch64")]
578    {
579        comb_filter_const_neon(y, x, y_idx, x_idx, t, n, g10, g11, g12);
580    }
581    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
582    unsafe {
583        if std::arch::is_x86_feature_detected!("avx") {
584            comb_filter_const_avx(y, x, y_idx, x_idx, t, n, g10, g11, g12);
585            return;
586        }
587    }
588    #[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
589    unsafe {
590        comb_filter_const_sse(y, x, y_idx, x_idx, t, n, g10, g11, g12);
591        #[allow(clippy::needless_return)]
592        return;
593    }
594    #[cfg(not(any(
595        target_arch = "aarch64",
596        all(target_arch = "x86_64", target_feature = "sse")
597    )))]
598    {
599        comb_filter_const_scalar(y, x, y_idx, x_idx, t, n, g10, g11, g12);
600    }
601}
602
603#[inline]
604#[allow(dead_code)]
605fn comb_filter_const_scalar(
606    y: &mut [f32],
607    x: &[f32],
608    y_idx: usize,
609    x_idx: usize,
610    t: usize,
611    n: usize,
612    g10: f32,
613    g11: f32,
614    g12: f32,
615) {
616    let mut x1;
617    let mut x2;
618    let mut x3;
619    let mut x4;
620    let mut x0;
621
622    x4 = x[x_idx - t - 2];
623    x3 = x[x_idx - t - 1];
624    x2 = x[x_idx - t];
625    x1 = x[x_idx - t + 1];
626
627    for i in 0..n {
628        x0 = x[x_idx + i - t + 2];
629        y[y_idx + i] = x[x_idx + i] + g10 * x2 + g11 * (x1 + x3) + g12 * (x0 + x4);
630        x4 = x3;
631        x3 = x2;
632        x2 = x1;
633        x1 = x0;
634    }
635}
636
637#[cfg(target_arch = "aarch64")]
638fn comb_filter_const_neon(
639    y: &mut [f32],
640    x: &[f32],
641    y_idx: usize,
642    x_idx: usize,
643    t: usize,
644    n: usize,
645    g10: f32,
646    g11: f32,
647    g12: f32,
648) {
649    unsafe { comb_filter_const_neon_impl(y, x, y_idx, x_idx, t, n, g10, g11, g12) }
650}
651
652#[cfg(target_arch = "aarch64")]
653#[inline(always)]
654#[allow(unsafe_op_in_unsafe_fn)]
655unsafe fn comb_filter_const_neon_impl(
656    y: &mut [f32],
657    x: &[f32],
658    y_idx: usize,
659    x_idx: usize,
660    t: usize,
661    n: usize,
662    g10: f32,
663    g11: f32,
664    g12: f32,
665) {
666    use std::arch::aarch64::*;
667
668    let g10v = vdupq_n_f32(g10);
669    let g11v = vdupq_n_f32(g11);
670    let g12v = vdupq_n_f32(g12);
671
672    let xbase = x.as_ptr().add(x_idx);
673    let ybase = y.as_mut_ptr().add(y_idx);
674
675    let mut x0v = vld1q_f32(xbase.sub(t + 2));
676
677    let mut i = 0;
678    while i + 4 <= n {
679        let x4v = vld1q_f32(xbase.add(i).sub(t - 2));
680
681        let x2v = vextq_f32(x0v, x4v, 2);
682
683        let x1v = vextq_f32(x0v, x4v, 1);
684
685        let x3v = vextq_f32(x0v, x4v, 3);
686
687        let xi = vld1q_f32(xbase.add(i));
688
689        let mut yi = xi;
690        yi = vfmaq_f32(yi, g10v, x2v);
691        yi = vfmaq_f32(yi, g11v, vaddq_f32(x1v, x3v));
692        yi = vfmaq_f32(yi, g12v, vaddq_f32(x4v, x0v));
693        vst1q_f32(ybase.add(i), yi);
694
695        x0v = x4v;
696        i += 4;
697    }
698
699    let x0v_arr: [f32; 4] = std::mem::transmute(x0v);
700    let mut sx4 = x0v_arr[0];
701    let mut sx3 = x0v_arr[1];
702    let mut sx2 = x0v_arr[2];
703    let mut sx1 = x0v_arr[3];
704
705    while i < n {
706        let sx0 = x[x_idx + i - t + 2];
707        y[y_idx + i] = x[x_idx + i] + g10 * sx2 + g11 * (sx1 + sx3) + g12 * (sx0 + sx4);
708        sx4 = sx3;
709        sx3 = sx2;
710        sx2 = sx1;
711        sx1 = sx0;
712        i += 1;
713    }
714}
715
716#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
717#[inline(always)]
718#[allow(unsafe_op_in_unsafe_fn)]
719unsafe fn comb_filter_const_sse(
720    y: &mut [f32],
721    x: &[f32],
722    y_idx: usize,
723    x_idx: usize,
724    t: usize,
725    n: usize,
726    g10: f32,
727    g11: f32,
728    g12: f32,
729) {
730    use std::arch::x86_64::*;
731
732    let g10v = _mm_set1_ps(g10);
733    let g11v = _mm_set1_ps(g11);
734    let g12v = _mm_set1_ps(g12);
735
736    let xbase = x.as_ptr().add(x_idx);
737    let ybase = y.as_mut_ptr().add(y_idx);
738    let mut x0v = _mm_loadu_ps(xbase.sub(t + 2));
739
740    let mut i = 0;
741    while i + 4 <= n {
742        let x4v = _mm_loadu_ps(xbase.add(i).sub(t - 2));
743
744        let x2v = _mm_shuffle_ps(x0v, x4v, 0x4e);
745
746        let x1v = _mm_shuffle_ps(x0v, x2v, 0x99);
747
748        let x3v = _mm_shuffle_ps(x2v, x4v, 0x99);
749
750        let xi = _mm_loadu_ps(xbase.add(i));
751
752        let mut yi = xi;
753        yi = _mm_add_ps(yi, _mm_mul_ps(g10v, x2v));
754        let yi2 = _mm_add_ps(
755            _mm_mul_ps(g11v, _mm_add_ps(x3v, x1v)),
756            _mm_mul_ps(g12v, _mm_add_ps(x4v, x0v)),
757        );
758        yi = _mm_add_ps(yi, yi2);
759        _mm_storeu_ps(ybase.add(i), yi);
760
761        x0v = x4v;
762        i += 4;
763    }
764
765    let x0v_arr: [f32; 4] = std::mem::transmute(x0v);
766    let mut sx4 = x0v_arr[0];
767    let mut sx3 = x0v_arr[1];
768    let mut sx2 = x0v_arr[2];
769    let mut sx1 = x0v_arr[3];
770
771    while i < n {
772        let sx0 = x[x_idx + i - t + 2];
773        y[y_idx + i] = x[x_idx + i] + g10 * sx2 + g11 * (sx1 + sx3) + g12 * (sx0 + sx4);
774        sx4 = sx3;
775        sx3 = sx2;
776        sx2 = sx1;
777        sx1 = sx0;
778        i += 1;
779    }
780}
781
782#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
783#[target_feature(enable = "avx,fma")]
784#[allow(unsafe_op_in_unsafe_fn)]
785unsafe fn comb_filter_const_avx(
786    y: &mut [f32],
787    x: &[f32],
788    y_idx: usize,
789    x_idx: usize,
790    t: usize,
791    n: usize,
792    g10: f32,
793    g11: f32,
794    g12: f32,
795) {
796    use std::arch::x86_64::*;
797
798    let g10v = _mm256_set1_ps(g10);
799    let g11v = _mm256_set1_ps(g11);
800    let g12v = _mm256_set1_ps(g12);
801
802    let xbase = x.as_ptr().add(x_idx);
803    let ybase = y.as_mut_ptr().add(y_idx);
804
805    let mut i = 0;
806
807    while i + 16 <= n {
808        let xi_a = _mm256_loadu_ps(xbase.add(i));
809        let x0_a = _mm256_loadu_ps(xbase.add(i).sub(t + 2));
810        let x4_a = _mm256_loadu_ps(xbase.add(i).sub(t - 2));
811
812        let x2_a = _mm256_loadu_ps(xbase.add(i).sub(t));
813        let x1x3_a = _mm256_add_ps(
814            _mm256_loadu_ps(xbase.add(i).sub(t + 1)),
815            _mm256_loadu_ps(xbase.add(i).sub(t - 1)),
816        );
817        let x0x4_a = _mm256_add_ps(x0_a, x4_a);
818
819        let mut yi_a = xi_a;
820        yi_a = _mm256_fmadd_ps(g10v, x2_a, yi_a);
821        yi_a = _mm256_fmadd_ps(g11v, x1x3_a, yi_a);
822        yi_a = _mm256_fmadd_ps(g12v, x0x4_a, yi_a);
823        _mm256_storeu_ps(ybase.add(i), yi_a);
824
825        let j = i + 8;
826        let xi_b = _mm256_loadu_ps(xbase.add(j));
827        let x0_b = _mm256_loadu_ps(xbase.add(j).sub(t + 2));
828        let x4_b = _mm256_loadu_ps(xbase.add(j).sub(t - 2));
829        let x2_b = _mm256_loadu_ps(xbase.add(j).sub(t));
830        let x1x3_b = _mm256_add_ps(
831            _mm256_loadu_ps(xbase.add(j).sub(t + 1)),
832            _mm256_loadu_ps(xbase.add(j).sub(t - 1)),
833        );
834        let x0x4_b = _mm256_add_ps(x0_b, x4_b);
835
836        let mut yi_b = xi_b;
837        yi_b = _mm256_fmadd_ps(g10v, x2_b, yi_b);
838        yi_b = _mm256_fmadd_ps(g11v, x1x3_b, yi_b);
839        yi_b = _mm256_fmadd_ps(g12v, x0x4_b, yi_b);
840        _mm256_storeu_ps(ybase.add(j), yi_b);
841
842        i += 16;
843    }
844
845    while i + 8 <= n {
846        let xi = _mm256_loadu_ps(xbase.add(i));
847        let x0 = _mm256_loadu_ps(xbase.add(i).sub(t + 2));
848        let x4 = _mm256_loadu_ps(xbase.add(i).sub(t - 2));
849        let x2 = _mm256_loadu_ps(xbase.add(i).sub(t));
850        let x1x3 = _mm256_add_ps(
851            _mm256_loadu_ps(xbase.add(i).sub(t + 1)),
852            _mm256_loadu_ps(xbase.add(i).sub(t - 1)),
853        );
854        let x0x4 = _mm256_add_ps(x0, x4);
855
856        let mut yi = xi;
857        yi = _mm256_fmadd_ps(g10v, x2, yi);
858        yi = _mm256_fmadd_ps(g11v, x1x3, yi);
859        yi = _mm256_fmadd_ps(g12v, x0x4, yi);
860        _mm256_storeu_ps(ybase.add(i), yi);
861
862        i += 8;
863    }
864
865    if i + 4 <= n {
866        comb_filter_const_sse_fma(y, x, y_idx + i, x_idx + i, t, n - i, g10, g11, g12);
867        return;
868    }
869
870    let mut sx4 = x[x_idx + i - t - 2];
871    let mut sx3 = x[x_idx + i - t - 1];
872    let mut sx2 = x[x_idx + i - t];
873    let mut sx1 = x[x_idx + i - t + 1];
874    while i < n {
875        let sx0 = x[x_idx + i - t + 2];
876        y[y_idx + i] = x[x_idx + i] + g10 * sx2 + g11 * (sx1 + sx3) + g12 * (sx0 + sx4);
877        sx4 = sx3;
878        sx3 = sx2;
879        sx2 = sx1;
880        sx1 = sx0;
881        i += 1;
882    }
883}
884
885#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
886#[target_feature(enable = "avx,fma")]
887#[allow(unsafe_op_in_unsafe_fn)]
888unsafe fn comb_filter_const_sse_fma(
889    y: &mut [f32],
890    x: &[f32],
891    y_idx: usize,
892    x_idx: usize,
893    t: usize,
894    n: usize,
895    g10: f32,
896    g11: f32,
897    g12: f32,
898) {
899    use std::arch::x86_64::*;
900
901    let g10v = _mm_set1_ps(g10);
902    let g11v = _mm_set1_ps(g11);
903    let g12v = _mm_set1_ps(g12);
904
905    let xbase = x.as_ptr().add(x_idx);
906    let ybase = y.as_mut_ptr().add(y_idx);
907    let mut x0v = _mm_loadu_ps(xbase.sub(t + 2));
908
909    let mut i = 0;
910    while i + 4 <= n {
911        let x4v = _mm_loadu_ps(xbase.add(i).sub(t - 2));
912        let x2v = _mm_shuffle_ps(x0v, x4v, 0x4e);
913        let x1v = _mm_shuffle_ps(x0v, x2v, 0x99);
914        let x3v = _mm_shuffle_ps(x2v, x4v, 0x99);
915        let xi = _mm_loadu_ps(xbase.add(i));
916
917        let mut yi = xi;
918        yi = _mm_fmadd_ps(g10v, x2v, yi);
919        yi = _mm_fmadd_ps(g11v, _mm_add_ps(x1v, x3v), yi);
920        yi = _mm_fmadd_ps(g12v, _mm_add_ps(x0v, x4v), yi);
921        _mm_storeu_ps(ybase.add(i), yi);
922
923        x0v = x4v;
924        i += 4;
925    }
926
927    let x0v_arr: [f32; 4] = std::mem::transmute(x0v);
928    let mut sx4 = x0v_arr[0];
929    let mut sx3 = x0v_arr[1];
930    let mut sx2 = x0v_arr[2];
931    let mut sx1 = x0v_arr[3];
932    while i < n {
933        let sx0 = x[x_idx + i - t + 2];
934        y[y_idx + i] = x[x_idx + i] + g10 * sx2 + g11 * (sx1 + sx3) + g12 * (sx0 + sx4);
935        sx4 = sx3;
936        sx3 = sx2;
937        sx2 = sx1;
938        sx1 = sx0;
939        i += 1;
940    }
941}
942
943#[allow(clippy::too_many_arguments)]
944fn comb_filter(
945    y: &mut [f32],
946    x: &[f32],
947    y_idx: usize,
948    x_idx: usize,
949    t0: usize,
950    t1: usize,
951    n: usize,
952    g0: f32,
953    g1: f32,
954    tapset0: i32,
955    tapset1: i32,
956    window: &[f32],
957    overlap: usize,
958) {
959    if g0 == 0.0 && g1 == 0.0 {
960        if x_idx != y_idx || !std::ptr::eq(x.as_ptr(), y.as_ptr()) {
961            y[y_idx..y_idx + n].copy_from_slice(&x[x_idx..x_idx + n]);
962        }
963        return;
964    }
965
966    let t0 = t0.clamp(
967        COMBFILTER_MINPERIOD,
968        x_idx.saturating_sub(2).max(COMBFILTER_MINPERIOD),
969    );
970    let t1 = t1.clamp(
971        COMBFILTER_MINPERIOD,
972        x_idx.saturating_sub(2).max(COMBFILTER_MINPERIOD),
973    );
974
975    let g00 = g0 * PREFILTER_GAINS[tapset0 as usize][0];
976    let g01 = g0 * PREFILTER_GAINS[tapset0 as usize][1];
977    let g02 = g0 * PREFILTER_GAINS[tapset0 as usize][2];
978
979    let g10 = g1 * PREFILTER_GAINS[tapset1 as usize][0];
980    let g11 = g1 * PREFILTER_GAINS[tapset1 as usize][1];
981    let g12 = g1 * PREFILTER_GAINS[tapset1 as usize][2];
982
983    let mut x1 = x[x_idx - t1 + 1];
984    let mut x2 = x[x_idx - t1];
985    let mut x3 = x[x_idx - t1 - 1];
986    let mut x4 = x[x_idx - t1 - 2];
987
988    let mut inner_overlap = overlap;
989    if g0 == g1 && t0 == t1 && tapset0 == tapset1 {
990        inner_overlap = 0;
991    }
992
993    let mut i = 0;
994    while i < inner_overlap && i < n {
995        let x0 = x[x_idx + i - t1 + 2];
996        let f = window[i] * window[i];
997        y[y_idx + i] = x[x_idx + i]
998            + (1.0 - f)
999                * (g00 * x[x_idx + i - t0]
1000                    + g01 * (x[x_idx + i - t0 + 1] + x[x_idx + i - t0 - 1])
1001                    + g02 * (x[x_idx + i - t0 + 2] + x[x_idx + i - t0 - 2]))
1002            + f * (g10 * x2 + g11 * (x1 + x3) + g12 * (x0 + x4));
1003
1004        x4 = x3;
1005        x3 = x2;
1006        x2 = x1;
1007        x1 = x0;
1008        i += 1;
1009    }
1010
1011    if i < n {
1012        if g1 == 0.0 {
1013            y[y_idx + i..y_idx + n].copy_from_slice(&x[x_idx + i..x_idx + n]);
1014        } else {
1015            comb_filter_const(y, x, y_idx + i, x_idx + i, t1, n - i, g10, g11, g12);
1016        }
1017    }
1018}
1019
1020/// In-place comb filter: buf[y_idx..y_idx+n] is both input and output.
1021/// Reference samples at buf[y_idx + i - T + offset] may already be filtered
1022/// if T < i, matching C libopus's in-place comb_filter(out, out, ...) behavior.
1023fn comb_filter_inplace(
1024    buf: &mut [f32],
1025    y_idx: usize,
1026    t0: usize,
1027    t1: usize,
1028    n: usize,
1029    g0: f32,
1030    g1: f32,
1031    tapset0: i32,
1032    tapset1: i32,
1033    window: &[f32],
1034    overlap: usize,
1035) {
1036    if g0 == 0.0 && g1 == 0.0 {
1037        // nothing to do; buf[y_idx..] already holds the input
1038        return;
1039    }
1040
1041    let t0 = t0.clamp(COMBFILTER_MINPERIOD, y_idx - 2);
1042    let t1 = t1.clamp(COMBFILTER_MINPERIOD, y_idx - 2);
1043
1044    let g00 = g0 * PREFILTER_GAINS[tapset0 as usize][0];
1045    let g01 = g0 * PREFILTER_GAINS[tapset0 as usize][1];
1046    let g02 = g0 * PREFILTER_GAINS[tapset0 as usize][2];
1047
1048    let g10 = g1 * PREFILTER_GAINS[tapset1 as usize][0];
1049    let g11 = g1 * PREFILTER_GAINS[tapset1 as usize][1];
1050    let g12 = g1 * PREFILTER_GAINS[tapset1 as usize][2];
1051
1052    let mut inner_overlap = overlap;
1053    if g0 == g1 && t0 == t1 && tapset0 == tapset1 {
1054        inner_overlap = 0;
1055    }
1056
1057    let mut i = 0;
1058    while i < inner_overlap && i < n {
1059        let idx = y_idx + i;
1060        let f = window[i] * window[i];
1061        let s = buf[idx]; // original input (not yet overwritten at idx)
1062        let r0 = buf[idx - t0];
1063        let r0p1 = buf[idx - t0 + 1];
1064        let r0m1 = buf[idx - t0 - 1];
1065        let r0p2 = buf[idx - t0 + 2];
1066        let r0m2 = buf[idx - t0 - 2];
1067        let r1 = buf[idx - t1];
1068        let r1p1 = buf[idx - t1 + 1];
1069        let r1m1 = buf[idx - t1 - 1];
1070        let r1p2 = buf[idx - t1 + 2];
1071        let r1m2 = buf[idx - t1 - 2];
1072        buf[idx] = s
1073            + (1.0 - f) * (g00 * r0 + g01 * (r0p1 + r0m1) + g02 * (r0p2 + r0m2))
1074            + f * (g10 * r1 + g11 * (r1p1 + r1m1) + g12 * (r1p2 + r1m2));
1075        i += 1;
1076    }
1077
1078    // Constant region: only new filter (t1, g1)
1079    while i < n {
1080        let idx = y_idx + i;
1081        let s = buf[idx];
1082        let r1 = buf[idx - t1];
1083        let r1p1 = buf[idx - t1 + 1];
1084        let r1m1 = buf[idx - t1 - 1];
1085        let r1p2 = buf[idx - t1 + 2];
1086        let r1m2 = buf[idx - t1 - 2];
1087        buf[idx] = s + g10 * r1 + g11 * (r1p1 + r1m1) + g12 * (r1p2 + r1m2);
1088        i += 1;
1089    }
1090}
1091
1092fn run_prefilter(
1093    in_buf: &mut [f32],
1094    prefilter_mem: &mut [f32],
1095    prefilter_period: usize,
1096    prefilter_gain: f32,
1097    prefilter_tapset: i32,
1098    tapset_decision: i32,
1099    window: &[f32],
1100    channels: usize,
1101    frame_size: usize,
1102    overlap: usize,
1103
1104    pre: &mut [f32],
1105    pitch_buf: &mut [f32],
1106    before: &mut [f32],
1107    after: &mut [f32],
1108) -> (bool, f32, usize) {
1109    let max_period = COMBFILTER_MAXPERIOD;
1110    let min_period = COMBFILTER_MINPERIOD;
1111    let buf_stride = frame_size + overlap;
1112    let pre_size = max_period + frame_size;
1113
1114    for c in 0..channels {
1115        pre[c * pre_size..c * pre_size + max_period]
1116            .copy_from_slice(&prefilter_mem[c * max_period..(c + 1) * max_period]);
1117        pre[c * pre_size + max_period..c * pre_size + pre_size].copy_from_slice(
1118            &in_buf[c * buf_stride + overlap..c * buf_stride + overlap + frame_size],
1119        );
1120    }
1121
1122    let pitch_buf_len = (max_period + frame_size) >> 1;
1123    {
1124        let pre_slices: Vec<&[f32]> = (0..channels)
1125            .map(|c| &pre[c * pre_size..c * pre_size + pre_size])
1126            .collect();
1127        crate::pitch::pitch_downsample(&pre_slices, pitch_buf, pitch_buf_len, channels, 2);
1128    }
1129
1130    let search_max = max_period - 3 * min_period;
1131    let pitch_result = crate::pitch::pitch_search(
1132        &pitch_buf[max_period >> 1..],
1133        pitch_buf,
1134        frame_size,
1135        search_max,
1136    );
1137    let mut pitch_index = (max_period - pitch_result).min(max_period - 2);
1138
1139    let gain1_raw = crate::pitch::remove_doubling(
1140        pitch_buf,
1141        max_period,
1142        min_period,
1143        frame_size,
1144        &mut pitch_index,
1145        prefilter_period,
1146        prefilter_gain,
1147    );
1148    let mut gain1 = gain1_raw * 0.7;
1149
1150    let mut pf_threshold = 0.2f32;
1151    if (pitch_index as i32 - prefilter_period as i32).unsigned_abs() as usize * 10 > pitch_index {
1152        pf_threshold += 0.2;
1153    }
1154    if prefilter_gain > 0.4 {
1155        pf_threshold -= 0.1;
1156    }
1157    if prefilter_gain > 0.55 {
1158        pf_threshold -= 0.1;
1159    }
1160    pf_threshold = pf_threshold.max(0.2);
1161
1162    let pf_on;
1163    if gain1 < pf_threshold {
1164        gain1 = 0.0;
1165        pf_on = false;
1166    } else {
1167        if (gain1 - prefilter_gain).abs() < 0.1 {
1168            gain1 = prefilter_gain;
1169        }
1170        let qg = ((gain1 * 32.0 / 3.0 + 0.5).floor() as i32 - 1).clamp(0, 7);
1171        gain1 = 0.09375 * (qg + 1) as f32;
1172        pf_on = true;
1173    }
1174
1175    let before = &mut before[..channels];
1176    for c in 0..channels {
1177        let start = c * buf_stride + overlap;
1178        before[c] = sum_abs(&in_buf[start..start + frame_size]);
1179    }
1180
1181    let offset = 0usize;
1182    let prev_period = prefilter_period.clamp(COMBFILTER_MINPERIOD, max_period - 2);
1183
1184    for c in 0..channels {
1185        if offset > 0 {
1186            let pre_c = &pre[c * pre_size..];
1187            comb_filter(
1188                in_buf,
1189                pre_c,
1190                c * buf_stride + overlap,
1191                max_period,
1192                prev_period,
1193                prev_period,
1194                offset,
1195                -prefilter_gain,
1196                -prefilter_gain,
1197                prefilter_tapset,
1198                prefilter_tapset,
1199                window,
1200                0,
1201            );
1202        }
1203
1204        {
1205            let pre_c = &pre[c * pre_size..];
1206            comb_filter(
1207                in_buf,
1208                pre_c,
1209                c * buf_stride + overlap + offset,
1210                max_period + offset,
1211                prev_period,
1212                pitch_index,
1213                frame_size - offset,
1214                -prefilter_gain,
1215                -gain1,
1216                prefilter_tapset,
1217                tapset_decision,
1218                window,
1219                overlap,
1220            );
1221        }
1222    }
1223
1224    let after = &mut after[..channels];
1225    for c in 0..channels {
1226        let start = c * buf_stride + overlap;
1227        after[c] = sum_abs(&in_buf[start..start + frame_size]);
1228    }
1229
1230    let cancel_pitch = (0..channels).any(|c| after[c] > before[c]);
1231
1232    if cancel_pitch {
1233        for c in 0..channels {
1234            in_buf[c * buf_stride + overlap..c * buf_stride + overlap + frame_size]
1235                .copy_from_slice(
1236                    &pre[c * pre_size + max_period..c * pre_size + max_period + frame_size],
1237                );
1238        }
1239
1240        for c in 0..channels {
1241            if frame_size >= max_period {
1242                prefilter_mem[c * max_period..(c + 1) * max_period].copy_from_slice(
1243                    &pre[c * pre_size + frame_size..c * pre_size + frame_size + max_period],
1244                );
1245            } else {
1246                let shift = max_period - frame_size;
1247                prefilter_mem.copy_within(
1248                    c * max_period + frame_size..(c + 1) * max_period,
1249                    c * max_period,
1250                );
1251                prefilter_mem[c * max_period + shift..(c + 1) * max_period].copy_from_slice(
1252                    &pre[c * pre_size + max_period..c * pre_size + max_period + frame_size],
1253                );
1254            }
1255        }
1256        return (false, 0.0, pitch_index);
1257    }
1258
1259    for c in 0..channels {
1260        if frame_size >= max_period {
1261            prefilter_mem[c * max_period..(c + 1) * max_period].copy_from_slice(
1262                &pre[c * pre_size + frame_size..c * pre_size + frame_size + max_period],
1263            );
1264        } else {
1265            let shift = max_period - frame_size;
1266            prefilter_mem.copy_within(
1267                c * max_period + frame_size..(c + 1) * max_period,
1268                c * max_period,
1269            );
1270            prefilter_mem[c * max_period + shift..(c + 1) * max_period].copy_from_slice(
1271                &pre[c * pre_size + max_period..c * pre_size + max_period + frame_size],
1272            );
1273        }
1274    }
1275
1276    (pf_on, gain1, pitch_index)
1277}
1278
1279const STRIDE_ACCESS_PAD: usize = crate::pvq::MAX_PVQ_N * 8;
1280
1281pub struct CeltEncoder {
1282    mode: &'static CeltMode,
1283    channels: usize,
1284    pub complexity: i32,
1285    syn_mem: Vec<f32>,
1286    enc_decode_mem: Vec<f32>,
1287    old_band_e: Vec<f32>,
1288    preemph_mem: Vec<f32>,
1289    tonal_average: i32,
1290    hf_average: i32,
1291    tapset_decision: i32,
1292    spread_decision: i32,
1293    intensity: i32,
1294    last_coded_bands: i32,
1295    prefilter_mem: Vec<f32>,
1296    prefilter_period: usize,
1297    prefilter_gain: f32,
1298    prefilter_tapset: i32,
1299    old_band_e2: Vec<f32>,
1300    old_band_e3: Vec<f32>,
1301    last_band_log_e: Vec<f32>,
1302    delayed_intra: f32,
1303
1304    w_in_buf: Vec<f32>,
1305    w_freq: Vec<f32>,
1306    w_band_e: Vec<f32>,
1307    w_x: Vec<f32>,
1308    w_band_log_e: Vec<f32>,
1309    w_error: Vec<f32>,
1310    w_tf_res: Vec<i32>,
1311    w_cap: Vec<i32>,
1312    w_offsets: Vec<i32>,
1313    w_pulses: Vec<i32>,
1314    w_ebits: Vec<i32>,
1315    w_fine_priority: Vec<i32>,
1316    w_collapse_masks: Vec<u32>,
1317    w_band_amp_synth: Vec<f32>,
1318    w_freq_synth: Vec<f32>,
1319    consec_transient: i32,
1320
1321    w_prefilter_pre: Vec<f32>,
1322    w_prefilter_pitch_buf: Vec<f32>,
1323    w_prefilter_before: Vec<f32>,
1324    w_prefilter_after: Vec<f32>,
1325
1326    w_transient_tmp: Vec<f32>,
1327    w_transient_tmp2: Vec<f32>,
1328}
1329
1330const INTEN_THRESHOLDS: [i32; 21] = [
1331    1, 2, 3, 4, 5, 6, 7, 8, 16, 24, 36, 44, 50, 56, 62, 67, 72, 79, 88, 106, 134,
1332];
1333const INTEN_HYSTERESIS: [i32; 21] = [
1334    1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 4, 5, 6, 8, 8,
1335];
1336
1337fn hysteresis_decision(val: i32, thresholds: &[i32], hysteresis: &[i32], prev: i32) -> i32 {
1338    let mut i = 0;
1339    while i < thresholds.len() {
1340        if val < thresholds[i] {
1341            break;
1342        }
1343        i += 1;
1344    }
1345    let mut res = i as i32;
1346    if res > prev && val < thresholds[prev as usize] + hysteresis[prev as usize] {
1347        res = prev;
1348    }
1349    if res < prev && res > 0 && val > thresholds[prev as usize - 1] - hysteresis[prev as usize - 1]
1350    {
1351        res = prev;
1352    }
1353    res
1354}
1355
1356#[allow(clippy::too_many_arguments)]
1357fn alloc_trim_analysis(
1358    mode: &CeltMode,
1359    x: &[f32],
1360    band_log_e: &[f32],
1361    end: usize,
1362    lm: i32,
1363    channels: usize,
1364    n0: usize,
1365    stereo_saving: &mut f32,
1366    tf_estimate: f32,
1367    intensity: i32,
1368    surround_trim: f32,
1369    equiv_rate: i32,
1370) -> i32 {
1371    let mut trim = 5.0f32;
1372    if equiv_rate < 64000 {
1373        trim = 4.0;
1374    } else if equiv_rate < 80000 {
1375        let frac = (equiv_rate - 64000) as f32 / 1024.0;
1376        trim = 4.0 + (1.0 / 16.0) * frac;
1377    }
1378
1379    if channels == 2 {
1380        let mut sum = 0.0f32;
1381        for i in 0..8 {
1382            let offset = (mode.e_bands[i] as usize) << lm;
1383            let n = ((mode.e_bands[i + 1] - mode.e_bands[i]) as usize) << lm;
1384            let mut partial = 0.0f32;
1385            for j in 0..n {
1386                partial += x[offset + j] * x[n0 + offset + j];
1387            }
1388            sum += partial;
1389        }
1390        sum = (sum / 8.0).abs().min(1.0);
1391        let mut min_xc = sum;
1392        for i in 8..intensity as usize {
1393            let offset = (mode.e_bands[i] as usize) << lm;
1394            let n = ((mode.e_bands[i + 1] - mode.e_bands[i]) as usize) << lm;
1395            let mut partial = 0.0f32;
1396            for j in 0..n {
1397                partial += x[offset + j] * x[n0 + offset + j];
1398            }
1399            min_xc = min_xc.min(partial.abs());
1400        }
1401        min_xc = min_xc.min(1.0);
1402
1403        let log_xc = (1.001 - sum * sum).log2();
1404        let log_xc2 = (log_xc * 0.5).max((1.001 - min_xc * min_xc).log2());
1405
1406        trim += (-4.0f32).max(0.75 * log_xc);
1407        *stereo_saving = (*stereo_saving + 0.25).min(-0.5 * log_xc2);
1408    }
1409
1410    let mut diff = 0.0f32;
1411    for c in 0..channels {
1412        for i in 0..end - 1 {
1413            diff += band_log_e[c * mode.nb_ebands + i] * (2 + 2 * i as i32 - end as i32) as f32;
1414        }
1415    }
1416    diff /= (channels * (end - 1)) as f32;
1417    trim -= (-2.0f32).max(2.0f32.min((diff + 1.0) / 6.0));
1418    trim -= surround_trim;
1419    trim -= 2.0 * tf_estimate;
1420
1421    let trim_index = (trim + 0.5).floor() as i32;
1422    trim_index.clamp(0, 10)
1423}
1424
1425#[inline(always)]
1426fn median3(a: f32, b: f32, c: f32) -> f32 {
1427    let mut v = [a, b, c];
1428    v.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
1429    v[1]
1430}
1431
1432#[inline(always)]
1433fn median5(v: &[f32]) -> f32 {
1434    let mut x = [v[0], v[1], v[2], v[3], v[4]];
1435    x.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1436    x[2]
1437}
1438
1439#[allow(clippy::too_many_arguments)]
1440fn dynalloc_analysis_simple(
1441    mode: &CeltMode,
1442    band_log_e: &[f32],
1443    old_band_e: &[f32],
1444    start: usize,
1445    end: usize,
1446    channels: usize,
1447    lm: usize,
1448    effective_bytes: usize,
1449    is_transient: bool,
1450    offsets: &mut [i32],
1451    cap: &[i32],
1452) {
1453    offsets.fill(0);
1454    if effective_bytes < (30 + 5 * lm) {
1455        return;
1456    }
1457
1458    let nb = mode.nb_ebands;
1459    let mut follower = vec![0.0f32; nb * channels];
1460
1461    for c in 0..channels {
1462        let base = c * nb;
1463        let mut band_log_e3 = vec![0.0f32; end];
1464        for i in 0..end {
1465            let mut e = band_log_e[base + i];
1466            if lm == 0 && i < 8 {
1467                e = e.max(old_band_e[base + i]);
1468            }
1469            band_log_e3[i] = e;
1470        }
1471
1472        let mut last = 0usize;
1473        follower[base] = band_log_e3[0];
1474        for i in 1..end {
1475            if band_log_e3[i] > band_log_e3[i - 1] + 0.5 {
1476                last = i;
1477            }
1478            follower[base + i] = (follower[base + i - 1] + 1.5).min(band_log_e3[i]);
1479        }
1480        for i in (0..last).rev() {
1481            follower[base + i] =
1482                follower[base + i].min((follower[base + i + 1] + 2.0).min(band_log_e3[i]));
1483        }
1484
1485        let offset = 1.0f32;
1486        if end >= 5 {
1487            for i in 2..end - 2 {
1488                follower[base + i] =
1489                    follower[base + i].max(median5(&band_log_e3[i - 2..i + 3]) - offset);
1490            }
1491        }
1492        if end >= 3 {
1493            let l = median3(band_log_e3[0], band_log_e3[1], band_log_e3[2]) - offset;
1494            follower[base] = follower[base].max(l);
1495            follower[base + 1] = follower[base + 1].max(l);
1496
1497            let r = median3(
1498                band_log_e3[end - 3],
1499                band_log_e3[end - 2],
1500                band_log_e3[end - 1],
1501            ) - offset;
1502            follower[base + end - 2] = follower[base + end - 2].max(r);
1503            follower[base + end - 1] = follower[base + end - 1].max(r);
1504        }
1505    }
1506
1507    if channels == 2 {
1508        for i in start..end {
1509            let l = follower[i];
1510            let r = follower[nb + i];
1511            let r2 = r.max(l - 4.0);
1512            let l2 = l.max(r - 4.0);
1513            follower[i] =
1514                ((band_log_e[i] - l2).max(0.0) + (band_log_e[nb + i] - r2).max(0.0)) * 0.5;
1515        }
1516    } else {
1517        for i in start..end {
1518            follower[i] = (band_log_e[i] - follower[i]).max(0.0);
1519        }
1520    }
1521
1522    if !is_transient {
1523        for i in start..end {
1524            follower[i] *= 0.5;
1525        }
1526    }
1527
1528    let mut tot_boost = 0i32;
1529    for i in start..end {
1530        let mut f = follower[i].min(4.0);
1531        if i < 8 {
1532            f *= 2.0;
1533        }
1534        if i >= 12 {
1535            f *= 0.5;
1536        }
1537
1538        let width = channels as i32 * (mode.e_bands[i + 1] - mode.e_bands[i]) as i32 * (1 << lm);
1539        let (boost, boost_bits) = if width < 6 {
1540            let b = f.floor().max(0.0) as i32;
1541            (b, (b * width) << BITRES)
1542        } else if width > 48 {
1543            let b = (f * 8.0).floor().max(0.0) as i32;
1544            (b, ((b * width) << BITRES) / 8)
1545        } else {
1546            let b = (f * width as f32 / 6.0).floor().max(0.0) as i32;
1547            (b, (b * 6) << BITRES)
1548        };
1549
1550        // Keep dynalloc bounded so allocator still has base bits in CBR usage.
1551        let cap_bits = ((2 * effective_bytes as i32) / 3) << (BITRES + 3);
1552        if tot_boost + boost_bits > cap_bits {
1553            offsets[i] = ((cap_bits - tot_boost) >> BITRES).max(0);
1554            break;
1555        }
1556
1557        let quanta = (width << BITRES).min((6 << BITRES).max(width));
1558        let mut boost_count = boost;
1559        let mut as_bits = boost_count * quanta;
1560        if as_bits > cap[i] {
1561            as_bits = cap[i];
1562            boost_count = as_bits / quanta;
1563        }
1564
1565        offsets[i] = boost_count.max(0);
1566        tot_boost += boost_bits.max(0);
1567    }
1568}
1569
1570impl CeltEncoder {
1571    pub fn new(mode: &'static CeltMode, channels: usize) -> Self {
1572        let overlap = mode.overlap;
1573        let channel_mem_size = 2048 + overlap;
1574        let syn_mem_size = channels * channel_mem_size;
1575        let nb_ebands = mode.nb_ebands;
1576        let nb_x_ch = nb_ebands * channels;
1577        let frame_x_ch = MAX_FRAME_SIZE * channels;
1578        let bufstride_x_ch = (MAX_FRAME_SIZE + overlap) * channels;
1579        Self {
1580            mode,
1581            channels,
1582            complexity: 9,
1583            syn_mem: vec![0.0; syn_mem_size],
1584            enc_decode_mem: vec![0.0; syn_mem_size],
1585            old_band_e: vec![0.0; nb_x_ch],
1586            preemph_mem: vec![0.0; channels],
1587            tonal_average: 256,
1588            hf_average: 0,
1589            tapset_decision: 0,
1590            spread_decision: SPREAD_NORMAL,
1591            intensity: 0,
1592            last_coded_bands: 0,
1593            prefilter_mem: vec![0.0; channels * COMBFILTER_MAXPERIOD],
1594            prefilter_period: COMBFILTER_MINPERIOD,
1595            prefilter_gain: 0.0,
1596            prefilter_tapset: 0,
1597            old_band_e2: vec![0.0; nb_x_ch],
1598            old_band_e3: vec![0.0; nb_x_ch],
1599            last_band_log_e: vec![0.0; nb_x_ch],
1600            delayed_intra: 0.0,
1601
1602            w_in_buf: vec![0.0; bufstride_x_ch],
1603            w_freq: vec![0.0; frame_x_ch + 4],
1604            w_band_e: vec![0.0; nb_x_ch],
1605
1606            w_x: vec![0.0; frame_x_ch + STRIDE_ACCESS_PAD],
1607            w_band_log_e: vec![0.0; nb_x_ch],
1608            w_error: vec![0.0; nb_x_ch],
1609            w_tf_res: vec![0; nb_ebands],
1610            w_cap: vec![0; nb_ebands],
1611            w_offsets: vec![0; nb_ebands],
1612            w_pulses: vec![0; nb_ebands],
1613            w_ebits: vec![0; nb_x_ch],
1614            w_fine_priority: vec![0; nb_x_ch],
1615            w_collapse_masks: vec![0; nb_x_ch],
1616            w_band_amp_synth: vec![0.0; nb_x_ch],
1617            w_freq_synth: vec![0.0; frame_x_ch + 4],
1618
1619            w_prefilter_pre: vec![0.0; channels * (COMBFILTER_MAXPERIOD + MAX_FRAME_SIZE)],
1620            w_prefilter_pitch_buf: vec![0.0; (COMBFILTER_MAXPERIOD + MAX_FRAME_SIZE) >> 1],
1621            w_prefilter_before: vec![0.0; channels],
1622            w_prefilter_after: vec![0.0; channels],
1623            w_transient_tmp: vec![0.0; MAX_TRANSIENT_LEN],
1624            w_transient_tmp2: vec![0.0; MAX_TRANSIENT_LEN / 2],
1625            consec_transient: 0,
1626        }
1627    }
1628
1629    pub fn encode(&mut self, pcm: &[f32], frame_size: usize, rc: &mut RangeCoder) {
1630        self.encode_impl(pcm, frame_size, rc, 0, None)
1631    }
1632
1633    pub fn encode_with_start_band(
1634        &mut self,
1635        pcm: &[f32],
1636        frame_size: usize,
1637        rc: &mut RangeCoder,
1638        start_band: usize,
1639    ) {
1640        self.encode_impl(pcm, frame_size, rc, start_band, None)
1641    }
1642
1643    pub fn encode_with_budget(
1644        &mut self,
1645        pcm: &[f32],
1646        frame_size: usize,
1647        rc: &mut RangeCoder,
1648        start_band: usize,
1649        total_bits: i32,
1650    ) {
1651        self.encode_impl(pcm, frame_size, rc, start_band, Some(total_bits))
1652    }
1653
1654    fn encode_impl(
1655        &mut self,
1656        pcm: &[f32],
1657        frame_size: usize,
1658        rc: &mut RangeCoder,
1659        start_band: usize,
1660        explicit_total_bits: Option<i32>,
1661    ) {
1662        let mode = self.mode;
1663        let channels = self.channels;
1664        let nb_ebands = mode.nb_ebands;
1665        let overlap = mode.overlap;
1666
1667        let mut lm = 0;
1668        while (mode.short_mdct_size << lm) != frame_size {
1669            lm += 1;
1670            if lm > mode.max_lm {
1671                break;
1672            }
1673        }
1674        if (mode.short_mdct_size << lm) != frame_size {
1675            lm = 0;
1676        }
1677
1678        let syn_mem_size = 2048 + overlap;
1679        for c in 0..channels {
1680            let channel_offset = c * syn_mem_size;
1681
1682            self.syn_mem.copy_within(
1683                channel_offset + frame_size..channel_offset + syn_mem_size,
1684                channel_offset,
1685            );
1686
1687            let mut m = self.preemph_mem[c];
1688            let coef = mode.preemph[0];
1689            for i in 0..frame_size {
1690                let x = pcm[c * frame_size + i] * 32768.0;
1691                let val = x - m;
1692                self.syn_mem[channel_offset + syn_mem_size - frame_size + i] = val;
1693                m = x * coef;
1694            }
1695            self.preemph_mem[c] = m;
1696        }
1697
1698        let buf_stride = frame_size + overlap;
1699        let in_buf = &mut self.w_in_buf[..buf_stride * channels];
1700        for c in 0..channels {
1701            let channel_offset = c * syn_mem_size;
1702            let in_buf_offset = c * buf_stride;
1703
1704            let src_start = syn_mem_size - frame_size - overlap;
1705            in_buf[in_buf_offset..in_buf_offset + buf_stride].copy_from_slice(
1706                &self.syn_mem[channel_offset + src_start..channel_offset + syn_mem_size],
1707            );
1708        }
1709
1710        let mut tf_estimate = 0.0f32;
1711        let mut tf_chan = 0;
1712        let mut weak_transient = false;
1713
1714        let is_transient = if self.complexity >= 1 {
1715            transient_analysis(
1716                in_buf,
1717                buf_stride,
1718                channels,
1719                &mut tf_estimate,
1720                &mut tf_chan,
1721                false,
1722                &mut weak_transient,
1723                0.0,
1724                0.0,
1725                &mut self.w_transient_tmp,
1726                &mut self.w_transient_tmp2,
1727            )
1728        } else {
1729            false
1730        };
1731
1732        let pf_enabled = start_band == 0 && self.complexity >= 5;
1733        let (pf_on, gain1, pitch_index) = if pf_enabled {
1734            run_prefilter(
1735                in_buf,
1736                &mut self.prefilter_mem,
1737                self.prefilter_period,
1738                self.prefilter_gain,
1739                self.prefilter_tapset,
1740                self.tapset_decision,
1741                mode.window,
1742                channels,
1743                frame_size,
1744                overlap,
1745                &mut self.w_prefilter_pre,
1746                &mut self.w_prefilter_pitch_buf,
1747                &mut self.w_prefilter_before,
1748                &mut self.w_prefilter_after,
1749            )
1750        } else {
1751            (false, 0.0f32, COMBFILTER_MINPERIOD)
1752        };
1753
1754        // Save the prefiltered overlap for the next frame.
1755        // In libopus, st->in_mem stores the overlap separately and run_prefilter
1756        // copies it to/from in[]. Here we emulate that by updating syn_mem with
1757        // the last overlap samples of in_buf (which were prefiltered in place).
1758        let syn_mem_size = 2048 + overlap;
1759        for c in 0..channels {
1760            let channel_offset = c * syn_mem_size;
1761            let in_buf_offset = c * buf_stride;
1762            self.syn_mem[channel_offset + syn_mem_size - overlap
1763                ..channel_offset + syn_mem_size]
1764                .copy_from_slice(
1765                    &in_buf[in_buf_offset + frame_size..in_buf_offset + buf_stride],
1766                );
1767        }
1768
1769        let freq = &mut self.w_freq[..frame_size * channels];
1770        let (shift, b) = if is_transient {
1771            (mode.max_lm, 1 << lm)
1772        } else {
1773            (mode.max_lm - lm, 1)
1774        };
1775        let n = frame_size / b;
1776
1777        for c in 0..channels {
1778            let c_buf_offset = c * buf_stride;
1779
1780            if c == 0 && b == 1 && channels == 1 {
1781                let mut max_val = 0.0f32;
1782                let check_len = (frame_size + overlap).min(buf_stride);
1783                for j in 0..check_len {
1784                    max_val = max_val.max(in_buf[c_buf_offset + j].abs());
1785                }
1786            }
1787
1788            for i in 0..b {
1789                mode.mdct.forward(
1790                    &in_buf[c_buf_offset + i * n..],
1791                    &mut freq[c * frame_size + i..],
1792                    mode.window,
1793                    overlap,
1794                    shift,
1795                    b,
1796                );
1797            }
1798        }
1799
1800        let band_e = &mut self.w_band_e[..nb_ebands * channels];
1801        compute_band_energies(mode, freq, band_e, nb_ebands, channels, lm);
1802
1803        let x_pad_end = (frame_size * channels + STRIDE_ACCESS_PAD).min(self.w_x.len());
1804        let x = &mut self.w_x[..x_pad_end];
1805        normalise_bands(
1806            mode,
1807            freq,
1808            x,
1809            band_e,
1810            nb_ebands,
1811            channels,
1812            (1 << lm) as usize,
1813        );
1814
1815        if channels == 1 {
1816            let _ = freq[0];
1817        }
1818
1819        let band_log_e = &mut self.w_band_log_e[..nb_ebands * channels];
1820        crate::bands::amp2log2(mode, start_band, nb_ebands, band_e, band_log_e, channels);
1821
1822        let total_bits = explicit_total_bits.unwrap_or_else(|| (rc.buf.len() * 8) as i32);
1823        self.w_error[..nb_ebands * channels].fill(0.0);
1824        let error = &mut self.w_error[..nb_ebands * channels];
1825
1826        let tell = rc.tell();
1827        let silence = false;
1828        if tell == 1 {
1829            rc.encode_bit_logp(silence, 15);
1830        }
1831
1832        if start_band == 0 && !silence && rc.tell() + 16 <= total_bits {
1833            rc.encode_bit_logp(pf_on, 1);
1834            if pf_on {
1835                let qg = (gain1 / 0.09375 - 1.0 + 0.5).floor() as i32;
1836                let qg = qg.clamp(0, 7);
1837                let pi = (pitch_index + 1) as u32;
1838                let octave = 31 - pi.leading_zeros();
1839                let octave = (octave as i32 - 5).max(0) as u32;
1840                rc.enc_uint(octave, 6);
1841                rc.enc_bits(pi - (16 << octave), 4 + octave);
1842                rc.enc_bits(qg as u32, 3);
1843                rc.encode_icdf(self.tapset_decision, &TAPSET_ICDF, 2);
1844            }
1845        }
1846
1847        let mut short_blocks = false;
1848        if lm > 0 && rc.tell() + 3 <= total_bits {
1849            rc.encode_bit_logp(is_transient, 3);
1850            if is_transient {
1851                short_blocks = true;
1852            }
1853        }
1854
1855        if short_blocks {
1856            let b = 1 << lm;
1857            let n = frame_size / b;
1858            for c in 0..channels {
1859                let c_offset = c * buf_stride;
1860                for i in 0..b {
1861                    mode.mdct.forward(
1862                        &in_buf[c_offset + i * n..c_offset + buf_stride],
1863                        &mut freq[c * frame_size + i..],
1864                        mode.window,
1865                        overlap,
1866                        mode.max_lm,
1867                        b,
1868                    );
1869                }
1870            }
1871
1872            compute_band_energies(mode, freq, band_e, nb_ebands, channels, lm);
1873            normalise_bands(
1874                mode,
1875                freq,
1876                x,
1877                band_e,
1878                nb_ebands,
1879                channels,
1880                (1 << lm) as usize,
1881            );
1882        }
1883
1884        let intra_ener = if self.complexity >= 4 {
1885            false
1886        } else {
1887            self.old_band_e[..nb_ebands * channels]
1888                .iter()
1889                .all(|&e| e <= -27.0)
1890        };
1891        quant_coarse_energy_advanced(
1892            mode,
1893            start_band,
1894            nb_ebands,
1895            nb_ebands,
1896            band_log_e,
1897            &mut self.old_band_e,
1898            total_bits as u32,
1899            error,
1900            rc,
1901            channels,
1902            lm,
1903            (total_bits / 8) as usize,
1904            is_transient || intra_ener,
1905            &mut self.delayed_intra,
1906            self.complexity >= 4,
1907            0,
1908            false,
1909        );
1910        self.w_tf_res[..nb_ebands].fill(0);
1911        let tf_res = &mut self.w_tf_res[..nb_ebands];
1912        let effective_bytes = ((total_bits / 8) as usize).max(1);
1913        let lambda = 80.max(20480 / effective_bytes + 2) as i32;
1914
1915        let tf_select = if self.complexity >= 2 && effective_bytes >= 15 * channels {
1916            tf_analysis(
1917                mode,
1918                nb_ebands,
1919                is_transient,
1920                tf_res,
1921                lambda,
1922                x,
1923                frame_size,
1924                lm as i32,
1925                tf_estimate,
1926                tf_chan,
1927            )
1928        } else {
1929            0
1930        };
1931        tf_encode(
1932            start_band,
1933            nb_ebands,
1934            is_transient,
1935            tf_res,
1936            lm as i32,
1937            tf_select,
1938            rc,
1939        );
1940
1941        let mut dual_stereo_val = if channels == 2 {
1942            stereo_analysis(mode, x, lm as i32, frame_size) as i32
1943        } else {
1944            0
1945        };
1946
1947        let mut stereo_saving = 0.0f32;
1948        let equiv_rate = (total_bits * 48000) / frame_size as i32;
1949        if channels == 2 {
1950            self.intensity = hysteresis_decision(
1951                equiv_rate / 1000,
1952                &INTEN_THRESHOLDS,
1953                &INTEN_HYSTERESIS,
1954                self.intensity,
1955            );
1956            self.intensity = self.intensity.clamp(0, nb_ebands as i32);
1957        }
1958
1959        if self.complexity == 0 {
1960            self.spread_decision = SPREAD_NONE;
1961            if rc.tell() + 4 <= total_bits {
1962                rc.encode_icdf(self.spread_decision, &SPREAD_ICDF, 5);
1963            }
1964        } else if rc.tell() + 4 <= total_bits {
1965            if is_transient || self.complexity < 3 || effective_bytes < 10 * channels {
1966                self.spread_decision = SPREAD_NORMAL;
1967            } else {
1968                let update_hf = lm == mode.max_lm;
1969                let spread_weights = [32i32; 21];
1970                self.spread_decision = spreading_decision(
1971                    mode,
1972                    x,
1973                    &mut self.tonal_average,
1974                    self.spread_decision,
1975                    &mut self.hf_average,
1976                    &mut self.tapset_decision,
1977                    update_hf,
1978                    nb_ebands,
1979                    channels,
1980                    (1 << lm) as usize,
1981                    &spread_weights,
1982                );
1983            }
1984            rc.encode_icdf(self.spread_decision, &SPREAD_ICDF, 5);
1985        } else {
1986            self.spread_decision = SPREAD_NORMAL;
1987        }
1988
1989        self.w_cap[..nb_ebands].fill(0);
1990        let cap = &mut self.w_cap[..nb_ebands];
1991        for (i, cap_i) in cap.iter_mut().enumerate() {
1992            let n = (mode.e_bands[i + 1] - mode.e_bands[i]) << lm;
1993            *cap_i = ((mode.cache.caps[nb_ebands * (2 * lm + channels - 1) + i] as i32 + 64)
1994                * channels as i32
1995                * n as i32)
1996                >> 2;
1997        }
1998
1999        self.w_offsets[..nb_ebands].fill(0);
2000        let offsets = &mut self.w_offsets[..nb_ebands];
2001
2002        dynalloc_analysis_simple(
2003            mode,
2004            band_log_e,
2005            &self.old_band_e,
2006            start_band,
2007            nb_ebands,
2008            channels,
2009            lm,
2010            effective_bytes,
2011            is_transient,
2012            offsets,
2013            cap,
2014        );
2015
2016        let mut dynalloc_logp = 6i32;
2017        let total_bits_bitres = total_bits << BITRES;
2018        let mut total_boost = 0i32;
2019        let mut tell_frac = rc.tell_frac();
2020
2021        for i in start_band..nb_ebands {
2022            let width =
2023                channels as i32 * (mode.e_bands[i + 1] - mode.e_bands[i]) as i32 * (1 << lm);
2024            let quanta = (width << BITRES).min((6 << BITRES).max(width));
2025            let mut dynalloc_loop_logp = dynalloc_logp;
2026            let mut boost = 0i32;
2027            let mut j = 0i32;
2028
2029            while tell_frac + (dynalloc_loop_logp << BITRES) < total_bits_bitres - total_boost
2030                && boost < cap[i]
2031            {
2032                let flag = j < offsets[i];
2033                rc.encode_bit_logp(flag, dynalloc_loop_logp as u32);
2034                tell_frac = rc.tell_frac();
2035                if !flag {
2036                    break;
2037                }
2038                boost += quanta;
2039                total_boost += quanta;
2040                dynalloc_loop_logp = 1;
2041                j += 1;
2042            }
2043
2044            if j > 0 {
2045                dynalloc_logp = 2.max(dynalloc_logp - 1);
2046            }
2047            offsets[i] = boost;
2048        }
2049
2050        let alloc_trim = alloc_trim_analysis(
2051            mode,
2052            x,
2053            band_log_e,
2054            nb_ebands,
2055            lm as i32,
2056            channels,
2057            frame_size,
2058            &mut stereo_saving,
2059            tf_estimate,
2060            self.intensity,
2061            0.0,
2062            equiv_rate,
2063        );
2064        if rc.tell_frac() + (6 << BITRES) <= total_bits_bitres - total_boost {
2065            rc.encode_icdf(alloc_trim, &TRIM_ICDF, 7);
2066        }
2067
2068        let mut intensity = self.intensity;
2069        self.w_pulses[..nb_ebands].fill(0);
2070        let pulses = &mut self.w_pulses[..nb_ebands];
2071
2072        let stereo = channels > 1;
2073        let ebands_stereo = if stereo {
2074            nb_ebands * channels
2075        } else {
2076            nb_ebands
2077        };
2078        self.w_fine_priority[..ebands_stereo].fill(0);
2079        let fine_priority = &mut self.w_fine_priority[..ebands_stereo];
2080        self.w_ebits[..ebands_stereo].fill(0);
2081        let ebits = &mut self.w_ebits[..ebands_stereo];
2082        let mut balance = 0;
2083
2084        self.last_coded_bands = clt_compute_allocation(
2085            mode,
2086            start_band,
2087            nb_ebands,
2088            offsets,
2089            cap,
2090            alloc_trim,
2091            &mut intensity,
2092            &mut dual_stereo_val,
2093            (total_bits << BITRES) - rc.tell_frac() - 1,
2094            &mut balance,
2095            pulses,
2096            ebits,
2097            fine_priority,
2098            channels as i32,
2099            lm as i32,
2100            rc,
2101            true,
2102            0,
2103            nb_ebands as i32 - 1,
2104        );
2105
2106        quant_fine_energy(
2107            mode,
2108            start_band,
2109            nb_ebands,
2110            &mut self.old_band_e,
2111            error,
2112            ebits,
2113            rc,
2114            channels,
2115        );
2116
2117        self.w_collapse_masks[..nb_ebands * channels].fill(0);
2118        let collapse_masks = &mut self.w_collapse_masks[..nb_ebands * channels];
2119        let (x_split, y_split) = x.split_at_mut(frame_size);
2120        let y_opt = if channels == 2 { Some(y_split) } else { None };
2121
2122        let anti_collapse_rsv = if is_transient && lm >= 2 {
2123            let remaining = (total_bits << BITRES) - rc.tell_frac() - 1;
2124            if remaining >= ((lm as i32 + 2) << BITRES) {
2125                1i32 << BITRES
2126            } else {
2127                0
2128            }
2129        } else {
2130            0
2131        };
2132
2133        let mut dual_stereo = dual_stereo_val != 0;
2134
2135        let theta_rdo = channels == 2 && !dual_stereo && self.complexity >= 8;
2136        let resynth = theta_rdo;
2137
2138        quant_all_bands(
2139            true,
2140            mode,
2141            start_band,
2142            nb_ebands,
2143            x_split,
2144            y_opt,
2145            collapse_masks,
2146            band_e,
2147            pulses,
2148            short_blocks,
2149            self.spread_decision,
2150            &mut dual_stereo,
2151            intensity as usize,
2152            tf_res,
2153            (total_bits << BITRES) - anti_collapse_rsv,
2154            &mut balance,
2155            rc,
2156            lm as i32,
2157            self.last_coded_bands,
2158            resynth,
2159            &mut 0u32,
2160        );
2161
2162        if anti_collapse_rsv > 0 {
2163            let anti_collapse_on = if self.consec_transient < 2 {
2164                1u32
2165            } else {
2166                0u32
2167            };
2168            rc.enc_bits(anti_collapse_on, 1);
2169        }
2170
2171        quant_energy_finalise(
2172            mode,
2173            start_band,
2174            nb_ebands,
2175            &mut self.old_band_e,
2176            error,
2177            ebits,
2178            fine_priority,
2179            total_bits - rc.tell(),
2180            rc,
2181            channels,
2182        );
2183
2184        if resynth {
2185            let band_amp_synth = &mut self.w_band_amp_synth[..nb_ebands * channels];
2186            log2amp(mode, nb_ebands, band_amp_synth, &self.old_band_e, channels);
2187            self.w_freq_synth[..frame_size * channels].fill(0.0);
2188            let freq_synth = &mut self.w_freq_synth[..frame_size * channels];
2189            denormalise_bands(
2190                mode,
2191                x,
2192                freq_synth,
2193                band_amp_synth,
2194                start_band,
2195                nb_ebands,
2196                channels,
2197                (1 << lm) as usize,
2198            );
2199            let (syn_shift, syn_b) = if is_transient {
2200                (mode.max_lm, 1 << lm)
2201            } else {
2202                (mode.max_lm - lm, 1)
2203            };
2204            let syn_n = frame_size / syn_b;
2205            let decode_buf_size = 2048;
2206
2207            for c in 0..channels {
2208                let co = c * syn_mem_size;
2209                self.enc_decode_mem
2210                    .copy_within(co + frame_size..co + decode_buf_size + overlap, co);
2211            }
2212
2213            for c in 0..channels {
2214                let co = c * syn_mem_size;
2215                let out_syn_idx = decode_buf_size - frame_size;
2216                for bi in 0..syn_b {
2217                    let syn_stride = if is_transient {
2218                        mode.short_mdct_size
2219                    } else {
2220                        syn_n
2221                    };
2222                    mode.mdct.backward(
2223                        &freq_synth[c * frame_size + bi..],
2224                        &mut self.enc_decode_mem[co + out_syn_idx + bi * syn_stride..],
2225                        mode.window,
2226                        overlap,
2227                        syn_shift,
2228                        syn_b,
2229                    );
2230                }
2231            }
2232        }
2233
2234        self.last_band_log_e.copy_from_slice(&self.old_band_e);
2235
2236        if !is_transient {
2237            self.old_band_e3.copy_from_slice(&self.old_band_e2);
2238            self.old_band_e2.copy_from_slice(&self.old_band_e);
2239        } else {
2240            for i in 0..channels * nb_ebands {
2241                self.old_band_e2[i] = self.old_band_e2[i].min(self.old_band_e[i]);
2242            }
2243        }
2244
2245        rc.pad_to_bits(total_bits);
2246
2247        if pf_on {
2248            self.prefilter_period = pitch_index;
2249            self.prefilter_gain = gain1;
2250            self.prefilter_tapset = self.tapset_decision;
2251        } else {
2252            self.prefilter_period = COMBFILTER_MINPERIOD;
2253            self.prefilter_gain = 0.0;
2254            self.prefilter_tapset = self.tapset_decision;
2255        }
2256
2257        if is_transient {
2258            self.consec_transient += 1;
2259        } else {
2260            self.consec_transient = 0;
2261        }
2262    }
2263}
2264
2265pub struct CeltDecoder {
2266    mode: &'static CeltMode,
2267    channels: usize,
2268    decode_mem: Vec<f32>,
2269    old_band_e: Vec<f32>,
2270    preemph_mem: Vec<f32>,
2271    prefilter_mem: Vec<f32>,
2272    prefilter_period: usize,
2273    prefilter_period_old: usize,
2274    prefilter_gain: f32,
2275    prefilter_gain_old: f32,
2276    prefilter_tapset: i32,
2277    prefilter_tapset_old: i32,
2278    old_band_e2: Vec<f32>,
2279    old_band_e3: Vec<f32>,
2280    rng: u32,
2281
2282    w_tf_res: Vec<i32>,
2283    w_cap: Vec<i32>,
2284    w_offsets: Vec<i32>,
2285    w_pulses: Vec<i32>,
2286    w_ebits: Vec<i32>,
2287    w_fine_priority: Vec<i32>,
2288    w_x: Vec<f32>,
2289    w_collapse_masks: Vec<u32>,
2290    w_freq: Vec<f32>,
2291    w_band_amp: Vec<f32>,
2292    w_pcm_frame: Vec<f32>,
2293    w_post: Vec<f32>,
2294}
2295
2296impl CeltDecoder {
2297    pub fn new(mode: &'static CeltMode, channels: usize) -> Self {
2298        let overlap = mode.overlap;
2299        let nb_ebands = mode.nb_ebands;
2300        let nb_x_ch = nb_ebands * channels;
2301        let dec_frame_x_ch = DECODE_BUFFER_SIZE * channels;
2302        Self {
2303            mode,
2304            channels,
2305            decode_mem: vec![0.0; channels * (DECODE_BUFFER_SIZE + overlap)],
2306            old_band_e: vec![0.0; nb_x_ch],
2307            preemph_mem: vec![0.0; channels],
2308            prefilter_mem: vec![0.0; channels * COMBFILTER_MAXPERIOD],
2309            prefilter_period: COMBFILTER_MINPERIOD,
2310            prefilter_period_old: COMBFILTER_MINPERIOD,
2311            prefilter_gain: 0.0,
2312            prefilter_gain_old: 0.0,
2313            prefilter_tapset: 0,
2314            prefilter_tapset_old: 0,
2315            old_band_e2: vec![0.0; nb_x_ch],
2316            old_band_e3: vec![0.0; nb_x_ch],
2317            rng: 0,
2318
2319            w_tf_res: vec![0; nb_ebands],
2320            w_cap: vec![0; nb_ebands],
2321            w_offsets: vec![0; nb_ebands],
2322            w_pulses: vec![0; nb_ebands],
2323            w_ebits: vec![0; nb_x_ch],
2324            w_fine_priority: vec![0; nb_x_ch],
2325
2326            w_x: vec![0.0; dec_frame_x_ch + STRIDE_ACCESS_PAD],
2327            w_collapse_masks: vec![0; nb_x_ch],
2328            w_freq: vec![0.0; dec_frame_x_ch + 4], // +4: NEON backward pre-rotation reads up to 3 elements past n2
2329            w_band_amp: vec![0.0; nb_x_ch],
2330            w_pcm_frame: vec![0.0; DECODE_BUFFER_SIZE],
2331            w_post: vec![0.0; DECODE_BUFFER_SIZE + COMBFILTER_MAXPERIOD],
2332        }
2333    }
2334
2335    pub fn decode(&mut self, compressed: &[u8], frame_size: usize, pcm: &mut [f32]) -> usize {
2336        self.decode_impl(compressed, frame_size, pcm, 0)
2337    }
2338
2339    pub fn decode_with_start_band(
2340        &mut self,
2341        compressed: &[u8],
2342        frame_size: usize,
2343        pcm: &mut [f32],
2344        start_band: usize,
2345    ) -> usize {
2346        self.decode_impl(compressed, frame_size, pcm, start_band)
2347    }
2348
2349    pub fn decode_from_range_coder(
2350        &mut self,
2351        rc: &mut RangeCoder,
2352        total_bits: i32,
2353        frame_size: usize,
2354        pcm: &mut [f32],
2355        start_band: usize,
2356    ) -> usize {
2357        self.decode_impl_from_rc(rc, total_bits, frame_size, pcm, start_band)
2358    }
2359
2360    fn decode_impl(
2361        &mut self,
2362        compressed: &[u8],
2363        frame_size: usize,
2364        pcm: &mut [f32],
2365        start_band: usize,
2366    ) -> usize {
2367        let total_bits = (compressed.len() * 8) as i32;
2368        let mut rc = RangeCoder::new_decoder(compressed);
2369        self.decode_impl_from_rc(&mut rc, total_bits, frame_size, pcm, start_band)
2370    }
2371
2372    fn decode_impl_from_rc(
2373        &mut self,
2374        rc: &mut RangeCoder,
2375        total_bits: i32,
2376        frame_size: usize,
2377        pcm: &mut [f32],
2378        start_band: usize,
2379    ) -> usize {
2380        let mode = self.mode;
2381        let channels = self.channels;
2382        let nb_ebands = mode.nb_ebands;
2383        let overlap = mode.overlap;
2384
2385        let mut lm = 0;
2386        while (mode.short_mdct_size << lm) != frame_size {
2387            lm += 1;
2388            if lm > mode.max_lm {
2389                break;
2390            }
2391        }
2392        if (mode.short_mdct_size << lm) != frame_size {
2393            lm = 0;
2394        }
2395
2396        let tell = rc.tell();
2397        let mut silence = false;
2398        if tell >= total_bits {
2399            silence = true;
2400        } else if tell == 1 {
2401            silence = rc.decode_bit_logp(15);
2402        }
2403
2404        if silence {
2405            pcm[..frame_size * channels].fill(0.0);
2406            return frame_size;
2407        }
2408
2409        let mut pf_on = false;
2410        let mut pitch_index = COMBFILTER_MINPERIOD;
2411        let mut gain1 = 0.0f32;
2412        let mut prefilter_tapset = 0;
2413
2414        if start_band == 0 && !silence && rc.tell() + 16 <= total_bits {
2415            pf_on = rc.decode_bit_logp(1);
2416            if pf_on {
2417                let octave = rc.dec_uint(6);
2418                pitch_index = ((16 << octave) + rc.dec_bits(4 + octave)) as usize - 1;
2419                let qg = rc.dec_bits(3);
2420                if rc.tell() + 2 <= total_bits {
2421                    prefilter_tapset = rc.decode_icdf(&TAPSET_ICDF, 2) as usize;
2422                }
2423                gain1 = 0.09375 * (qg as f32 + 1.0);
2424            }
2425        }
2426        if start_band != 0 {
2427            self.prefilter_gain = 0.0;
2428        }
2429
2430        let mut is_transient = false;
2431        if lm > 0 && rc.tell() + 3 <= total_bits {
2432            is_transient = rc.decode_bit_logp(3);
2433        }
2434        let short_blocks = is_transient;
2435
2436        unquant_coarse_energy(
2437            mode,
2438            start_band,
2439            nb_ebands,
2440            &mut self.old_band_e,
2441            total_bits as u32,
2442            rc,
2443            channels,
2444            lm,
2445        );
2446        self.w_tf_res[..nb_ebands].fill(0);
2447        let tf_res = &mut self.w_tf_res[..nb_ebands];
2448        tf_decode(start_band, nb_ebands, is_transient, tf_res, lm as i32, rc);
2449
2450        let spread_decision = if rc.tell() + 4 <= total_bits {
2451            rc.decode_icdf(&SPREAD_ICDF, 5)
2452        } else {
2453            SPREAD_NORMAL
2454        };
2455
2456        self.w_cap[..nb_ebands].fill(0);
2457        let cap = &mut self.w_cap[..nb_ebands];
2458        for (i, cap_i) in cap.iter_mut().enumerate() {
2459            let n = (mode.e_bands[i + 1] - mode.e_bands[i]) << lm;
2460            *cap_i = ((mode.cache.caps[nb_ebands * (2 * lm + channels - 1) + i] as i32 + 64)
2461                * channels as i32
2462                * n as i32)
2463                >> 2;
2464        }
2465
2466        self.w_offsets[..nb_ebands].fill(0);
2467        let offsets = &mut self.w_offsets[..nb_ebands];
2468        let mut dynalloc_logp = 6i32;
2469        let mut total_bits_bitres = total_bits << BITRES;
2470        let mut tell_frac = rc.tell_frac();
2471        for i in start_band..nb_ebands {
2472            let width =
2473                channels as i32 * (mode.e_bands[i + 1] - mode.e_bands[i]) as i32 * (1 << lm);
2474            let quanta = (width << BITRES).min((6i32 << BITRES).max(width));
2475            let mut dynalloc_loop_logp = dynalloc_logp;
2476            let mut boost = 0i32;
2477            while tell_frac + (dynalloc_loop_logp << BITRES) < total_bits_bitres && boost < cap[i] {
2478                let flag = rc.decode_bit_logp(dynalloc_loop_logp as u32);
2479                tell_frac = rc.tell_frac();
2480                if !flag {
2481                    break;
2482                }
2483                boost += quanta;
2484                total_bits_bitres -= quanta;
2485                dynalloc_loop_logp = 1;
2486            }
2487            offsets[i] = boost;
2488            if boost > 0 {
2489                dynalloc_logp = dynalloc_logp.max(2) - 1;
2490                dynalloc_logp = dynalloc_logp.max(2);
2491            }
2492        }
2493
2494        let alloc_trim = if rc.tell_frac() + (6 << BITRES) <= total_bits_bitres {
2495            rc.decode_icdf(&TRIM_ICDF, 7)
2496        } else {
2497            5
2498        };
2499        let anti_collapse_rsv = if is_transient && lm >= 2 {
2500            let remaining = (total_bits << BITRES) - rc.tell_frac() - 1;
2501            if remaining >= ((lm as i32 + 2) << BITRES) {
2502                1i32 << BITRES
2503            } else {
2504                0
2505            }
2506        } else {
2507            0
2508        };
2509
2510        let mut intensity = 0;
2511        let mut dual_stereo_val = if channels == 2 { 1 } else { 0 };
2512        let mut balance = 0;
2513        self.w_pulses[..nb_ebands].fill(0);
2514        let pulses = &mut self.w_pulses[..nb_ebands];
2515
2516        let ebands_stereo = if channels > 1 {
2517            nb_ebands * channels
2518        } else {
2519            nb_ebands
2520        };
2521        self.w_fine_priority[..ebands_stereo].fill(0);
2522        let fine_priority = &mut self.w_fine_priority[..ebands_stereo];
2523        self.w_ebits[..ebands_stereo].fill(0);
2524        let ebits = &mut self.w_ebits[..ebands_stereo];
2525
2526        let alloc_bits = (total_bits << BITRES) - rc.tell_frac() - 1 - anti_collapse_rsv;
2527        let coded_bands = clt_compute_allocation(
2528            mode,
2529            start_band,
2530            nb_ebands,
2531            offsets,
2532            cap,
2533            alloc_trim,
2534            &mut intensity,
2535            &mut dual_stereo_val,
2536            alloc_bits,
2537            &mut balance,
2538            pulses,
2539            ebits,
2540            fine_priority,
2541            channels as i32,
2542            lm as i32,
2543            rc,
2544            false,
2545            0,
2546            nb_ebands as i32 - 1,
2547        );
2548
2549        unquant_fine_energy(
2550            mode,
2551            start_band,
2552            nb_ebands,
2553            &mut self.old_band_e,
2554            ebits,
2555            rc,
2556            channels,
2557        );
2558
2559        if frame_size > DECODE_BUFFER_SIZE + overlap {
2560            return 0;
2561        }
2562
2563        self.w_x[..frame_size * channels].fill(0.0);
2564
2565        let x_pad_end = (frame_size * channels + STRIDE_ACCESS_PAD).min(self.w_x.len());
2566        let x = &mut self.w_x[..x_pad_end];
2567        self.w_collapse_masks[..nb_ebands * channels].fill(0);
2568        let collapse_masks = &mut self.w_collapse_masks[..nb_ebands * channels];
2569
2570        let (x_split, y_split) = x.split_at_mut(frame_size);
2571        let y_opt = if channels == 2 { Some(y_split) } else { None };
2572
2573        let mut dual_stereo = dual_stereo_val != 0;
2574        self.w_band_amp[..nb_ebands * channels].fill(0.0);
2575        let band_amp = &mut self.w_band_amp[..nb_ebands * channels];
2576        log2amp(mode, nb_ebands, band_amp, &self.old_band_e, channels);
2577        quant_all_bands(
2578            false,
2579            mode,
2580            start_band,
2581            nb_ebands,
2582            x_split,
2583            y_opt,
2584            collapse_masks,
2585            band_amp,
2586            pulses,
2587            short_blocks,
2588            spread_decision,
2589            &mut dual_stereo,
2590            intensity as usize,
2591            tf_res,
2592            (total_bits << BITRES) - anti_collapse_rsv,
2593            &mut balance,
2594            rc,
2595            lm as i32,
2596            coded_bands,
2597            true,
2598            &mut self.rng,
2599        );
2600        // Trace X values for comparison with C decoder
2601        let mut anti_collapse_on = false;
2602        if anti_collapse_rsv > 0 {
2603            anti_collapse_on = rc.dec_bits(1) != 0;
2604        }
2605
2606        unquant_energy_finalise(
2607            mode,
2608            start_band,
2609            nb_ebands,
2610            &mut self.old_band_e,
2611            ebits,
2612            fine_priority,
2613            total_bits - rc.tell(),
2614            rc,
2615            channels,
2616        );
2617        if anti_collapse_on {
2618            self.rng = crate::bands::anti_collapse(
2619                mode,
2620                x,
2621                collapse_masks,
2622                lm as i32,
2623                channels,
2624                frame_size,
2625                start_band,
2626                nb_ebands,
2627                &self.old_band_e,
2628                &self.old_band_e2,
2629                &self.old_band_e3,
2630                pulses,
2631                self.rng,
2632            );
2633        }
2634
2635        // Recompute band_amp after unquant_energy_finalise, which adjusts old_band_e.
2636        // (Mirrors the encoder's resynth path: log2amp is called after quant_energy_finalise.)
2637        log2amp(mode, nb_ebands, band_amp, &self.old_band_e, channels);
2638        self.w_freq[..frame_size * channels].fill(0.0);
2639        let freq = &mut self.w_freq[..frame_size * channels];
2640        denormalise_bands(
2641            mode,
2642            x,
2643            freq,
2644            band_amp,
2645            start_band,
2646            nb_ebands,
2647            channels,
2648            (1 << lm) as usize,
2649        );
2650        // Always trace freq and band_amp for comparison
2651
2652        let (shift, b) = if short_blocks {
2653            (mode.max_lm, 1 << lm)
2654        } else {
2655            (mode.max_lm - lm, 1)
2656        };
2657        let n = frame_size / b;
2658
2659        for c in 0..channels {
2660            let channel_mem_offset = c * (DECODE_BUFFER_SIZE + overlap);
2661
2662            let mem_size = DECODE_BUFFER_SIZE + overlap;
2663            self.decode_mem.copy_within(
2664                channel_mem_offset + frame_size..channel_mem_offset + mem_size,
2665                channel_mem_offset,
2666            );
2667
2668            let out_syn_idx = DECODE_BUFFER_SIZE - frame_size;
2669
2670            for i in 0..b {
2671                let block_freq_idx = c * frame_size + i;
2672                // Stride between short-block MDCT outputs is short_mdct_size (not n).
2673                // In libopus: out_syn[c] + NB*b, where NB = mode->shortMdctSize.
2674                // For non-transient b=1, i*n == 0 either way.
2675                let block_stride = if short_blocks {
2676                    mode.short_mdct_size
2677                } else {
2678                    n
2679                };
2680                let block_out_idx = channel_mem_offset + out_syn_idx + i * block_stride;
2681                let available_len = self.decode_mem.len() - block_out_idx;
2682                if available_len < n + overlap {
2683                    panic!(
2684                        "MDCT backward buffer too small: need {}, have {} (out_syn_idx={}, n={}, overlap={})",
2685                        n + overlap,
2686                        available_len,
2687                        out_syn_idx,
2688                        n,
2689                        overlap
2690                    );
2691                }
2692                self.mode.mdct.backward(
2693                    &freq[block_freq_idx..],
2694                    &mut self.decode_mem[block_out_idx..],
2695                    mode.window,
2696                    overlap,
2697                    shift,
2698                    b,
2699                );
2700            }
2701
2702            const SIG_SAT: f32 = 536870911.0;
2703            for i in 0..frame_size {
2704                let v = &mut self.decode_mem[channel_mem_offset + out_syn_idx + i];
2705                *v = v.clamp(-SIG_SAT, SIG_SAT);
2706            }
2707
2708            self.w_pcm_frame[..frame_size].fill(0.0);
2709            let pcm_frame = &mut self.w_pcm_frame[..frame_size];
2710
2711            pcm_frame.copy_from_slice(
2712                &self.decode_mem[channel_mem_offset + out_syn_idx
2713                    ..channel_mem_offset + out_syn_idx + frame_size],
2714            );
2715            if pf_on || self.prefilter_gain > 0.0 || self.prefilter_gain_old > 0.0 {
2716                // Set up w_post = [prefilter_mem | pcm_frame] for history access.
2717                // We apply combfilter in-place on w_post[COMBFILTER_MAXPERIOD..] so that
2718                // later samples can reference already-filtered earlier samples, matching C's
2719                // in-place comb_filter behavior.
2720                self.w_post[..COMBFILTER_MAXPERIOD].copy_from_slice(
2721                    &self.prefilter_mem[c * COMBFILTER_MAXPERIOD..(c + 1) * COMBFILTER_MAXPERIOD],
2722                );
2723                self.w_post[COMBFILTER_MAXPERIOD..COMBFILTER_MAXPERIOD + frame_size]
2724                    .copy_from_slice(pcm_frame);
2725
2726                let short_n = mode.short_mdct_size;
2727                // Call 1: first short_n samples, transition old→current params
2728                // Apply in-place on w_post[COMBFILTER_MAXPERIOD..], output overwrites input
2729                comb_filter_inplace(
2730                    &mut self.w_post,
2731                    COMBFILTER_MAXPERIOD,
2732                    self.prefilter_period_old,
2733                    self.prefilter_period,
2734                    short_n,
2735                    self.prefilter_gain_old,
2736                    self.prefilter_gain,
2737                    self.prefilter_tapset_old,
2738                    self.prefilter_tapset,
2739                    mode.window,
2740                    overlap,
2741                );
2742                if lm != 0 {
2743                    // Call 2: remaining N-short_n samples, transition current→new params
2744                    comb_filter_inplace(
2745                        &mut self.w_post,
2746                        COMBFILTER_MAXPERIOD + short_n,
2747                        self.prefilter_period,
2748                        pitch_index,
2749                        frame_size - short_n,
2750                        self.prefilter_gain,
2751                        gain1,
2752                        self.prefilter_tapset,
2753                        prefilter_tapset as i32,
2754                        mode.window,
2755                        overlap,
2756                    );
2757                }
2758
2759                pcm_frame.copy_from_slice(
2760                    &self.w_post[COMBFILTER_MAXPERIOD..COMBFILTER_MAXPERIOD + frame_size],
2761                );
2762
2763                self.decode_mem[channel_mem_offset + out_syn_idx
2764                    ..channel_mem_offset + out_syn_idx + frame_size]
2765                    .copy_from_slice(pcm_frame);
2766            }
2767            let mut new_mem = [0.0f32; COMBFILTER_MAXPERIOD];
2768            if frame_size >= COMBFILTER_MAXPERIOD {
2769                new_mem.copy_from_slice(&pcm_frame[frame_size - COMBFILTER_MAXPERIOD..frame_size]);
2770            } else {
2771                new_mem[..COMBFILTER_MAXPERIOD - frame_size].copy_from_slice(
2772                    &self.prefilter_mem
2773                        [c * COMBFILTER_MAXPERIOD + frame_size..(c + 1) * COMBFILTER_MAXPERIOD],
2774                );
2775                new_mem[COMBFILTER_MAXPERIOD - frame_size..].copy_from_slice(pcm_frame);
2776            }
2777            self.prefilter_mem[c * COMBFILTER_MAXPERIOD..(c + 1) * COMBFILTER_MAXPERIOD]
2778                .copy_from_slice(&new_mem);
2779
2780            let coef = mode.preemph[0];
2781            let mut m = self.preemph_mem[c];
2782            const VERY_SMALL: f32 = 1e-30f32;
2783            for i in 0..frame_size {
2784                let x = pcm_frame[i];
2785                let val = (x + VERY_SMALL + m).clamp(-SIG_SAT, SIG_SAT);
2786                pcm[c * frame_size + i] = val * (1.0 / 32768.0);
2787                m = val * coef;
2788            }
2789            self.preemph_mem[c] = m;
2790        }
2791
2792        self.prefilter_period_old = self.prefilter_period;
2793        self.prefilter_gain_old = self.prefilter_gain;
2794        self.prefilter_tapset_old = self.prefilter_tapset;
2795
2796        if pf_on {
2797            self.prefilter_period = pitch_index;
2798            self.prefilter_gain = gain1;
2799            self.prefilter_tapset = prefilter_tapset as i32;
2800        } else {
2801            self.prefilter_period = COMBFILTER_MINPERIOD;
2802            self.prefilter_gain = 0.0;
2803            self.prefilter_tapset = 0;
2804        }
2805
2806        if lm > 0 {
2807            self.prefilter_period_old = self.prefilter_period;
2808            self.prefilter_gain_old = self.prefilter_gain;
2809            self.prefilter_tapset_old = self.prefilter_tapset;
2810        }
2811
2812        if !is_transient {
2813            self.old_band_e3.copy_from_slice(&self.old_band_e2);
2814            self.old_band_e2.copy_from_slice(&self.old_band_e);
2815        } else {
2816            let nb_ebands = mode.nb_ebands;
2817            for i in 0..channels * nb_ebands {
2818                self.old_band_e2[i] = self.old_band_e2[i].min(self.old_band_e[i]);
2819            }
2820        }
2821
2822        self.rng = rc.rng;
2823
2824        frame_size
2825    }
2826}