Skip to main content

opus_rs/
celt.rs

1use crate::bands::{
2    SPREAD_NONE, SPREAD_NORMAL, compute_band_energies, denormalise_bands, haar1, log2amp,
3    normalise_bands, quant_all_bands, spreading_decision,
4};
5use crate::modes::{CeltMode, SPREAD_ICDF, TAPSET_ICDF, TF_SELECT_TABLE, TRIM_ICDF};
6use crate::quant_bands::{
7    quant_coarse_energy_advanced, quant_energy_finalise, quant_fine_energy, unquant_coarse_energy,
8    unquant_energy_finalise, unquant_fine_energy,
9};
10use crate::range_coder::RangeCoder;
11use crate::rate::{BITRES, clt_compute_allocation};
12
13#[cfg(target_arch = "aarch64")]
14use std::arch::aarch64::*;
15
16#[cfg(target_arch = "aarch64")]
17#[inline(always)]
18#[allow(unsafe_op_in_unsafe_fn)]
19unsafe fn sum_abs_neon(x: &[f32], n: usize) -> f32 {
20    let mut sum_vec = vdupq_n_f32(0.0);
21    let mut i = 0;
22
23    while i + 16 <= n {
24        let x0 = vld1q_f32(x.as_ptr().add(i));
25        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
26        let x2 = vld1q_f32(x.as_ptr().add(i + 8));
27        let x3 = vld1q_f32(x.as_ptr().add(i + 12));
28
29        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x0), vdupq_n_f32(1.0));
30        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x1), vdupq_n_f32(1.0));
31        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x2), vdupq_n_f32(1.0));
32        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x3), vdupq_n_f32(1.0));
33
34        i += 16;
35    }
36
37    while i + 8 <= n {
38        let x0 = vld1q_f32(x.as_ptr().add(i));
39        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
40        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x0), vdupq_n_f32(1.0));
41        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x1), vdupq_n_f32(1.0));
42        i += 8;
43    }
44
45    while i + 4 <= n {
46        let x0 = vld1q_f32(x.as_ptr().add(i));
47        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x0), vdupq_n_f32(1.0));
48        i += 4;
49    }
50
51    let mut sum = vaddvq_f32(sum_vec);
52
53    for j in i..n {
54        sum += x[j].abs();
55    }
56
57    sum
58}
59
60#[inline(always)]
61fn sum_abs(x: &[f32]) -> f32 {
62    #[cfg(target_arch = "x86_64")]
63    unsafe {
64        if std::arch::is_x86_feature_detected!("avx") {
65            return sum_abs_avx(x, x.len());
66        }
67    }
68    #[cfg(target_arch = "aarch64")]
69    unsafe {
70        sum_abs_neon(x, x.len())
71    }
72    #[cfg(not(target_arch = "aarch64"))]
73    {
74        x.iter().map(|&v| v.abs()).sum()
75    }
76}
77
78const MAX_FRAME_SIZE: usize = 2880;
79
80const DECODE_BUFFER_SIZE: usize = 3072;
81
82const INV_TABLE: [u8; 128] = [
83    255, 255, 156, 110, 86, 70, 59, 51, 45, 40, 37, 33, 31, 28, 26, 25, 23, 22, 21, 20, 19, 18, 17,
84    16, 16, 15, 15, 14, 13, 13, 12, 12, 12, 12, 11, 11, 11, 10, 10, 10, 9, 9, 9, 9, 9, 9, 8, 8, 8,
85    8, 8, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
86    5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3,
87    3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2,
88];
89
90const MAX_TRANSIENT_LEN: usize = 3000;
91
92#[derive(Debug, Clone, Copy)]
93pub struct AnalysisInfo {
94    pub valid: bool,
95    pub tonality: f32,
96    pub tonality_slope: f32,
97    pub noisiness: f32,
98    pub activity: f32,
99    pub music_prob: f32,
100    pub music_prob_min: f32,
101    pub music_prob_max: f32,
102    pub bandwidth: i32,
103    pub activity_probability: f32,
104    pub max_pitch_ratio: f32,
105    pub leak_boost: [u8; 19], // LEAK_BANDS = 19
106}
107
108impl Default for AnalysisInfo {
109    fn default() -> Self {
110        Self {
111            valid: false,
112            tonality: 0.0,
113            tonality_slope: 0.0,
114            noisiness: 0.0,
115            activity: 0.0,
116            music_prob: 0.0,
117            music_prob_min: 0.0,
118            music_prob_max: 0.0,
119            bandwidth: 0,
120            activity_probability: 0.0,
121            max_pitch_ratio: 1.0,
122            leak_boost: [0; 19],
123        }
124    }
125}
126
127#[allow(clippy::too_many_arguments)]
128fn transient_analysis(
129    input: &[f32],
130    len: usize,
131    channels: usize,
132    tf_estimate: &mut f32,
133    tf_chan: &mut usize,
134    allow_weak_transients: bool,
135    weak_transient: &mut bool,
136    _tone_freq: f32,
137    toneishness: f32,
138    tmp: &mut [f32],
139    tmp2: &mut [f32],
140) -> bool {
141    let mut mask_metric = 0.0f32;
142    let mut forward_decay = 0.0625f32;
143
144    *weak_transient = false;
145    if allow_weak_transients {
146        forward_decay = 0.03125f32;
147    }
148
149    let len2 = len / 2;
150    debug_assert!(len <= MAX_TRANSIENT_LEN);
151
152    for c in 0..channels {
153        let mut mem0 = 0.0f32;
154        let mut mem1 = 0.0f32;
155
156        for i in 0..len {
157            let x = input[c * len + i];
158            let y = mem0 + x;
159            let mem00 = mem0;
160            mem0 = mem0 - x + 0.5 * mem1;
161            mem1 = x - mem00;
162            tmp[i] = y;
163        }
164
165        tmp[..12].fill(0.0);
166
167        let mut mean = 0.0f32;
168        mem0 = 0.0f32;
169        for i in 0..len2 {
170            let x2 = (tmp[2 * i] * tmp[2 * i] + tmp[2 * i + 1] * tmp[2 * i + 1]) / 16.0;
171            mean += x2 / 4096.0;
172            mem0 = x2 + (1.0 - forward_decay) * mem0;
173            tmp2[i] = forward_decay * mem0;
174        }
175
176        mem0 = 0.0f32;
177        let mut max_e = 0.0f32;
178        for i in (0..len2).rev() {
179            mem0 = tmp2[i] + 0.875 * mem0;
180            tmp2[i] = 0.125 * mem0;
181            if tmp2[i] > max_e {
182                max_e = tmp2[i];
183            }
184        }
185
186        mean = (mean * max_e * 0.5 * (len2 as f32)).sqrt();
187        let norm = (len2 as f32) / (1e-10 + mean);
188
189        let mut unmask = 0.0f32;
190        for i in (12..(len2 - 5)).step_by(4) {
191            let id = (64.0 * norm * (tmp2[i] + 1e-10)).floor() as i32;
192            let id = id.clamp(0, 127) as usize;
193            unmask += INV_TABLE[id] as f32;
194        }
195
196        unmask = 64.0 * unmask * 4.0 / (6.0 * (len2 as f32 - 17.0));
197        if unmask > mask_metric {
198            *tf_chan = c;
199            mask_metric = unmask;
200        }
201    }
202
203    let mut is_transient = mask_metric > 200.0;
204
205    if toneishness > 0.98 && _tone_freq < 0.026 {
206        is_transient = false;
207        mask_metric = 0.0;
208    }
209
210    *tf_estimate = (mask_metric - 150.0).clamp(0.0, 1.0);
211
212    is_transient
213}
214
215fn l1_metric(tmp: &[f32], n: usize, lm: i32, bias: f32) -> f32 {
216    #[cfg(target_arch = "x86_64")]
217    unsafe {
218        if n >= 16 && std::arch::is_x86_feature_detected!("avx") {
219            return l1_metric_avx(tmp, n, lm, bias);
220        }
221    }
222    #[cfg(target_arch = "aarch64")]
223    {
224        if n >= 16 {
225            return unsafe { l1_metric_neon(tmp, n, lm, bias) };
226        }
227    }
228
229    let mut l1 = 0.0f32;
230    for &tv in tmp[..n].iter() {
231        l1 += tv.abs();
232    }
233    l1 + (lm as f32) * bias * l1
234}
235
236#[cfg(target_arch = "x86_64")]
237#[target_feature(enable = "avx")]
238unsafe fn sum_abs_avx(x: &[f32], n: usize) -> f32 {
239    use std::arch::x86_64::*;
240
241    let mut sum0 = _mm256_setzero_ps();
242    let mut sum1 = _mm256_setzero_ps();
243    let mut i = 0usize;
244    let sign_mask = _mm256_set1_ps(-0.0);
245
246    while i + 16 <= n {
247        let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
248        let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
249        sum0 = _mm256_add_ps(sum0, _mm256_andnot_ps(sign_mask, v0));
250        sum1 = _mm256_add_ps(sum1, _mm256_andnot_ps(sign_mask, v1));
251        i += 16;
252    }
253
254    while i + 8 <= n {
255        let v = _mm256_loadu_ps(x.as_ptr().add(i));
256        sum0 = _mm256_add_ps(sum0, _mm256_andnot_ps(sign_mask, v));
257        i += 8;
258    }
259
260    let sum = _mm256_add_ps(sum0, sum1);
261    let hi = _mm256_extractf128_ps(sum, 1);
262    let lo = _mm256_castps256_ps128(sum);
263    let s4 = _mm_add_ps(lo, hi);
264    let t1 = _mm_movehl_ps(s4, s4);
265    let s2 = _mm_add_ps(s4, t1);
266    let t2 = _mm_shuffle_ps(s2, s2, 0x55);
267    let mut out = _mm_cvtss_f32(_mm_add_ss(s2, t2));
268
269    for j in i..n {
270        out += x[j].abs();
271    }
272
273    out
274}
275
276#[cfg(target_arch = "x86_64")]
277#[target_feature(enable = "avx")]
278unsafe fn l1_metric_avx(tmp: &[f32], n: usize, lm: i32, bias: f32) -> f32 {
279    let l1 = sum_abs_avx(tmp, n);
280    l1 + (lm as f32) * bias * l1
281}
282
283#[cfg(target_arch = "aarch64")]
284#[target_feature(enable = "neon")]
285unsafe fn l1_metric_neon(tmp: &[f32], n: usize, lm: i32, bias: f32) -> f32 {
286    unsafe {
287        let mut sum4 = vdupq_n_f32(0.0);
288        let mut i = 0;
289
290        while i + 15 < n {
291            let v0 = vld1q_f32(tmp.as_ptr().add(i));
292            let v1 = vld1q_f32(tmp.as_ptr().add(i + 4));
293            let v2 = vld1q_f32(tmp.as_ptr().add(i + 8));
294            let v3 = vld1q_f32(tmp.as_ptr().add(i + 12));
295
296            sum4 = vaddq_f32(sum4, vabsq_f32(v0));
297            sum4 = vaddq_f32(sum4, vabsq_f32(v1));
298            sum4 = vaddq_f32(sum4, vabsq_f32(v2));
299            sum4 = vaddq_f32(sum4, vabsq_f32(v3));
300
301            i += 16;
302        }
303
304        while i + 3 < n {
305            let v = vld1q_f32(tmp.as_ptr().add(i));
306            sum4 = vaddq_f32(sum4, vabsq_f32(v));
307            i += 4;
308        }
309
310        let sum2 = vpaddq_f32(sum4, sum4);
311        let sum1 = vpaddq_f32(sum2, sum2);
312        let mut l1 = vgetq_lane_f32(sum1, 0);
313
314        while i < n {
315            l1 += tmp[i].abs();
316            i += 1;
317        }
318
319        l1 + (lm as f32) * bias * l1
320    }
321}
322
323const MAX_NB_EBANDS: usize = 21;
324
325const MAX_TF_TMP: usize = 176;
326
327#[allow(clippy::too_many_arguments)]
328fn tf_analysis(
329    mode: &CeltMode,
330    len: usize,
331    is_transient: bool,
332    tf_res: &mut [i32],
333    lambda: i32,
334    x: &[f32],
335    n0: usize,
336    lm: i32,
337    tf_estimate: f32,
338    tf_chan: usize,
339) -> i32 {
340    debug_assert!(len <= MAX_NB_EBANDS);
341    let mut metric = [0i32; MAX_NB_EBANDS];
342    let mut tmp = [0.0f32; MAX_TF_TMP];
343    let mut tmp_1 = [0.0f32; MAX_TF_TMP];
344
345    let bias = 0.04 * (-0.25f32).max(0.5 - tf_estimate);
346
347    for (i, metric_i) in metric[..len].iter_mut().enumerate() {
348        let n = ((mode.e_bands[i + 1] - mode.e_bands[i]) as usize) << lm;
349        let narrow = (mode.e_bands[i + 1] - mode.e_bands[i]) == 1;
350        let offset = tf_chan * n0 + ((mode.e_bands[i] as usize) << lm);
351        tmp[..n].copy_from_slice(&x[offset..offset + n]);
352
353        let mut l1 = l1_metric(&tmp[..n], n, if is_transient { lm } else { 0 }, bias);
354        let mut best_l1 = l1;
355        let mut best_level = 0;
356
357        if is_transient && !narrow {
358            tmp_1[..n].copy_from_slice(&tmp[..n]);
359            haar1(&mut tmp_1[..n], n >> lm, 1 << lm);
360            l1 = l1_metric(&tmp_1[..n], n, lm + 1, bias);
361            if l1 < best_l1 {
362                best_l1 = l1;
363                best_level = -1;
364            }
365        }
366
367        for k in 0..(lm + if is_transient || narrow { 0 } else { 1 }) {
368            let b = if is_transient { lm - k - 1 } else { k + 1 };
369
370            haar1(&mut tmp[..n], n >> k, 1 << k);
371            l1 = l1_metric(&tmp[..n], n, b, bias);
372
373            if l1 < best_l1 {
374                best_l1 = l1;
375                best_level = k + 1;
376            }
377        }
378
379        if is_transient {
380            *metric_i = 2 * best_level;
381        } else {
382            *metric_i = -2 * best_level;
383        }
384
385        if narrow && (*metric_i == 0 || *metric_i == -2 * lm) {
386            *metric_i -= 1;
387        }
388    }
389
390    let mut tf_select = 0;
391    let importance = [1.0f32; MAX_NB_EBANDS];
392    let mut selcost = [0.0f32; 2];
393
394    for sel in 0..2 {
395        let mut cost0 = importance[0]
396            * ((metric[0]
397                - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * sel] as i32)
398                as f32)
399                .abs();
400        let mut cost1 = importance[0]
401            * ((metric[0]
402                - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * sel + 1]
403                    as i32) as f32)
404                .abs()
405            + (if is_transient { 0.0 } else { lambda as f32 });
406
407        for i in 1..len {
408            let curr0 = cost0.min(cost1 + lambda as f32);
409            let curr1 = (cost0 + lambda as f32).min(cost1);
410            cost0 = curr0
411                + importance[i]
412                    * ((metric[i]
413                        - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * sel]
414                            as i32) as f32)
415                        .abs();
416            cost1 = curr1
417                + importance[i]
418                    * ((metric[i]
419                        - 2 * TF_SELECT_TABLE[lm as usize]
420                            [4 * (is_transient as usize) + 2 * sel + 1]
421                            as i32) as f32)
422                        .abs();
423        }
424        selcost[sel] = cost0.min(cost1);
425    }
426
427    if selcost[1] < selcost[0] {
428        tf_select = 1;
429    }
430
431    let mut cost0 = importance[0]
432        * ((metric[0]
433            - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * tf_select] as i32)
434            as f32)
435            .abs();
436    let mut cost1 = importance[0]
437        * ((metric[0]
438            - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * tf_select + 1]
439                as i32) as f32)
440            .abs()
441        + (if is_transient { 0.0 } else { lambda as f32 });
442
443    tf_res[0] = if cost0 < cost1 { 0 } else { 1 };
444
445    for i in 1..len {
446        let curr0 = cost0.min(cost1 + lambda as f32);
447        let curr1 = (cost0 + lambda as f32).min(cost1);
448        cost0 = curr0
449            + importance[i]
450                * ((metric[i]
451                    - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * tf_select]
452                        as i32) as f32)
453                    .abs();
454        cost1 = curr1
455            + importance[i]
456                * ((metric[i]
457                    - 2 * TF_SELECT_TABLE[lm as usize]
458                        [4 * (is_transient as usize) + 2 * tf_select + 1]
459                        as i32) as f32)
460                    .abs();
461        tf_res[i] = if cost0 < cost1 { 0 } else { 1 };
462    }
463
464    tf_select as i32
465}
466
467fn tf_encode(
468    start: usize,
469    end: usize,
470    is_transient: bool,
471    tf_res: &mut [i32],
472    lm: i32,
473    mut tf_select: i32,
474    rc: &mut RangeCoder,
475) -> i32 {
476    let mut curr = 0;
477    let mut tf_changed = 0;
478    let mut logp = if is_transient { 2 } else { 4 };
479    let mut budget = rc.storage as i32 * 8;
480    let mut tell = rc.tell();
481
482    let tf_select_rsv = if lm > 0 && tell + logp < budget { 1 } else { 0 };
483    budget -= tf_select_rsv;
484
485    for tf_res_i in tf_res[start..end].iter_mut() {
486        if tell + logp <= budget {
487            rc.encode_bit_logp(*tf_res_i ^ curr != 0, logp as u32);
488            tell = rc.tell();
489            curr = *tf_res_i;
490            tf_changed |= curr;
491        } else {
492            *tf_res_i = curr;
493        }
494        logp = if is_transient { 4 } else { 5 };
495    }
496
497    if tf_select_rsv != 0
498        && TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + (tf_changed as usize)]
499            != TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 + (tf_changed as usize)]
500    {
501        rc.encode_bit_logp(tf_select != 0, 1);
502    } else {
503        tf_select = 0;
504    }
505
506    for tf_res_i in tf_res[start..end].iter_mut() {
507        *tf_res_i = TF_SELECT_TABLE[lm as usize]
508            [4 * (is_transient as usize) + 2 * (tf_select as usize) + (*tf_res_i as usize)]
509            as i32;
510    }
511
512    tf_changed
513}
514
515fn tf_decode(
516    start: usize,
517    end: usize,
518    is_transient: bool,
519    tf_res: &mut [i32],
520    lm: i32,
521    rc: &mut RangeCoder,
522) {
523    let mut curr = 0;
524    let mut tf_changed = 0;
525    let mut logp = if is_transient { 2 } else { 4 };
526    let budget = rc.storage as i32 * 8;
527    let mut tell = rc.tell();
528
529    let tf_select_rsv = if lm > 0 && tell + logp < budget { 1 } else { 0 };
530    let budget = budget - tf_select_rsv;
531
532    for tf_res_i in tf_res[start..end].iter_mut() {
533        if tell + logp <= budget {
534            curr ^= if rc.decode_bit_logp(logp as u32) {
535                1
536            } else {
537                0
538            };
539            tell = rc.tell();
540            tf_changed |= curr;
541        }
542        *tf_res_i = curr;
543        logp = if is_transient { 4 } else { 5 };
544    }
545
546    let mut tf_select = 0;
547    let _budget = budget + tf_select_rsv;
548    if tf_select_rsv > 0
549        && TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + (tf_changed as usize)]
550            != TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 + (tf_changed as usize)]
551    {
552        tf_select = if rc.decode_bit_logp(1) { 1 } else { 0 };
553    }
554
555    for tf_res_i in tf_res[start..end].iter_mut() {
556        *tf_res_i = TF_SELECT_TABLE[lm as usize]
557            [4 * (is_transient as usize) + 2 * (tf_select as usize) + (*tf_res_i as usize)]
558            as i32;
559    }
560}
561
562fn stereo_analysis(m: &CeltMode, x: &[f32], lm: i32, n0: usize) -> bool {
563    let mut sum_lr = 1e-9f32;
564    let mut sum_ms = 1e-9f32;
565
566    for i in 0..13 {
567        let start = (m.e_bands[i] as usize) << lm;
568        let end = (m.e_bands[i + 1] as usize) << lm;
569        for j in start..end {
570            let l = x[j];
571            let r = x[n0 + j];
572            let m_val = l + r;
573            let s_val = l - r;
574            sum_lr += l.abs() + r.abs();
575            sum_ms += m_val.abs() + s_val.abs();
576        }
577    }
578
579    sum_ms *= std::f32::consts::FRAC_1_SQRT_2;
580    let mut thetas = 13;
581    if lm <= 1 {
582        thetas -= 8;
583    }
584
585    let left = (((m.e_bands[13] as usize) << (lm + 1)) + thetas) as f32 * sum_ms;
586    let right = ((m.e_bands[13] as usize) << (lm + 1)) as f32 * sum_lr;
587
588    left > right
589}
590
591const COMBFILTER_MINPERIOD: usize = 15;
592const COMBFILTER_MAXPERIOD: usize = 1024;
593
594const PREFILTER_GAINS: [[f32; 3]; 3] = [
595    [0.306_640_6, 0.217_041, 0.129_638_7],
596    [0.463_867_2, 0.268_066_4, 0.0],
597    [0.799_804_7, 0.100_097_7, 0.0],
598];
599
600#[allow(clippy::too_many_arguments)]
601fn comb_filter_const(
602    y: &mut [f32],
603    x: &[f32],
604    y_idx: usize,
605    x_idx: usize,
606    t: usize,
607    n: usize,
608    g10: f32,
609    g11: f32,
610    g12: f32,
611) {
612    #[cfg(target_arch = "aarch64")]
613    {
614        comb_filter_const_neon(y, x, y_idx, x_idx, t, n, g10, g11, g12);
615    }
616    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
617    unsafe {
618        if std::arch::is_x86_feature_detected!("avx") {
619            comb_filter_const_avx(y, x, y_idx, x_idx, t, n, g10, g11, g12);
620            return;
621        }
622    }
623    #[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
624    unsafe {
625        comb_filter_const_sse(y, x, y_idx, x_idx, t, n, g10, g11, g12);
626        #[allow(clippy::needless_return)]
627        return;
628    }
629    #[cfg(not(any(
630        target_arch = "aarch64",
631        all(target_arch = "x86_64", target_feature = "sse")
632    )))]
633    {
634        comb_filter_const_scalar(y, x, y_idx, x_idx, t, n, g10, g11, g12);
635    }
636}
637
638#[inline]
639#[allow(dead_code)]
640fn comb_filter_const_scalar(
641    y: &mut [f32],
642    x: &[f32],
643    y_idx: usize,
644    x_idx: usize,
645    t: usize,
646    n: usize,
647    g10: f32,
648    g11: f32,
649    g12: f32,
650) {
651    let mut x1;
652    let mut x2;
653    let mut x3;
654    let mut x4;
655    let mut x0;
656
657    x4 = x[x_idx - t - 2];
658    x3 = x[x_idx - t - 1];
659    x2 = x[x_idx - t];
660    x1 = x[x_idx - t + 1];
661
662    for i in 0..n {
663        x0 = x[x_idx + i - t + 2];
664        y[y_idx + i] = x[x_idx + i] + g10 * x2 + g11 * (x1 + x3) + g12 * (x0 + x4);
665        x4 = x3;
666        x3 = x2;
667        x2 = x1;
668        x1 = x0;
669    }
670}
671
672#[cfg(target_arch = "aarch64")]
673fn comb_filter_const_neon(
674    y: &mut [f32],
675    x: &[f32],
676    y_idx: usize,
677    x_idx: usize,
678    t: usize,
679    n: usize,
680    g10: f32,
681    g11: f32,
682    g12: f32,
683) {
684    unsafe { comb_filter_const_neon_impl(y, x, y_idx, x_idx, t, n, g10, g11, g12) }
685}
686
687#[cfg(target_arch = "aarch64")]
688#[inline(always)]
689#[allow(unsafe_op_in_unsafe_fn)]
690unsafe fn comb_filter_const_neon_impl(
691    y: &mut [f32],
692    x: &[f32],
693    y_idx: usize,
694    x_idx: usize,
695    t: usize,
696    n: usize,
697    g10: f32,
698    g11: f32,
699    g12: f32,
700) {
701    use std::arch::aarch64::*;
702
703    let g10v = vdupq_n_f32(g10);
704    let g11v = vdupq_n_f32(g11);
705    let g12v = vdupq_n_f32(g12);
706
707    let xbase = x.as_ptr().add(x_idx);
708    let ybase = y.as_mut_ptr().add(y_idx);
709
710    let mut x0v = vld1q_f32(xbase.sub(t + 2));
711
712    let mut i = 0;
713    while i + 4 <= n {
714        let x4v = vld1q_f32(xbase.add(i).sub(t - 2));
715
716        let x2v = vextq_f32(x0v, x4v, 2);
717
718        let x1v = vextq_f32(x0v, x4v, 1);
719
720        let x3v = vextq_f32(x0v, x4v, 3);
721
722        let xi = vld1q_f32(xbase.add(i));
723
724        let mut yi = xi;
725        yi = vfmaq_f32(yi, g10v, x2v);
726        yi = vfmaq_f32(yi, g11v, vaddq_f32(x1v, x3v));
727        yi = vfmaq_f32(yi, g12v, vaddq_f32(x4v, x0v));
728        vst1q_f32(ybase.add(i), yi);
729
730        x0v = x4v;
731        i += 4;
732    }
733
734    let x0v_arr: [f32; 4] = std::mem::transmute(x0v);
735    let mut sx4 = x0v_arr[0];
736    let mut sx3 = x0v_arr[1];
737    let mut sx2 = x0v_arr[2];
738    let mut sx1 = x0v_arr[3];
739
740    while i < n {
741        let sx0 = x[x_idx + i - t + 2];
742        y[y_idx + i] = x[x_idx + i] + g10 * sx2 + g11 * (sx1 + sx3) + g12 * (sx0 + sx4);
743        sx4 = sx3;
744        sx3 = sx2;
745        sx2 = sx1;
746        sx1 = sx0;
747        i += 1;
748    }
749}
750
751#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
752#[inline(always)]
753#[allow(unsafe_op_in_unsafe_fn)]
754unsafe fn comb_filter_const_sse(
755    y: &mut [f32],
756    x: &[f32],
757    y_idx: usize,
758    x_idx: usize,
759    t: usize,
760    n: usize,
761    g10: f32,
762    g11: f32,
763    g12: f32,
764) {
765    use std::arch::x86_64::*;
766
767    let g10v = _mm_set1_ps(g10);
768    let g11v = _mm_set1_ps(g11);
769    let g12v = _mm_set1_ps(g12);
770
771    let xbase = x.as_ptr().add(x_idx);
772    let ybase = y.as_mut_ptr().add(y_idx);
773    let mut x0v = _mm_loadu_ps(xbase.sub(t + 2));
774
775    let mut i = 0;
776    while i + 4 <= n {
777        let x4v = _mm_loadu_ps(xbase.add(i).sub(t - 2));
778
779        let x2v = _mm_shuffle_ps(x0v, x4v, 0x4e);
780
781        let x1v = _mm_shuffle_ps(x0v, x2v, 0x99);
782
783        let x3v = _mm_shuffle_ps(x2v, x4v, 0x99);
784
785        let xi = _mm_loadu_ps(xbase.add(i));
786
787        let mut yi = xi;
788        yi = _mm_add_ps(yi, _mm_mul_ps(g10v, x2v));
789        let yi2 = _mm_add_ps(
790            _mm_mul_ps(g11v, _mm_add_ps(x3v, x1v)),
791            _mm_mul_ps(g12v, _mm_add_ps(x4v, x0v)),
792        );
793        yi = _mm_add_ps(yi, yi2);
794        _mm_storeu_ps(ybase.add(i), yi);
795
796        x0v = x4v;
797        i += 4;
798    }
799
800    let x0v_arr: [f32; 4] = std::mem::transmute(x0v);
801    let mut sx4 = x0v_arr[0];
802    let mut sx3 = x0v_arr[1];
803    let mut sx2 = x0v_arr[2];
804    let mut sx1 = x0v_arr[3];
805
806    while i < n {
807        let sx0 = x[x_idx + i - t + 2];
808        y[y_idx + i] = x[x_idx + i] + g10 * sx2 + g11 * (sx1 + sx3) + g12 * (sx0 + sx4);
809        sx4 = sx3;
810        sx3 = sx2;
811        sx2 = sx1;
812        sx1 = sx0;
813        i += 1;
814    }
815}
816
817#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
818#[target_feature(enable = "avx,fma")]
819#[allow(unsafe_op_in_unsafe_fn)]
820unsafe fn comb_filter_const_avx(
821    y: &mut [f32],
822    x: &[f32],
823    y_idx: usize,
824    x_idx: usize,
825    t: usize,
826    n: usize,
827    g10: f32,
828    g11: f32,
829    g12: f32,
830) {
831    use std::arch::x86_64::*;
832
833    let g10v = _mm256_set1_ps(g10);
834    let g11v = _mm256_set1_ps(g11);
835    let g12v = _mm256_set1_ps(g12);
836
837    let xbase = x.as_ptr().add(x_idx);
838    let ybase = y.as_mut_ptr().add(y_idx);
839
840    let mut i = 0;
841
842    while i + 16 <= n {
843        let xi_a = _mm256_loadu_ps(xbase.add(i));
844        let x0_a = _mm256_loadu_ps(xbase.add(i).sub(t + 2));
845        let x4_a = _mm256_loadu_ps(xbase.add(i).sub(t - 2));
846
847        let x2_a = _mm256_loadu_ps(xbase.add(i).sub(t));
848        let x1x3_a = _mm256_add_ps(
849            _mm256_loadu_ps(xbase.add(i).sub(t + 1)),
850            _mm256_loadu_ps(xbase.add(i).sub(t - 1)),
851        );
852        let x0x4_a = _mm256_add_ps(x0_a, x4_a);
853
854        let mut yi_a = xi_a;
855        yi_a = _mm256_fmadd_ps(g10v, x2_a, yi_a);
856        yi_a = _mm256_fmadd_ps(g11v, x1x3_a, yi_a);
857        yi_a = _mm256_fmadd_ps(g12v, x0x4_a, yi_a);
858        _mm256_storeu_ps(ybase.add(i), yi_a);
859
860        let j = i + 8;
861        let xi_b = _mm256_loadu_ps(xbase.add(j));
862        let x0_b = _mm256_loadu_ps(xbase.add(j).sub(t + 2));
863        let x4_b = _mm256_loadu_ps(xbase.add(j).sub(t - 2));
864        let x2_b = _mm256_loadu_ps(xbase.add(j).sub(t));
865        let x1x3_b = _mm256_add_ps(
866            _mm256_loadu_ps(xbase.add(j).sub(t + 1)),
867            _mm256_loadu_ps(xbase.add(j).sub(t - 1)),
868        );
869        let x0x4_b = _mm256_add_ps(x0_b, x4_b);
870
871        let mut yi_b = xi_b;
872        yi_b = _mm256_fmadd_ps(g10v, x2_b, yi_b);
873        yi_b = _mm256_fmadd_ps(g11v, x1x3_b, yi_b);
874        yi_b = _mm256_fmadd_ps(g12v, x0x4_b, yi_b);
875        _mm256_storeu_ps(ybase.add(j), yi_b);
876
877        i += 16;
878    }
879
880    while i + 8 <= n {
881        let xi = _mm256_loadu_ps(xbase.add(i));
882        let x0 = _mm256_loadu_ps(xbase.add(i).sub(t + 2));
883        let x4 = _mm256_loadu_ps(xbase.add(i).sub(t - 2));
884        let x2 = _mm256_loadu_ps(xbase.add(i).sub(t));
885        let x1x3 = _mm256_add_ps(
886            _mm256_loadu_ps(xbase.add(i).sub(t + 1)),
887            _mm256_loadu_ps(xbase.add(i).sub(t - 1)),
888        );
889        let x0x4 = _mm256_add_ps(x0, x4);
890
891        let mut yi = xi;
892        yi = _mm256_fmadd_ps(g10v, x2, yi);
893        yi = _mm256_fmadd_ps(g11v, x1x3, yi);
894        yi = _mm256_fmadd_ps(g12v, x0x4, yi);
895        _mm256_storeu_ps(ybase.add(i), yi);
896
897        i += 8;
898    }
899
900    if i + 4 <= n {
901        comb_filter_const_sse_fma(y, x, y_idx + i, x_idx + i, t, n - i, g10, g11, g12);
902        return;
903    }
904
905    let mut sx4 = x[x_idx + i - t - 2];
906    let mut sx3 = x[x_idx + i - t - 1];
907    let mut sx2 = x[x_idx + i - t];
908    let mut sx1 = x[x_idx + i - t + 1];
909    while i < n {
910        let sx0 = x[x_idx + i - t + 2];
911        y[y_idx + i] = x[x_idx + i] + g10 * sx2 + g11 * (sx1 + sx3) + g12 * (sx0 + sx4);
912        sx4 = sx3;
913        sx3 = sx2;
914        sx2 = sx1;
915        sx1 = sx0;
916        i += 1;
917    }
918}
919
920#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
921#[target_feature(enable = "avx,fma")]
922#[allow(unsafe_op_in_unsafe_fn)]
923unsafe fn comb_filter_const_sse_fma(
924    y: &mut [f32],
925    x: &[f32],
926    y_idx: usize,
927    x_idx: usize,
928    t: usize,
929    n: usize,
930    g10: f32,
931    g11: f32,
932    g12: f32,
933) {
934    use std::arch::x86_64::*;
935
936    let g10v = _mm_set1_ps(g10);
937    let g11v = _mm_set1_ps(g11);
938    let g12v = _mm_set1_ps(g12);
939
940    let xbase = x.as_ptr().add(x_idx);
941    let ybase = y.as_mut_ptr().add(y_idx);
942    let mut x0v = _mm_loadu_ps(xbase.sub(t + 2));
943
944    let mut i = 0;
945    while i + 4 <= n {
946        let x4v = _mm_loadu_ps(xbase.add(i).sub(t - 2));
947        let x2v = _mm_shuffle_ps(x0v, x4v, 0x4e);
948        let x1v = _mm_shuffle_ps(x0v, x2v, 0x99);
949        let x3v = _mm_shuffle_ps(x2v, x4v, 0x99);
950        let xi = _mm_loadu_ps(xbase.add(i));
951
952        let mut yi = xi;
953        yi = _mm_fmadd_ps(g10v, x2v, yi);
954        yi = _mm_fmadd_ps(g11v, _mm_add_ps(x1v, x3v), yi);
955        yi = _mm_fmadd_ps(g12v, _mm_add_ps(x0v, x4v), yi);
956        _mm_storeu_ps(ybase.add(i), yi);
957
958        x0v = x4v;
959        i += 4;
960    }
961
962    let x0v_arr: [f32; 4] = std::mem::transmute(x0v);
963    let mut sx4 = x0v_arr[0];
964    let mut sx3 = x0v_arr[1];
965    let mut sx2 = x0v_arr[2];
966    let mut sx1 = x0v_arr[3];
967    while i < n {
968        let sx0 = x[x_idx + i - t + 2];
969        y[y_idx + i] = x[x_idx + i] + g10 * sx2 + g11 * (sx1 + sx3) + g12 * (sx0 + sx4);
970        sx4 = sx3;
971        sx3 = sx2;
972        sx2 = sx1;
973        sx1 = sx0;
974        i += 1;
975    }
976}
977
978#[allow(clippy::too_many_arguments)]
979fn comb_filter(
980    y: &mut [f32],
981    x: &[f32],
982    y_idx: usize,
983    x_idx: usize,
984    t0: usize,
985    t1: usize,
986    n: usize,
987    g0: f32,
988    g1: f32,
989    tapset0: i32,
990    tapset1: i32,
991    window: &[f32],
992    overlap: usize,
993) {
994    if g0 == 0.0 && g1 == 0.0 {
995        if x_idx != y_idx || !std::ptr::eq(x.as_ptr(), y.as_ptr()) {
996            y[y_idx..y_idx + n].copy_from_slice(&x[x_idx..x_idx + n]);
997        }
998        return;
999    }
1000
1001    let t0 = t0.clamp(
1002        COMBFILTER_MINPERIOD,
1003        x_idx.saturating_sub(2).max(COMBFILTER_MINPERIOD),
1004    );
1005    let t1 = t1.clamp(
1006        COMBFILTER_MINPERIOD,
1007        x_idx.saturating_sub(2).max(COMBFILTER_MINPERIOD),
1008    );
1009
1010    let g00 = g0 * PREFILTER_GAINS[tapset0 as usize][0];
1011    let g01 = g0 * PREFILTER_GAINS[tapset0 as usize][1];
1012    let g02 = g0 * PREFILTER_GAINS[tapset0 as usize][2];
1013
1014    let g10 = g1 * PREFILTER_GAINS[tapset1 as usize][0];
1015    let g11 = g1 * PREFILTER_GAINS[tapset1 as usize][1];
1016    let g12 = g1 * PREFILTER_GAINS[tapset1 as usize][2];
1017
1018    let mut x1 = x[x_idx - t1 + 1];
1019    let mut x2 = x[x_idx - t1];
1020    let mut x3 = x[x_idx - t1 - 1];
1021    let mut x4 = x[x_idx - t1 - 2];
1022
1023    let mut inner_overlap = overlap;
1024    if g0 == g1 && t0 == t1 && tapset0 == tapset1 {
1025        inner_overlap = 0;
1026    }
1027
1028    let mut i = 0;
1029    while i < inner_overlap && i < n {
1030        let x0 = x[x_idx + i - t1 + 2];
1031        let f = window[i] * window[i];
1032        y[y_idx + i] = x[x_idx + i]
1033            + (1.0 - f)
1034                * (g00 * x[x_idx + i - t0]
1035                    + g01 * (x[x_idx + i - t0 + 1] + x[x_idx + i - t0 - 1])
1036                    + g02 * (x[x_idx + i - t0 + 2] + x[x_idx + i - t0 - 2]))
1037            + f * (g10 * x2 + g11 * (x1 + x3) + g12 * (x0 + x4));
1038
1039        x4 = x3;
1040        x3 = x2;
1041        x2 = x1;
1042        x1 = x0;
1043        i += 1;
1044    }
1045
1046    if i < n {
1047        if g1 == 0.0 {
1048            y[y_idx + i..y_idx + n].copy_from_slice(&x[x_idx + i..x_idx + n]);
1049        } else {
1050            comb_filter_const(y, x, y_idx + i, x_idx + i, t1, n - i, g10, g11, g12);
1051        }
1052    }
1053}
1054
1055/// In-place comb filter: buf[y_idx..y_idx+n] is both input and output.
1056/// Reference samples at buf[y_idx + i - T + offset] may already be filtered
1057/// if T < i, matching C libopus's in-place comb_filter(out, out, ...) behavior.
1058fn comb_filter_inplace(
1059    buf: &mut [f32],
1060    y_idx: usize,
1061    t0: usize,
1062    t1: usize,
1063    n: usize,
1064    g0: f32,
1065    g1: f32,
1066    tapset0: i32,
1067    tapset1: i32,
1068    window: &[f32],
1069    overlap: usize,
1070) {
1071    if g0 == 0.0 && g1 == 0.0 {
1072        // nothing to do; buf[y_idx..] already holds the input
1073        return;
1074    }
1075
1076    let t0 = t0.clamp(COMBFILTER_MINPERIOD, y_idx - 2);
1077    let t1 = t1.clamp(COMBFILTER_MINPERIOD, y_idx - 2);
1078
1079    let g00 = g0 * PREFILTER_GAINS[tapset0 as usize][0];
1080    let g01 = g0 * PREFILTER_GAINS[tapset0 as usize][1];
1081    let g02 = g0 * PREFILTER_GAINS[tapset0 as usize][2];
1082
1083    let g10 = g1 * PREFILTER_GAINS[tapset1 as usize][0];
1084    let g11 = g1 * PREFILTER_GAINS[tapset1 as usize][1];
1085    let g12 = g1 * PREFILTER_GAINS[tapset1 as usize][2];
1086
1087    let mut inner_overlap = overlap;
1088    if g0 == g1 && t0 == t1 && tapset0 == tapset1 {
1089        inner_overlap = 0;
1090    }
1091
1092    let mut i = 0;
1093    while i < inner_overlap && i < n {
1094        let idx = y_idx + i;
1095        let f = window[i] * window[i];
1096        let s = buf[idx]; // original input (not yet overwritten at idx)
1097        let r0 = buf[idx - t0];
1098        let r0p1 = buf[idx - t0 + 1];
1099        let r0m1 = buf[idx - t0 - 1];
1100        let r0p2 = buf[idx - t0 + 2];
1101        let r0m2 = buf[idx - t0 - 2];
1102        let r1 = buf[idx - t1];
1103        let r1p1 = buf[idx - t1 + 1];
1104        let r1m1 = buf[idx - t1 - 1];
1105        let r1p2 = buf[idx - t1 + 2];
1106        let r1m2 = buf[idx - t1 - 2];
1107        buf[idx] = s
1108            + (1.0 - f) * (g00 * r0 + g01 * (r0p1 + r0m1) + g02 * (r0p2 + r0m2))
1109            + f * (g10 * r1 + g11 * (r1p1 + r1m1) + g12 * (r1p2 + r1m2));
1110        i += 1;
1111    }
1112
1113    // Constant region: only new filter (t1, g1)
1114    while i < n {
1115        let idx = y_idx + i;
1116        let s = buf[idx];
1117        let r1 = buf[idx - t1];
1118        let r1p1 = buf[idx - t1 + 1];
1119        let r1m1 = buf[idx - t1 - 1];
1120        let r1p2 = buf[idx - t1 + 2];
1121        let r1m2 = buf[idx - t1 - 2];
1122        buf[idx] = s + g10 * r1 + g11 * (r1p1 + r1m1) + g12 * (r1p2 + r1m2);
1123        i += 1;
1124    }
1125}
1126
1127fn run_prefilter(
1128    in_buf: &mut [f32],
1129    prefilter_mem: &mut [f32],
1130    prefilter_period: usize,
1131    prefilter_gain: f32,
1132    prefilter_tapset: i32,
1133    tapset_decision: i32,
1134    window: &[f32],
1135    channels: usize,
1136    frame_size: usize,
1137    overlap: usize,
1138
1139    pre: &mut [f32],
1140    pitch_buf: &mut [f32],
1141    before: &mut [f32],
1142    after: &mut [f32],
1143
1144    analysis: &AnalysisInfo,
1145    loss_rate: i32,
1146) -> (bool, f32, usize) {
1147    let max_period = COMBFILTER_MAXPERIOD;
1148    let min_period = COMBFILTER_MINPERIOD;
1149    let buf_stride = frame_size + overlap;
1150    let pre_size = max_period + frame_size;
1151
1152    for c in 0..channels {
1153        pre[c * pre_size..c * pre_size + max_period]
1154            .copy_from_slice(&prefilter_mem[c * max_period..(c + 1) * max_period]);
1155        pre[c * pre_size + max_period..c * pre_size + pre_size].copy_from_slice(
1156            &in_buf[c * buf_stride + overlap..c * buf_stride + overlap + frame_size],
1157        );
1158    }
1159
1160    let pitch_buf_len = (max_period + frame_size) >> 1;
1161    {
1162        let pre_slices: Vec<&[f32]> = (0..channels)
1163            .map(|c| &pre[c * pre_size..c * pre_size + pre_size])
1164            .collect();
1165        crate::pitch::pitch_downsample(&pre_slices, pitch_buf, pitch_buf_len, channels, 2);
1166    }
1167
1168    let search_max = max_period - 3 * min_period;
1169    let pitch_result = crate::pitch::pitch_search(
1170        &pitch_buf[max_period >> 1..],
1171        pitch_buf,
1172        frame_size,
1173        search_max,
1174    );
1175    let mut pitch_index = (max_period - pitch_result).min(max_period - 2);
1176
1177    let gain1_raw = crate::pitch::remove_doubling(
1178        pitch_buf,
1179        max_period,
1180        min_period,
1181        frame_size,
1182        &mut pitch_index,
1183        prefilter_period,
1184        prefilter_gain,
1185    );
1186    let mut gain1 = gain1_raw * 0.7;
1187
1188    // Apply max_pitch_ratio from analysis if available
1189    if analysis.valid {
1190        gain1 *= analysis.max_pitch_ratio;
1191    }
1192
1193    // Apply loss_rate scaling: halve at 2%, quarter at 4%, zero at 8%
1194    if loss_rate >= 8 {
1195        gain1 = 0.0;
1196    } else if loss_rate > 0 {
1197        gain1 *= 1.0 - (loss_rate as f32) / 8.0;
1198    }
1199
1200    let mut pf_threshold = 0.2f32;
1201    if (pitch_index as i32 - prefilter_period as i32).unsigned_abs() as usize * 10 > pitch_index {
1202        pf_threshold += 0.2;
1203    }
1204    if prefilter_gain > 0.4 {
1205        pf_threshold -= 0.1;
1206    }
1207    if prefilter_gain > 0.55 {
1208        pf_threshold -= 0.1;
1209    }
1210    pf_threshold = pf_threshold.max(0.2);
1211
1212    let pf_on;
1213    if gain1 < pf_threshold {
1214        gain1 = 0.0;
1215        pf_on = false;
1216    } else {
1217        if (gain1 - prefilter_gain).abs() < 0.1 {
1218            gain1 = prefilter_gain;
1219        }
1220        let qg = ((gain1 * 32.0 / 3.0 + 0.5).floor() as i32 - 1).clamp(0, 7);
1221        gain1 = 0.09375 * (qg + 1) as f32;
1222        pf_on = true;
1223    }
1224
1225    let before = &mut before[..channels];
1226    for c in 0..channels {
1227        let start = c * buf_stride + overlap;
1228        before[c] = sum_abs(&in_buf[start..start + frame_size]);
1229    }
1230
1231    let offset = 0usize;
1232    let prev_period = prefilter_period.clamp(COMBFILTER_MINPERIOD, max_period - 2);
1233
1234    for c in 0..channels {
1235        if offset > 0 {
1236            let pre_c = &pre[c * pre_size..];
1237            comb_filter(
1238                in_buf,
1239                pre_c,
1240                c * buf_stride + overlap,
1241                max_period,
1242                prev_period,
1243                prev_period,
1244                offset,
1245                -prefilter_gain,
1246                -prefilter_gain,
1247                prefilter_tapset,
1248                prefilter_tapset,
1249                window,
1250                0,
1251            );
1252        }
1253
1254        {
1255            let pre_c = &pre[c * pre_size..];
1256            comb_filter(
1257                in_buf,
1258                pre_c,
1259                c * buf_stride + overlap + offset,
1260                max_period + offset,
1261                prev_period,
1262                pitch_index,
1263                frame_size - offset,
1264                -prefilter_gain,
1265                -gain1,
1266                prefilter_tapset,
1267                tapset_decision,
1268                window,
1269                overlap,
1270            );
1271        }
1272    }
1273
1274    let after = &mut after[..channels];
1275    for c in 0..channels {
1276        let start = c * buf_stride + overlap;
1277        after[c] = sum_abs(&in_buf[start..start + frame_size]);
1278    }
1279
1280    let cancel_pitch = (0..channels).any(|c| after[c] > before[c]);
1281
1282    if cancel_pitch {
1283        for c in 0..channels {
1284            in_buf[c * buf_stride + overlap..c * buf_stride + overlap + frame_size]
1285                .copy_from_slice(
1286                    &pre[c * pre_size + max_period..c * pre_size + max_period + frame_size],
1287                );
1288        }
1289
1290        for c in 0..channels {
1291            if frame_size >= max_period {
1292                prefilter_mem[c * max_period..(c + 1) * max_period].copy_from_slice(
1293                    &pre[c * pre_size + frame_size..c * pre_size + frame_size + max_period],
1294                );
1295            } else {
1296                let shift = max_period - frame_size;
1297                prefilter_mem.copy_within(
1298                    c * max_period + frame_size..(c + 1) * max_period,
1299                    c * max_period,
1300                );
1301                prefilter_mem[c * max_period + shift..(c + 1) * max_period].copy_from_slice(
1302                    &pre[c * pre_size + max_period..c * pre_size + max_period + frame_size],
1303                );
1304            }
1305        }
1306        return (false, 0.0, pitch_index);
1307    }
1308
1309    for c in 0..channels {
1310        if frame_size >= max_period {
1311            prefilter_mem[c * max_period..(c + 1) * max_period].copy_from_slice(
1312                &pre[c * pre_size + frame_size..c * pre_size + frame_size + max_period],
1313            );
1314        } else {
1315            let shift = max_period - frame_size;
1316            prefilter_mem.copy_within(
1317                c * max_period + frame_size..(c + 1) * max_period,
1318                c * max_period,
1319            );
1320            prefilter_mem[c * max_period + shift..(c + 1) * max_period].copy_from_slice(
1321                &pre[c * pre_size + max_period..c * pre_size + max_period + frame_size],
1322            );
1323        }
1324    }
1325
1326    (pf_on, gain1, pitch_index)
1327}
1328
1329const STRIDE_ACCESS_PAD: usize = crate::pvq::MAX_PVQ_N * 8;
1330
1331pub struct CeltEncoder {
1332    mode: &'static CeltMode,
1333    channels: usize,
1334    pub complexity: i32,
1335    syn_mem: Vec<f32>,
1336    enc_decode_mem: Vec<f32>,
1337    old_band_e: Vec<f32>,
1338    preemph_mem: Vec<f32>,
1339    tonal_average: i32,
1340    hf_average: i32,
1341    tapset_decision: i32,
1342    spread_decision: i32,
1343    intensity: i32,
1344    last_coded_bands: i32,
1345    prefilter_mem: Vec<f32>,
1346    prefilter_period: usize,
1347    prefilter_gain: f32,
1348    prefilter_tapset: i32,
1349    old_band_e2: Vec<f32>,
1350    old_band_e3: Vec<f32>,
1351    last_band_log_e: Vec<f32>,
1352    delayed_intra: f32,
1353
1354    w_in_buf: Vec<f32>,
1355    w_freq: Vec<f32>,
1356    w_band_e: Vec<f32>,
1357    w_x: Vec<f32>,
1358    w_band_log_e: Vec<f32>,
1359    w_error: Vec<f32>,
1360    w_tf_res: Vec<i32>,
1361    w_cap: Vec<i32>,
1362    w_offsets: Vec<i32>,
1363    w_pulses: Vec<i32>,
1364    w_ebits: Vec<i32>,
1365    w_fine_priority: Vec<i32>,
1366    w_collapse_masks: Vec<u32>,
1367    w_band_amp_synth: Vec<f32>,
1368    w_freq_synth: Vec<f32>,
1369    consec_transient: i32,
1370
1371    w_prefilter_pre: Vec<f32>,
1372    w_prefilter_pitch_buf: Vec<f32>,
1373    w_prefilter_before: Vec<f32>,
1374    w_prefilter_after: Vec<f32>,
1375
1376    w_transient_tmp: Vec<f32>,
1377    w_transient_tmp2: Vec<f32>,
1378
1379    analysis: AnalysisInfo,
1380    loss_rate: i32,
1381}
1382
1383const INTEN_THRESHOLDS: [i32; 21] = [
1384    1, 2, 3, 4, 5, 6, 7, 8, 16, 24, 36, 44, 50, 56, 62, 67, 72, 79, 88, 106, 134,
1385];
1386const INTEN_HYSTERESIS: [i32; 21] = [
1387    1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 4, 5, 6, 8, 8,
1388];
1389
1390fn hysteresis_decision(val: i32, thresholds: &[i32], hysteresis: &[i32], prev: i32) -> i32 {
1391    let mut i = 0;
1392    while i < thresholds.len() {
1393        if val < thresholds[i] {
1394            break;
1395        }
1396        i += 1;
1397    }
1398    let mut res = i as i32;
1399    if res > prev && val < thresholds[prev as usize] + hysteresis[prev as usize] {
1400        res = prev;
1401    }
1402    if res < prev && res > 0 && val > thresholds[prev as usize - 1] - hysteresis[prev as usize - 1]
1403    {
1404        res = prev;
1405    }
1406    res
1407}
1408
1409#[allow(clippy::too_many_arguments)]
1410fn alloc_trim_analysis(
1411    mode: &CeltMode,
1412    x: &[f32],
1413    band_log_e: &[f32],
1414    end: usize,
1415    lm: i32,
1416    channels: usize,
1417    n0: usize,
1418    stereo_saving: &mut f32,
1419    tf_estimate: f32,
1420    intensity: i32,
1421    surround_trim: f32,
1422    equiv_rate: i32,
1423) -> i32 {
1424    let mut trim = 5.0f32;
1425    if equiv_rate < 64000 {
1426        trim = 4.0;
1427    } else if equiv_rate < 80000 {
1428        let frac = (equiv_rate - 64000) as f32 / 1024.0;
1429        trim = 4.0 + (1.0 / 16.0) * frac;
1430    }
1431
1432    if channels == 2 {
1433        let mut sum = 0.0f32;
1434        for i in 0..8 {
1435            let offset = (mode.e_bands[i] as usize) << lm;
1436            let n = ((mode.e_bands[i + 1] - mode.e_bands[i]) as usize) << lm;
1437            let mut partial = 0.0f32;
1438            for j in 0..n {
1439                partial += x[offset + j] * x[n0 + offset + j];
1440            }
1441            sum += partial;
1442        }
1443        sum = (sum / 8.0).abs().min(1.0);
1444        let mut min_xc = sum;
1445        for i in 8..intensity as usize {
1446            let offset = (mode.e_bands[i] as usize) << lm;
1447            let n = ((mode.e_bands[i + 1] - mode.e_bands[i]) as usize) << lm;
1448            let mut partial = 0.0f32;
1449            for j in 0..n {
1450                partial += x[offset + j] * x[n0 + offset + j];
1451            }
1452            min_xc = min_xc.min(partial.abs());
1453        }
1454        min_xc = min_xc.min(1.0);
1455
1456        let log_xc = (1.001 - sum * sum).log2();
1457        let log_xc2 = (log_xc * 0.5).max((1.001 - min_xc * min_xc).log2());
1458
1459        trim += (-4.0f32).max(0.75 * log_xc);
1460        *stereo_saving = (*stereo_saving + 0.25).min(-0.5 * log_xc2);
1461    }
1462
1463    let mut diff = 0.0f32;
1464    for c in 0..channels {
1465        for i in 0..end - 1 {
1466            diff += band_log_e[c * mode.nb_ebands + i] * (2 + 2 * i as i32 - end as i32) as f32;
1467        }
1468    }
1469    diff /= (channels * (end - 1)) as f32;
1470    trim -= (-2.0f32).max(2.0f32.min((diff + 1.0) / 6.0));
1471    trim -= surround_trim;
1472    trim -= 2.0 * tf_estimate;
1473
1474    let trim_index = (trim + 0.5).floor() as i32;
1475    trim_index.clamp(0, 10)
1476}
1477
1478#[inline(always)]
1479fn median3(a: f32, b: f32, c: f32) -> f32 {
1480    let mut v = [a, b, c];
1481    v.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
1482    v[1]
1483}
1484
1485#[inline(always)]
1486fn median5(v: &[f32]) -> f32 {
1487    let mut x = [v[0], v[1], v[2], v[3], v[4]];
1488    x.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1489    x[2]
1490}
1491
1492#[allow(clippy::too_many_arguments)]
1493fn dynalloc_analysis_simple(
1494    mode: &CeltMode,
1495    band_log_e: &[f32],
1496    old_band_e: &[f32],
1497    start: usize,
1498    end: usize,
1499    channels: usize,
1500    lm: usize,
1501    effective_bytes: usize,
1502    is_transient: bool,
1503    offsets: &mut [i32],
1504    cap: &[i32],
1505) {
1506    offsets.fill(0);
1507    if effective_bytes < (30 + 5 * lm) {
1508        return;
1509    }
1510
1511    let nb = mode.nb_ebands;
1512    let mut follower = vec![0.0f32; nb * channels];
1513
1514    for c in 0..channels {
1515        let base = c * nb;
1516        let mut band_log_e3 = vec![0.0f32; end];
1517        for i in 0..end {
1518            let mut e = band_log_e[base + i];
1519            if lm == 0 && i < 8 {
1520                e = e.max(old_band_e[base + i]);
1521            }
1522            band_log_e3[i] = e;
1523        }
1524
1525        let mut last = 0usize;
1526        follower[base] = band_log_e3[0];
1527        for i in 1..end {
1528            if band_log_e3[i] > band_log_e3[i - 1] + 0.5 {
1529                last = i;
1530            }
1531            follower[base + i] = (follower[base + i - 1] + 1.5).min(band_log_e3[i]);
1532        }
1533        for i in (0..last).rev() {
1534            follower[base + i] =
1535                follower[base + i].min((follower[base + i + 1] + 2.0).min(band_log_e3[i]));
1536        }
1537
1538        let offset = 1.0f32;
1539        if end >= 5 {
1540            for i in 2..end - 2 {
1541                follower[base + i] =
1542                    follower[base + i].max(median5(&band_log_e3[i - 2..i + 3]) - offset);
1543            }
1544        }
1545        if end >= 3 {
1546            let l = median3(band_log_e3[0], band_log_e3[1], band_log_e3[2]) - offset;
1547            follower[base] = follower[base].max(l);
1548            follower[base + 1] = follower[base + 1].max(l);
1549
1550            let r = median3(
1551                band_log_e3[end - 3],
1552                band_log_e3[end - 2],
1553                band_log_e3[end - 1],
1554            ) - offset;
1555            follower[base + end - 2] = follower[base + end - 2].max(r);
1556            follower[base + end - 1] = follower[base + end - 1].max(r);
1557        }
1558    }
1559
1560    if channels == 2 {
1561        for i in start..end {
1562            let l = follower[i];
1563            let r = follower[nb + i];
1564            let r2 = r.max(l - 4.0);
1565            let l2 = l.max(r - 4.0);
1566            follower[i] =
1567                ((band_log_e[i] - l2).max(0.0) + (band_log_e[nb + i] - r2).max(0.0)) * 0.5;
1568        }
1569    } else {
1570        for i in start..end {
1571            follower[i] = (band_log_e[i] - follower[i]).max(0.0);
1572        }
1573    }
1574
1575    if !is_transient {
1576        for i in start..end {
1577            follower[i] *= 0.5;
1578        }
1579    }
1580
1581    let mut tot_boost = 0i32;
1582    for i in start..end {
1583        let mut f = follower[i].min(4.0);
1584        if i < 8 {
1585            f *= 2.0;
1586        }
1587        if i >= 12 {
1588            f *= 0.5;
1589        }
1590
1591        let width = channels as i32 * (mode.e_bands[i + 1] - mode.e_bands[i]) as i32 * (1 << lm);
1592        let (boost, boost_bits) = if width < 6 {
1593            let b = f.floor().max(0.0) as i32;
1594            (b, (b * width) << BITRES)
1595        } else if width > 48 {
1596            let b = (f * 8.0).floor().max(0.0) as i32;
1597            (b, ((b * width) << BITRES) / 8)
1598        } else {
1599            let b = (f * width as f32 / 6.0).floor().max(0.0) as i32;
1600            (b, (b * 6) << BITRES)
1601        };
1602
1603        // Keep dynalloc bounded so allocator still has base bits in CBR usage.
1604        let cap_bits = ((2 * effective_bytes as i32) / 3) << (BITRES + 3);
1605        if tot_boost + boost_bits > cap_bits {
1606            offsets[i] = ((cap_bits - tot_boost) >> BITRES).max(0);
1607            break;
1608        }
1609
1610        let quanta = (width << BITRES).min((6 << BITRES).max(width));
1611        let mut boost_count = boost;
1612        let mut as_bits = boost_count * quanta;
1613        if as_bits > cap[i] {
1614            as_bits = cap[i];
1615            boost_count = as_bits / quanta;
1616        }
1617
1618        offsets[i] = boost_count.max(0);
1619        tot_boost += boost_bits.max(0);
1620    }
1621}
1622
1623impl CeltEncoder {
1624    pub fn new(mode: &'static CeltMode, channels: usize) -> Self {
1625        let overlap = mode.overlap;
1626        let channel_mem_size = 2048 + overlap;
1627        let syn_mem_size = channels * channel_mem_size;
1628        let nb_ebands = mode.nb_ebands;
1629        let nb_x_ch = nb_ebands * channels;
1630        let frame_x_ch = MAX_FRAME_SIZE * channels;
1631        let bufstride_x_ch = (MAX_FRAME_SIZE + overlap) * channels;
1632        Self {
1633            mode,
1634            channels,
1635            complexity: 9,
1636            syn_mem: vec![0.0; syn_mem_size],
1637            enc_decode_mem: vec![0.0; syn_mem_size],
1638            old_band_e: vec![0.0; nb_x_ch],
1639            preemph_mem: vec![0.0; channels],
1640            tonal_average: 256,
1641            hf_average: 0,
1642            tapset_decision: 0,
1643            spread_decision: SPREAD_NORMAL,
1644            intensity: 0,
1645            last_coded_bands: 0,
1646            prefilter_mem: vec![0.0; channels * COMBFILTER_MAXPERIOD],
1647            prefilter_period: COMBFILTER_MINPERIOD,
1648            prefilter_gain: 0.0,
1649            prefilter_tapset: 0,
1650            old_band_e2: vec![0.0; nb_x_ch],
1651            old_band_e3: vec![0.0; nb_x_ch],
1652            last_band_log_e: vec![0.0; nb_x_ch],
1653            delayed_intra: 0.0,
1654
1655            w_in_buf: vec![0.0; bufstride_x_ch],
1656            w_freq: vec![0.0; frame_x_ch + 4],
1657            w_band_e: vec![0.0; nb_x_ch],
1658
1659            w_x: vec![0.0; frame_x_ch + STRIDE_ACCESS_PAD],
1660            w_band_log_e: vec![0.0; nb_x_ch],
1661            w_error: vec![0.0; nb_x_ch],
1662            w_tf_res: vec![0; nb_ebands],
1663            w_cap: vec![0; nb_ebands],
1664            w_offsets: vec![0; nb_ebands],
1665            w_pulses: vec![0; nb_ebands],
1666            w_ebits: vec![0; nb_x_ch],
1667            w_fine_priority: vec![0; nb_x_ch],
1668            w_collapse_masks: vec![0; nb_x_ch],
1669            w_band_amp_synth: vec![0.0; nb_x_ch],
1670            w_freq_synth: vec![0.0; frame_x_ch + 4],
1671
1672            w_prefilter_pre: vec![0.0; channels * (COMBFILTER_MAXPERIOD + MAX_FRAME_SIZE)],
1673            w_prefilter_pitch_buf: vec![0.0; (COMBFILTER_MAXPERIOD + MAX_FRAME_SIZE) >> 1],
1674            w_prefilter_before: vec![0.0; channels],
1675            w_prefilter_after: vec![0.0; channels],
1676            w_transient_tmp: vec![0.0; MAX_TRANSIENT_LEN],
1677            w_transient_tmp2: vec![0.0; MAX_TRANSIENT_LEN / 2],
1678            consec_transient: 0,
1679
1680            analysis: AnalysisInfo::default(),
1681            loss_rate: 0,
1682        }
1683    }
1684
1685    pub fn encode(&mut self, pcm: &[f32], frame_size: usize, rc: &mut RangeCoder) {
1686        self.encode_impl(pcm, frame_size, rc, 0, None)
1687    }
1688
1689    pub fn encode_with_start_band(
1690        &mut self,
1691        pcm: &[f32],
1692        frame_size: usize,
1693        rc: &mut RangeCoder,
1694        start_band: usize,
1695    ) {
1696        self.encode_impl(pcm, frame_size, rc, start_band, None)
1697    }
1698
1699    pub fn encode_with_budget(
1700        &mut self,
1701        pcm: &[f32],
1702        frame_size: usize,
1703        rc: &mut RangeCoder,
1704        start_band: usize,
1705        total_bits: i32,
1706    ) {
1707        self.encode_impl(pcm, frame_size, rc, start_band, Some(total_bits))
1708    }
1709
1710    fn encode_impl(
1711        &mut self,
1712        pcm: &[f32],
1713        frame_size: usize,
1714        rc: &mut RangeCoder,
1715        start_band: usize,
1716        explicit_total_bits: Option<i32>,
1717    ) {
1718        let mode = self.mode;
1719        let channels = self.channels;
1720        let nb_ebands = mode.nb_ebands;
1721        let overlap = mode.overlap;
1722
1723        let mut lm = 0;
1724        while (mode.short_mdct_size << lm) != frame_size {
1725            lm += 1;
1726            if lm > mode.max_lm {
1727                break;
1728            }
1729        }
1730        if (mode.short_mdct_size << lm) != frame_size {
1731            lm = 0;
1732        }
1733
1734        let syn_mem_size = 2048 + overlap;
1735        for c in 0..channels {
1736            let channel_offset = c * syn_mem_size;
1737
1738            self.syn_mem.copy_within(
1739                channel_offset + frame_size..channel_offset + syn_mem_size,
1740                channel_offset,
1741            );
1742
1743            let mut m = self.preemph_mem[c];
1744            let coef = mode.preemph[0];
1745            for i in 0..frame_size {
1746                let x = pcm[c * frame_size + i] * 32768.0;
1747                let val = x - m;
1748                self.syn_mem[channel_offset + syn_mem_size - frame_size + i] = val;
1749                m = x * coef;
1750            }
1751            self.preemph_mem[c] = m;
1752        }
1753
1754        let buf_stride = frame_size + overlap;
1755        let in_buf = &mut self.w_in_buf[..buf_stride * channels];
1756        for c in 0..channels {
1757            let channel_offset = c * syn_mem_size;
1758            let in_buf_offset = c * buf_stride;
1759
1760            let src_start = syn_mem_size - frame_size - overlap;
1761            in_buf[in_buf_offset..in_buf_offset + buf_stride].copy_from_slice(
1762                &self.syn_mem[channel_offset + src_start..channel_offset + syn_mem_size],
1763            );
1764        }
1765
1766        let mut tf_estimate = 0.0f32;
1767        let mut tf_chan = 0;
1768        let mut weak_transient = false;
1769
1770        let is_transient = if self.complexity >= 1 {
1771            transient_analysis(
1772                in_buf,
1773                buf_stride,
1774                channels,
1775                &mut tf_estimate,
1776                &mut tf_chan,
1777                false,
1778                &mut weak_transient,
1779                0.0,
1780                0.0,
1781                &mut self.w_transient_tmp,
1782                &mut self.w_transient_tmp2,
1783            )
1784        } else {
1785            false
1786        };
1787
1788        // Check for pure tone: if tonality is very high, bypass pitch search
1789        let toneishness = if self.analysis.valid {
1790            self.analysis.tonality
1791        } else {
1792            0.0
1793        };
1794        let _tone_freq = 0.0f32; // Would be set from analysis if available
1795
1796        let pf_enabled =
1797            start_band == 0 && self.complexity >= 5 && toneishness < 0.99 && channels == 1;
1798        let (pf_on, gain1, pitch_index) = if pf_enabled {
1799            run_prefilter(
1800                in_buf,
1801                &mut self.prefilter_mem,
1802                self.prefilter_period,
1803                self.prefilter_gain,
1804                self.prefilter_tapset,
1805                self.tapset_decision,
1806                mode.window,
1807                channels,
1808                frame_size,
1809                overlap,
1810                &mut self.w_prefilter_pre,
1811                &mut self.w_prefilter_pitch_buf,
1812                &mut self.w_prefilter_before,
1813                &mut self.w_prefilter_after,
1814                &self.analysis,
1815                self.loss_rate,
1816            )
1817        } else {
1818            (false, 0.0f32, COMBFILTER_MINPERIOD)
1819        };
1820
1821        // Save the prefiltered overlap for the next frame.
1822        // In libopus, st->in_mem stores the overlap separately and run_prefilter
1823        // copies it to/from in[]. Here we emulate that by updating syn_mem with
1824        // the last overlap samples of in_buf (which were prefiltered in place).
1825        let syn_mem_size = 2048 + overlap;
1826        for c in 0..channels {
1827            let channel_offset = c * syn_mem_size;
1828            let in_buf_offset = c * buf_stride;
1829            self.syn_mem[channel_offset + syn_mem_size - overlap..channel_offset + syn_mem_size]
1830                .copy_from_slice(&in_buf[in_buf_offset + frame_size..in_buf_offset + buf_stride]);
1831        }
1832
1833        let freq = &mut self.w_freq[..frame_size * channels];
1834        let (shift, b) = if is_transient {
1835            (mode.max_lm, 1 << lm)
1836        } else {
1837            (mode.max_lm - lm, 1)
1838        };
1839        let n = frame_size / b;
1840
1841        for c in 0..channels {
1842            let c_buf_offset = c * buf_stride;
1843
1844            if c == 0 && b == 1 && channels == 1 {
1845                let mut max_val = 0.0f32;
1846                let check_len = (frame_size + overlap).min(buf_stride);
1847                for j in 0..check_len {
1848                    max_val = max_val.max(in_buf[c_buf_offset + j].abs());
1849                }
1850            }
1851
1852            for i in 0..b {
1853                mode.mdct.forward(
1854                    &in_buf[c_buf_offset + i * n..],
1855                    &mut freq[c * frame_size + i..],
1856                    mode.window,
1857                    overlap,
1858                    shift,
1859                    b,
1860                );
1861            }
1862        }
1863
1864        let band_e = &mut self.w_band_e[..nb_ebands * channels];
1865        compute_band_energies(mode, freq, band_e, nb_ebands, channels, lm);
1866
1867        let x_pad_end = (frame_size * channels + STRIDE_ACCESS_PAD).min(self.w_x.len());
1868        let x = &mut self.w_x[..x_pad_end];
1869        normalise_bands(
1870            mode,
1871            freq,
1872            x,
1873            band_e,
1874            nb_ebands,
1875            channels,
1876            (1 << lm) as usize,
1877        );
1878
1879        if channels == 1 {
1880            let _ = freq[0];
1881        }
1882
1883        let band_log_e = &mut self.w_band_log_e[..nb_ebands * channels];
1884        crate::bands::amp2log2(mode, start_band, nb_ebands, band_e, band_log_e, channels);
1885
1886        let total_bits = explicit_total_bits.unwrap_or_else(|| (rc.buf.len() * 8) as i32);
1887        self.w_error[..nb_ebands * channels].fill(0.0);
1888        let error = &mut self.w_error[..nb_ebands * channels];
1889
1890        let tell = rc.tell();
1891        let silence = false;
1892        if tell == 1 {
1893            rc.encode_bit_logp(silence, 15);
1894        }
1895
1896        if start_band == 0 && !silence && rc.tell() + 16 <= total_bits {
1897            rc.encode_bit_logp(pf_on, 1);
1898            if pf_on {
1899                let qg = (gain1 / 0.09375 - 1.0 + 0.5).floor() as i32;
1900                let qg = qg.clamp(0, 7);
1901                let pi = (pitch_index + 1) as u32;
1902                let octave = 31 - pi.leading_zeros();
1903                let octave = (octave as i32 - 5).max(0) as u32;
1904                rc.enc_uint(octave, 6);
1905                rc.enc_bits(pi - (16 << octave), 4 + octave);
1906                rc.enc_bits(qg as u32, 3);
1907                rc.encode_icdf(self.tapset_decision, &TAPSET_ICDF, 2);
1908            }
1909        }
1910
1911        let mut short_blocks = false;
1912        if lm > 0 && rc.tell() + 3 <= total_bits {
1913            rc.encode_bit_logp(is_transient, 3);
1914            if is_transient {
1915                short_blocks = true;
1916            }
1917        }
1918
1919        if short_blocks {
1920            let b = 1 << lm;
1921            let n = frame_size / b;
1922            for c in 0..channels {
1923                let c_offset = c * buf_stride;
1924                for i in 0..b {
1925                    mode.mdct.forward(
1926                        &in_buf[c_offset + i * n..c_offset + buf_stride],
1927                        &mut freq[c * frame_size + i..],
1928                        mode.window,
1929                        overlap,
1930                        mode.max_lm,
1931                        b,
1932                    );
1933                }
1934            }
1935
1936            compute_band_energies(mode, freq, band_e, nb_ebands, channels, lm);
1937            normalise_bands(
1938                mode,
1939                freq,
1940                x,
1941                band_e,
1942                nb_ebands,
1943                channels,
1944                (1 << lm) as usize,
1945            );
1946        }
1947
1948        let intra_ener = if self.complexity >= 4 {
1949            false
1950        } else {
1951            self.old_band_e[..nb_ebands * channels]
1952                .iter()
1953                .all(|&e| e <= -27.0)
1954        };
1955        quant_coarse_energy_advanced(
1956            mode,
1957            start_band,
1958            nb_ebands,
1959            nb_ebands,
1960            band_log_e,
1961            &mut self.old_band_e,
1962            total_bits as u32,
1963            error,
1964            rc,
1965            channels,
1966            lm,
1967            (total_bits / 8) as usize,
1968            is_transient || intra_ener,
1969            &mut self.delayed_intra,
1970            self.complexity >= 4,
1971            0,
1972            false,
1973        );
1974        self.w_tf_res[..nb_ebands].fill(0);
1975        let tf_res = &mut self.w_tf_res[..nb_ebands];
1976        let effective_bytes = ((total_bits / 8) as usize).max(1);
1977        let lambda = 80.max(20480 / effective_bytes + 2) as i32;
1978
1979        let tf_select = if self.complexity >= 2 && effective_bytes >= 15 * channels {
1980            tf_analysis(
1981                mode,
1982                nb_ebands,
1983                is_transient,
1984                tf_res,
1985                lambda,
1986                x,
1987                frame_size,
1988                lm as i32,
1989                tf_estimate,
1990                tf_chan,
1991            )
1992        } else {
1993            0
1994        };
1995        tf_encode(
1996            start_band,
1997            nb_ebands,
1998            is_transient,
1999            tf_res,
2000            lm as i32,
2001            tf_select,
2002            rc,
2003        );
2004
2005        let mut dual_stereo_val = if channels == 2 {
2006            stereo_analysis(mode, x, lm as i32, frame_size) as i32
2007        } else {
2008            0
2009        };
2010
2011        let mut stereo_saving = 0.0f32;
2012        let equiv_rate = (total_bits * 48000) / frame_size as i32;
2013        if channels == 2 {
2014            self.intensity = hysteresis_decision(
2015                equiv_rate / 1000,
2016                &INTEN_THRESHOLDS,
2017                &INTEN_HYSTERESIS,
2018                self.intensity,
2019            );
2020            self.intensity = self.intensity.clamp(0, nb_ebands as i32);
2021        }
2022
2023        if self.complexity == 0 {
2024            self.spread_decision = SPREAD_NONE;
2025            if rc.tell() + 4 <= total_bits {
2026                rc.encode_icdf(self.spread_decision, &SPREAD_ICDF, 5);
2027            }
2028        } else if rc.tell() + 4 <= total_bits {
2029            if is_transient || self.complexity < 3 || effective_bytes < 10 * channels {
2030                self.spread_decision = SPREAD_NORMAL;
2031            } else {
2032                let update_hf = lm == mode.max_lm;
2033                let spread_weights = [32i32; 21];
2034                self.spread_decision = spreading_decision(
2035                    mode,
2036                    x,
2037                    &mut self.tonal_average,
2038                    self.spread_decision,
2039                    &mut self.hf_average,
2040                    &mut self.tapset_decision,
2041                    update_hf,
2042                    nb_ebands,
2043                    channels,
2044                    (1 << lm) as usize,
2045                    &spread_weights,
2046                );
2047            }
2048            rc.encode_icdf(self.spread_decision, &SPREAD_ICDF, 5);
2049        } else {
2050            self.spread_decision = SPREAD_NORMAL;
2051        }
2052
2053        self.w_cap[..nb_ebands].fill(0);
2054        let cap = &mut self.w_cap[..nb_ebands];
2055        for (i, cap_i) in cap.iter_mut().enumerate() {
2056            let n = (mode.e_bands[i + 1] - mode.e_bands[i]) << lm;
2057            *cap_i = ((mode.cache.caps[nb_ebands * (2 * lm + channels - 1) + i] as i32 + 64)
2058                * channels as i32
2059                * n as i32)
2060                >> 2;
2061        }
2062
2063        self.w_offsets[..nb_ebands].fill(0);
2064        let offsets = &mut self.w_offsets[..nb_ebands];
2065
2066        dynalloc_analysis_simple(
2067            mode,
2068            band_log_e,
2069            &self.old_band_e,
2070            start_band,
2071            nb_ebands,
2072            channels,
2073            lm,
2074            effective_bytes,
2075            is_transient,
2076            offsets,
2077            cap,
2078        );
2079
2080        let mut dynalloc_logp = 6i32;
2081        let total_bits_bitres = total_bits << BITRES;
2082        let mut total_boost = 0i32;
2083        let mut tell_frac = rc.tell_frac();
2084
2085        for i in start_band..nb_ebands {
2086            let width =
2087                channels as i32 * (mode.e_bands[i + 1] - mode.e_bands[i]) as i32 * (1 << lm);
2088            let quanta = (width << BITRES).min((6 << BITRES).max(width));
2089            let mut dynalloc_loop_logp = dynalloc_logp;
2090            let mut boost = 0i32;
2091            let mut j = 0i32;
2092
2093            while tell_frac + (dynalloc_loop_logp << BITRES) < total_bits_bitres - total_boost
2094                && boost < cap[i]
2095            {
2096                let flag = j < offsets[i];
2097                rc.encode_bit_logp(flag, dynalloc_loop_logp as u32);
2098                tell_frac = rc.tell_frac();
2099                if !flag {
2100                    break;
2101                }
2102                boost += quanta;
2103                total_boost += quanta;
2104                dynalloc_loop_logp = 1;
2105                j += 1;
2106            }
2107
2108            if j > 0 {
2109                dynalloc_logp = 2.max(dynalloc_logp - 1);
2110            }
2111            offsets[i] = boost;
2112        }
2113
2114        let alloc_trim = alloc_trim_analysis(
2115            mode,
2116            x,
2117            band_log_e,
2118            nb_ebands,
2119            lm as i32,
2120            channels,
2121            frame_size,
2122            &mut stereo_saving,
2123            tf_estimate,
2124            self.intensity,
2125            0.0,
2126            equiv_rate,
2127        );
2128        if rc.tell_frac() + (6 << BITRES) <= total_bits_bitres - total_boost {
2129            rc.encode_icdf(alloc_trim, &TRIM_ICDF, 7);
2130        }
2131
2132        let mut intensity = self.intensity;
2133        self.w_pulses[..nb_ebands].fill(0);
2134        let pulses = &mut self.w_pulses[..nb_ebands];
2135
2136        let stereo = channels > 1;
2137        let ebands_stereo = if stereo {
2138            nb_ebands * channels
2139        } else {
2140            nb_ebands
2141        };
2142        self.w_fine_priority[..ebands_stereo].fill(0);
2143        let fine_priority = &mut self.w_fine_priority[..ebands_stereo];
2144        self.w_ebits[..ebands_stereo].fill(0);
2145        let ebits = &mut self.w_ebits[..ebands_stereo];
2146        let mut balance = 0;
2147
2148        self.last_coded_bands = clt_compute_allocation(
2149            mode,
2150            start_band,
2151            nb_ebands,
2152            offsets,
2153            cap,
2154            alloc_trim,
2155            &mut intensity,
2156            &mut dual_stereo_val,
2157            (total_bits << BITRES) - rc.tell_frac() - 1,
2158            &mut balance,
2159            pulses,
2160            ebits,
2161            fine_priority,
2162            channels as i32,
2163            lm as i32,
2164            rc,
2165            true,
2166            0,
2167            nb_ebands as i32 - 1,
2168        );
2169
2170        quant_fine_energy(
2171            mode,
2172            start_band,
2173            nb_ebands,
2174            &mut self.old_band_e,
2175            error,
2176            ebits,
2177            rc,
2178            channels,
2179        );
2180
2181        self.w_collapse_masks[..nb_ebands * channels].fill(0);
2182        let collapse_masks = &mut self.w_collapse_masks[..nb_ebands * channels];
2183        let (x_split, y_split) = x.split_at_mut(frame_size);
2184        let y_opt = if channels == 2 { Some(y_split) } else { None };
2185
2186        let anti_collapse_rsv = if is_transient && lm >= 2 {
2187            let remaining = (total_bits << BITRES) - rc.tell_frac() - 1;
2188            if remaining >= ((lm as i32 + 2) << BITRES) {
2189                1i32 << BITRES
2190            } else {
2191                0
2192            }
2193        } else {
2194            0
2195        };
2196
2197        let mut dual_stereo = dual_stereo_val != 0;
2198
2199        let theta_rdo = channels == 2 && !dual_stereo && self.complexity >= 8;
2200        let resynth = theta_rdo;
2201
2202        quant_all_bands(
2203            true,
2204            mode,
2205            start_band,
2206            nb_ebands,
2207            x_split,
2208            y_opt,
2209            collapse_masks,
2210            band_e,
2211            pulses,
2212            short_blocks,
2213            self.spread_decision,
2214            &mut dual_stereo,
2215            intensity as usize,
2216            tf_res,
2217            (total_bits << BITRES) - anti_collapse_rsv,
2218            &mut balance,
2219            rc,
2220            lm as i32,
2221            self.last_coded_bands,
2222            resynth,
2223            false,
2224            &mut 0u32,
2225        );
2226
2227        if anti_collapse_rsv > 0 {
2228            let anti_collapse_on = if self.consec_transient < 2 {
2229                1u32
2230            } else {
2231                0u32
2232            };
2233            rc.enc_bits(anti_collapse_on, 1);
2234        }
2235
2236        quant_energy_finalise(
2237            mode,
2238            start_band,
2239            nb_ebands,
2240            &mut self.old_band_e,
2241            error,
2242            ebits,
2243            fine_priority,
2244            total_bits - rc.tell(),
2245            rc,
2246            channels,
2247        );
2248
2249        if resynth {
2250            let band_amp_synth = &mut self.w_band_amp_synth[..nb_ebands * channels];
2251            log2amp(mode, nb_ebands, band_amp_synth, &self.old_band_e, channels);
2252            self.w_freq_synth[..frame_size * channels].fill(0.0);
2253            let freq_synth = &mut self.w_freq_synth[..frame_size * channels];
2254            denormalise_bands(
2255                mode,
2256                x,
2257                freq_synth,
2258                band_amp_synth,
2259                start_band,
2260                nb_ebands,
2261                channels,
2262                (1 << lm) as usize,
2263            );
2264            let (syn_shift, syn_b) = if is_transient {
2265                (mode.max_lm, 1 << lm)
2266            } else {
2267                (mode.max_lm - lm, 1)
2268            };
2269            let syn_n = frame_size / syn_b;
2270            let decode_buf_size = 2048;
2271
2272            for c in 0..channels {
2273                let co = c * syn_mem_size;
2274                self.enc_decode_mem
2275                    .copy_within(co + frame_size..co + decode_buf_size + overlap, co);
2276            }
2277
2278            for c in 0..channels {
2279                let co = c * syn_mem_size;
2280                let out_syn_idx = decode_buf_size - frame_size;
2281                for bi in 0..syn_b {
2282                    let syn_stride = if is_transient {
2283                        mode.short_mdct_size
2284                    } else {
2285                        syn_n
2286                    };
2287                    mode.mdct.backward(
2288                        &freq_synth[c * frame_size + bi..],
2289                        &mut self.enc_decode_mem[co + out_syn_idx + bi * syn_stride..],
2290                        mode.window,
2291                        overlap,
2292                        syn_shift,
2293                        syn_b,
2294                    );
2295                }
2296            }
2297        }
2298
2299        self.last_band_log_e.copy_from_slice(&self.old_band_e);
2300
2301        if !is_transient {
2302            self.old_band_e3.copy_from_slice(&self.old_band_e2);
2303            self.old_band_e2.copy_from_slice(&self.old_band_e);
2304        } else {
2305            for i in 0..channels * nb_ebands {
2306                self.old_band_e2[i] = self.old_band_e2[i].min(self.old_band_e[i]);
2307            }
2308        }
2309
2310        rc.pad_to_bits(total_bits);
2311
2312        if pf_on {
2313            self.prefilter_period = pitch_index;
2314            self.prefilter_gain = gain1;
2315            self.prefilter_tapset = self.tapset_decision;
2316        } else {
2317            self.prefilter_period = COMBFILTER_MINPERIOD;
2318            self.prefilter_gain = 0.0;
2319            self.prefilter_tapset = self.tapset_decision;
2320        }
2321
2322        if is_transient {
2323            self.consec_transient += 1;
2324        } else {
2325            self.consec_transient = 0;
2326        }
2327    }
2328}
2329
2330pub struct CeltDecoder {
2331    mode: &'static CeltMode,
2332    channels: usize,
2333    decode_mem: Vec<f32>,
2334    old_band_e: Vec<f32>,
2335    preemph_mem: Vec<f32>,
2336    prefilter_mem: Vec<f32>,
2337    prefilter_period: usize,
2338    prefilter_period_old: usize,
2339    prefilter_gain: f32,
2340    prefilter_gain_old: f32,
2341    prefilter_tapset: i32,
2342    prefilter_tapset_old: i32,
2343    old_band_e2: Vec<f32>,
2344    old_band_e3: Vec<f32>,
2345    rng: u32,
2346
2347    w_tf_res: Vec<i32>,
2348    w_cap: Vec<i32>,
2349    w_offsets: Vec<i32>,
2350    w_pulses: Vec<i32>,
2351    w_ebits: Vec<i32>,
2352    w_fine_priority: Vec<i32>,
2353    w_x: Vec<f32>,
2354    w_collapse_masks: Vec<u32>,
2355    w_freq: Vec<f32>,
2356    w_band_amp: Vec<f32>,
2357    w_pcm_frame: Vec<f32>,
2358    w_post: Vec<f32>,
2359}
2360
2361impl CeltDecoder {
2362    pub fn new(mode: &'static CeltMode, channels: usize) -> Self {
2363        let overlap = mode.overlap;
2364        let nb_ebands = mode.nb_ebands;
2365        let nb_x_ch = nb_ebands * channels;
2366        let dec_frame_x_ch = DECODE_BUFFER_SIZE * channels;
2367        Self {
2368            mode,
2369            channels,
2370            decode_mem: vec![0.0; channels * (DECODE_BUFFER_SIZE + overlap)],
2371            old_band_e: vec![0.0; nb_x_ch],
2372            preemph_mem: vec![0.0; channels],
2373            prefilter_mem: vec![0.0; channels * COMBFILTER_MAXPERIOD],
2374            prefilter_period: COMBFILTER_MINPERIOD,
2375            prefilter_period_old: COMBFILTER_MINPERIOD,
2376            prefilter_gain: 0.0,
2377            prefilter_gain_old: 0.0,
2378            prefilter_tapset: 0,
2379            prefilter_tapset_old: 0,
2380            old_band_e2: vec![0.0; nb_x_ch],
2381            old_band_e3: vec![0.0; nb_x_ch],
2382            rng: 0,
2383
2384            w_tf_res: vec![0; nb_ebands],
2385            w_cap: vec![0; nb_ebands],
2386            w_offsets: vec![0; nb_ebands],
2387            w_pulses: vec![0; nb_ebands],
2388            w_ebits: vec![0; nb_x_ch],
2389            w_fine_priority: vec![0; nb_x_ch],
2390
2391            w_x: vec![0.0; dec_frame_x_ch + STRIDE_ACCESS_PAD],
2392            w_collapse_masks: vec![0; nb_x_ch],
2393            w_freq: vec![0.0; dec_frame_x_ch + 4], // +4: NEON backward pre-rotation reads up to 3 elements past n2
2394            w_band_amp: vec![0.0; nb_x_ch],
2395            w_pcm_frame: vec![0.0; DECODE_BUFFER_SIZE],
2396            w_post: vec![0.0; DECODE_BUFFER_SIZE + COMBFILTER_MAXPERIOD],
2397        }
2398    }
2399
2400    pub fn decode(&mut self, compressed: &[u8], frame_size: usize, pcm: &mut [f32]) -> usize {
2401        self.decode_impl(compressed, frame_size, pcm, 0, self.mode.nb_ebands)
2402    }
2403
2404    pub fn decode_with_start_band(
2405        &mut self,
2406        compressed: &[u8],
2407        frame_size: usize,
2408        pcm: &mut [f32],
2409        start_band: usize,
2410    ) -> usize {
2411        self.decode_impl(compressed, frame_size, pcm, start_band, self.mode.nb_ebands)
2412    }
2413
2414    pub fn decode_from_range_coder(
2415        &mut self,
2416        rc: &mut RangeCoder,
2417        total_bits: i32,
2418        frame_size: usize,
2419        pcm: &mut [f32],
2420        start_band: usize,
2421    ) -> usize {
2422        self.decode_impl_from_rc(
2423            rc,
2424            total_bits,
2425            frame_size,
2426            pcm,
2427            start_band,
2428            self.mode.nb_ebands,
2429        )
2430    }
2431
2432    pub fn decode_from_range_coder_with_band_range(
2433        &mut self,
2434        rc: &mut RangeCoder,
2435        total_bits: i32,
2436        frame_size: usize,
2437        pcm: &mut [f32],
2438        start_band: usize,
2439        end_band: usize,
2440    ) -> usize {
2441        self.decode_impl_from_rc(rc, total_bits, frame_size, pcm, start_band, end_band)
2442    }
2443
2444    fn decode_impl(
2445        &mut self,
2446        compressed: &[u8],
2447        frame_size: usize,
2448        pcm: &mut [f32],
2449        start_band: usize,
2450        end_band: usize,
2451    ) -> usize {
2452        let total_bits = (compressed.len() * 8) as i32;
2453        let mut rc = RangeCoder::new_decoder(compressed);
2454        self.decode_impl_from_rc(&mut rc, total_bits, frame_size, pcm, start_band, end_band)
2455    }
2456
2457    fn decode_impl_from_rc(
2458        &mut self,
2459        rc: &mut RangeCoder,
2460        total_bits: i32,
2461        frame_size: usize,
2462        pcm: &mut [f32],
2463        start_band: usize,
2464        end_band: usize,
2465    ) -> usize {
2466        let mode = self.mode;
2467        let channels = self.channels;
2468        let nb_ebands = mode.nb_ebands;
2469        let end_band = end_band.min(nb_ebands).max(start_band);
2470        let overlap = mode.overlap;
2471
2472        let mut lm = 0;
2473        while (mode.short_mdct_size << lm) != frame_size {
2474            lm += 1;
2475            if lm > mode.max_lm {
2476                break;
2477            }
2478        }
2479        if (mode.short_mdct_size << lm) != frame_size {
2480            lm = 0;
2481        }
2482
2483        let tell = rc.tell();
2484        let mut silence = false;
2485        if tell >= total_bits {
2486            silence = true;
2487        } else if tell == 1 {
2488            silence = rc.decode_bit_logp(15);
2489        }
2490
2491        if silence {
2492            pcm[..frame_size * channels].fill(0.0);
2493            return frame_size;
2494        }
2495
2496        let mut pf_on = false;
2497        let mut pitch_index = COMBFILTER_MINPERIOD;
2498        let mut gain1 = 0.0f32;
2499        let mut prefilter_tapset = 0;
2500
2501        if start_band == 0 && !silence && rc.tell() + 16 <= total_bits {
2502            pf_on = rc.decode_bit_logp(1);
2503            if pf_on {
2504                let octave = rc.dec_uint(6);
2505                pitch_index = ((16 << octave) + rc.dec_bits(4 + octave)) as usize - 1;
2506                let qg = rc.dec_bits(3);
2507                if rc.tell() + 2 <= total_bits {
2508                    prefilter_tapset = rc.decode_icdf(&TAPSET_ICDF, 2) as usize;
2509                }
2510                gain1 = 0.09375 * (qg as f32 + 1.0);
2511            }
2512        }
2513        if start_band != 0 {
2514            self.prefilter_gain = 0.0;
2515        }
2516
2517        let mut is_transient = false;
2518        if lm > 0 && rc.tell() + 3 <= total_bits {
2519            is_transient = rc.decode_bit_logp(3);
2520        }
2521        let short_blocks = is_transient;
2522
2523        let intra_ener = if rc.tell() + 3 <= total_bits {
2524            rc.decode_bit_logp(3)
2525        } else {
2526            false
2527        };
2528
2529        unquant_coarse_energy(
2530            mode,
2531            start_band,
2532            end_band,
2533            &mut self.old_band_e,
2534            intra_ener,
2535            rc,
2536            channels,
2537            lm,
2538        );
2539        self.w_tf_res[..nb_ebands].fill(0);
2540        let tf_res = &mut self.w_tf_res[..nb_ebands];
2541        tf_decode(start_band, end_band, is_transient, tf_res, lm as i32, rc);
2542
2543        let spread_decision = if rc.tell() + 4 <= total_bits {
2544            rc.decode_icdf(&SPREAD_ICDF, 5)
2545        } else {
2546            SPREAD_NORMAL
2547        };
2548
2549        self.w_cap[..nb_ebands].fill(0);
2550        let cap = &mut self.w_cap[..nb_ebands];
2551        for (i, cap_i) in cap.iter_mut().enumerate() {
2552            let n = (mode.e_bands[i + 1] - mode.e_bands[i]) << lm;
2553            *cap_i = ((mode.cache.caps[nb_ebands * (2 * lm + channels - 1) + i] as i32 + 64)
2554                * channels as i32
2555                * n as i32)
2556                >> 2;
2557        }
2558
2559        self.w_offsets[..nb_ebands].fill(0);
2560        let offsets = &mut self.w_offsets[..nb_ebands];
2561        let mut dynalloc_logp = 6i32;
2562        let mut total_bits_bitres = total_bits << BITRES;
2563        let mut tell_frac = rc.tell_frac();
2564        for i in start_band..end_band {
2565            let width =
2566                channels as i32 * (mode.e_bands[i + 1] - mode.e_bands[i]) as i32 * (1 << lm);
2567            let quanta = (width << BITRES).min((6i32 << BITRES).max(width));
2568            let mut dynalloc_loop_logp = dynalloc_logp;
2569            let mut boost = 0i32;
2570            while tell_frac + (dynalloc_loop_logp << BITRES) < total_bits_bitres && boost < cap[i] {
2571                let flag = rc.decode_bit_logp(dynalloc_loop_logp as u32);
2572                tell_frac = rc.tell_frac();
2573                if !flag {
2574                    break;
2575                }
2576                boost += quanta;
2577                total_bits_bitres -= quanta;
2578                dynalloc_loop_logp = 1;
2579            }
2580            offsets[i] = boost;
2581            if boost > 0 {
2582                dynalloc_logp = dynalloc_logp.max(2) - 1;
2583                dynalloc_logp = dynalloc_logp.max(2);
2584            }
2585        }
2586
2587        let alloc_trim = if rc.tell_frac() + (6 << BITRES) <= total_bits_bitres {
2588            rc.decode_icdf(&TRIM_ICDF, 7)
2589        } else {
2590            5
2591        };
2592        let anti_collapse_rsv = if is_transient && lm >= 2 {
2593            let remaining = (total_bits << BITRES) - rc.tell_frac() - 1;
2594            if remaining >= ((lm as i32 + 2) << BITRES) {
2595                1i32 << BITRES
2596            } else {
2597                0
2598            }
2599        } else {
2600            0
2601        };
2602
2603        let mut intensity = 0;
2604        let mut dual_stereo_val = if channels == 2 { 1 } else { 0 };
2605        let mut balance = 0;
2606        self.w_pulses[..nb_ebands].fill(0);
2607        let pulses = &mut self.w_pulses[..nb_ebands];
2608
2609        let ebands_stereo = if channels > 1 {
2610            nb_ebands * channels
2611        } else {
2612            nb_ebands
2613        };
2614        self.w_fine_priority[..ebands_stereo].fill(0);
2615        let fine_priority = &mut self.w_fine_priority[..ebands_stereo];
2616        self.w_ebits[..ebands_stereo].fill(0);
2617        let ebits = &mut self.w_ebits[..ebands_stereo];
2618
2619        let alloc_bits = (total_bits << BITRES) - rc.tell_frac() - 1 - anti_collapse_rsv;
2620        let coded_bands = clt_compute_allocation(
2621            mode,
2622            start_band,
2623            end_band,
2624            offsets,
2625            cap,
2626            alloc_trim,
2627            &mut intensity,
2628            &mut dual_stereo_val,
2629            alloc_bits,
2630            &mut balance,
2631            pulses,
2632            ebits,
2633            fine_priority,
2634            channels as i32,
2635            lm as i32,
2636            rc,
2637            false,
2638            0,
2639            end_band as i32 - 1,
2640        );
2641
2642        unquant_fine_energy(
2643            mode,
2644            start_band,
2645            end_band,
2646            &mut self.old_band_e,
2647            ebits,
2648            rc,
2649            channels,
2650        );
2651
2652        if frame_size > DECODE_BUFFER_SIZE + overlap {
2653            return 0;
2654        }
2655
2656        self.w_x[..frame_size * channels].fill(0.0);
2657
2658        let x_pad_end = (frame_size * channels + STRIDE_ACCESS_PAD).min(self.w_x.len());
2659        let x = &mut self.w_x[..x_pad_end];
2660        self.w_collapse_masks[..nb_ebands * channels].fill(0);
2661        let collapse_masks = &mut self.w_collapse_masks[..nb_ebands * channels];
2662
2663        let (x_split, y_split) = x.split_at_mut(frame_size);
2664        let y_opt = if channels == 2 { Some(y_split) } else { None };
2665
2666        let mut dual_stereo = dual_stereo_val != 0;
2667        self.w_band_amp[..nb_ebands * channels].fill(0.0);
2668        let band_amp = &mut self.w_band_amp[..nb_ebands * channels];
2669        log2amp(mode, nb_ebands, band_amp, &self.old_band_e, channels);
2670        quant_all_bands(
2671            false,
2672            mode,
2673            start_band,
2674            end_band,
2675            x_split,
2676            y_opt,
2677            collapse_masks,
2678            band_amp,
2679            pulses,
2680            short_blocks,
2681            spread_decision,
2682            &mut dual_stereo,
2683            intensity as usize,
2684            tf_res,
2685            (total_bits << BITRES) - anti_collapse_rsv,
2686            &mut balance,
2687            rc,
2688            lm as i32,
2689            coded_bands,
2690            true,
2691            false,
2692            &mut self.rng,
2693        );
2694        // Trace X values for comparison with C decoder
2695        let mut anti_collapse_on = false;
2696        if anti_collapse_rsv > 0 {
2697            anti_collapse_on = rc.dec_bits(1) != 0;
2698        }
2699
2700        unquant_energy_finalise(
2701            mode,
2702            start_band,
2703            end_band,
2704            &mut self.old_band_e,
2705            ebits,
2706            fine_priority,
2707            total_bits - rc.tell(),
2708            rc,
2709            channels,
2710        );
2711        if anti_collapse_on {
2712            self.rng = crate::bands::anti_collapse(
2713                mode,
2714                x,
2715                collapse_masks,
2716                lm as i32,
2717                channels,
2718                frame_size,
2719                start_band,
2720                nb_ebands,
2721                &self.old_band_e,
2722                &self.old_band_e2,
2723                &self.old_band_e3,
2724                pulses,
2725                self.rng,
2726            );
2727        }
2728
2729        // Recompute band_amp after unquant_energy_finalise, which adjusts old_band_e.
2730        // (Mirrors the encoder's resynth path: log2amp is called after quant_energy_finalise.)
2731        log2amp(mode, nb_ebands, band_amp, &self.old_band_e, channels);
2732        self.w_freq[..frame_size * channels].fill(0.0);
2733        let freq = &mut self.w_freq[..frame_size * channels];
2734        denormalise_bands(
2735            mode,
2736            x,
2737            freq,
2738            band_amp,
2739            start_band,
2740            end_band,
2741            channels,
2742            (1 << lm) as usize,
2743        );
2744        // Always trace freq and band_amp for comparison
2745
2746        let (shift, b) = if short_blocks {
2747            (mode.max_lm, 1 << lm)
2748        } else {
2749            (mode.max_lm - lm, 1)
2750        };
2751        let n = frame_size / b;
2752
2753        for c in 0..channels {
2754            let channel_mem_offset = c * (DECODE_BUFFER_SIZE + overlap);
2755
2756            let mem_size = DECODE_BUFFER_SIZE + overlap;
2757            self.decode_mem.copy_within(
2758                channel_mem_offset + frame_size..channel_mem_offset + mem_size,
2759                channel_mem_offset,
2760            );
2761
2762            let out_syn_idx = DECODE_BUFFER_SIZE - frame_size;
2763
2764            for i in 0..b {
2765                let block_freq_idx = c * frame_size + i;
2766                // Stride between short-block MDCT outputs is short_mdct_size (not n).
2767                // In libopus: out_syn[c] + NB*b, where NB = mode->shortMdctSize.
2768                // For non-transient b=1, i*n == 0 either way.
2769                let block_stride = if short_blocks {
2770                    mode.short_mdct_size
2771                } else {
2772                    n
2773                };
2774                let block_out_idx = channel_mem_offset + out_syn_idx + i * block_stride;
2775                let available_len = self.decode_mem.len() - block_out_idx;
2776                if available_len < n + overlap {
2777                    panic!(
2778                        "MDCT backward buffer too small: need {}, have {} (out_syn_idx={}, n={}, overlap={})",
2779                        n + overlap,
2780                        available_len,
2781                        out_syn_idx,
2782                        n,
2783                        overlap
2784                    );
2785                }
2786                self.mode.mdct.backward(
2787                    &freq[block_freq_idx..],
2788                    &mut self.decode_mem[block_out_idx..],
2789                    mode.window,
2790                    overlap,
2791                    shift,
2792                    b,
2793                );
2794            }
2795
2796            const SIG_SAT: f32 = 536870911.0;
2797            for i in 0..frame_size {
2798                let v = &mut self.decode_mem[channel_mem_offset + out_syn_idx + i];
2799                *v = v.clamp(-SIG_SAT, SIG_SAT);
2800            }
2801
2802            self.w_pcm_frame[..frame_size].fill(0.0);
2803            let pcm_frame = &mut self.w_pcm_frame[..frame_size];
2804
2805            pcm_frame.copy_from_slice(
2806                &self.decode_mem[channel_mem_offset + out_syn_idx
2807                    ..channel_mem_offset + out_syn_idx + frame_size],
2808            );
2809            if pf_on || self.prefilter_gain > 0.0 || self.prefilter_gain_old > 0.0 {
2810                // Set up w_post = [prefilter_mem | pcm_frame] for history access.
2811                // We apply combfilter in-place on w_post[COMBFILTER_MAXPERIOD..] so that
2812                // later samples can reference already-filtered earlier samples, matching C's
2813                // in-place comb_filter behavior.
2814                self.w_post[..COMBFILTER_MAXPERIOD].copy_from_slice(
2815                    &self.prefilter_mem[c * COMBFILTER_MAXPERIOD..(c + 1) * COMBFILTER_MAXPERIOD],
2816                );
2817                self.w_post[COMBFILTER_MAXPERIOD..COMBFILTER_MAXPERIOD + frame_size]
2818                    .copy_from_slice(pcm_frame);
2819
2820                let short_n = mode.short_mdct_size;
2821                // Call 1: first short_n samples, transition old→current params
2822                // Apply in-place on w_post[COMBFILTER_MAXPERIOD..], output overwrites input
2823                comb_filter_inplace(
2824                    &mut self.w_post,
2825                    COMBFILTER_MAXPERIOD,
2826                    self.prefilter_period_old,
2827                    self.prefilter_period,
2828                    short_n,
2829                    self.prefilter_gain_old,
2830                    self.prefilter_gain,
2831                    self.prefilter_tapset_old,
2832                    self.prefilter_tapset,
2833                    mode.window,
2834                    overlap,
2835                );
2836                if lm != 0 {
2837                    // Call 2: remaining N-short_n samples, transition current→new params
2838                    comb_filter_inplace(
2839                        &mut self.w_post,
2840                        COMBFILTER_MAXPERIOD + short_n,
2841                        self.prefilter_period,
2842                        pitch_index,
2843                        frame_size - short_n,
2844                        self.prefilter_gain,
2845                        gain1,
2846                        self.prefilter_tapset,
2847                        prefilter_tapset as i32,
2848                        mode.window,
2849                        overlap,
2850                    );
2851                }
2852
2853                pcm_frame.copy_from_slice(
2854                    &self.w_post[COMBFILTER_MAXPERIOD..COMBFILTER_MAXPERIOD + frame_size],
2855                );
2856
2857                self.decode_mem[channel_mem_offset + out_syn_idx
2858                    ..channel_mem_offset + out_syn_idx + frame_size]
2859                    .copy_from_slice(pcm_frame);
2860            }
2861            let mut new_mem = [0.0f32; COMBFILTER_MAXPERIOD];
2862            if frame_size >= COMBFILTER_MAXPERIOD {
2863                new_mem.copy_from_slice(&pcm_frame[frame_size - COMBFILTER_MAXPERIOD..frame_size]);
2864            } else {
2865                new_mem[..COMBFILTER_MAXPERIOD - frame_size].copy_from_slice(
2866                    &self.prefilter_mem
2867                        [c * COMBFILTER_MAXPERIOD + frame_size..(c + 1) * COMBFILTER_MAXPERIOD],
2868                );
2869                new_mem[COMBFILTER_MAXPERIOD - frame_size..].copy_from_slice(pcm_frame);
2870            }
2871            self.prefilter_mem[c * COMBFILTER_MAXPERIOD..(c + 1) * COMBFILTER_MAXPERIOD]
2872                .copy_from_slice(&new_mem);
2873
2874            let coef = mode.preemph[0];
2875            let mut m = self.preemph_mem[c];
2876            const VERY_SMALL: f32 = 1e-30f32;
2877            for i in 0..frame_size {
2878                let x = pcm_frame[i];
2879                let val = (x + VERY_SMALL + m).clamp(-SIG_SAT, SIG_SAT);
2880                pcm[c * frame_size + i] = val * (1.0 / 32768.0);
2881                m = val * coef;
2882            }
2883            self.preemph_mem[c] = m;
2884        }
2885
2886        self.prefilter_period_old = self.prefilter_period;
2887        self.prefilter_gain_old = self.prefilter_gain;
2888        self.prefilter_tapset_old = self.prefilter_tapset;
2889
2890        if pf_on {
2891            self.prefilter_period = pitch_index;
2892            self.prefilter_gain = gain1;
2893            self.prefilter_tapset = prefilter_tapset as i32;
2894        } else {
2895            self.prefilter_period = COMBFILTER_MINPERIOD;
2896            self.prefilter_gain = 0.0;
2897            self.prefilter_tapset = 0;
2898        }
2899
2900        if lm > 0 {
2901            self.prefilter_period_old = self.prefilter_period;
2902            self.prefilter_gain_old = self.prefilter_gain;
2903            self.prefilter_tapset_old = self.prefilter_tapset;
2904        }
2905
2906        if !is_transient {
2907            self.old_band_e3.copy_from_slice(&self.old_band_e2);
2908            self.old_band_e2.copy_from_slice(&self.old_band_e);
2909        } else {
2910            let nb_ebands = mode.nb_ebands;
2911            for i in 0..channels * nb_ebands {
2912                self.old_band_e2[i] = self.old_band_e2[i].min(self.old_band_e[i]);
2913            }
2914        }
2915
2916        self.rng = rc.rng;
2917
2918        frame_size
2919    }
2920}
2921
2922#[cfg(test)]
2923mod tests {
2924    use super::*;
2925    use crate::{modes, range_coder::RangeCoder};
2926
2927    // Regression test: directly drive CeltEncoder with an invalid frame_size=48,
2928    // bypassing the OpusEncoder::encode() validation layer.
2929    //
2930    // This reproduces the crash that was reported against opus-rs 0.1.19 when
2931    // G.729-decoded PCM (8 kHz) reached the 48 kHz Opus encoder without correct
2932    // resampling, producing a 48-sample frame instead of 480.
2933    //
2934    // Root cause: the lm-search in encode_impl finds no valid match for frame_size=48
2935    // (valid sizes are 120, 240, 480, 960) and silently falls back to lm=0.
2936    // With lm=0 and shift=max_lm=3: n=1920>>3=240, n2=120, overlap2=60.
2937    // The in_buf slice has only frame_size+overlap=168 elements, but forward()
2938    // requires input.len() >= n2+overlap2 = 180, so it panics immediately.
2939    // In opus-rs 0.1.19 this assertion was absent and the crash reached the MDCT
2940    // output write: "index out of bounds: the len is 48 but the index is 119".
2941    //
2942    // Either way: the call panics, confirming the crash path is real.
2943    // The fix in OpusEncoder::encode() returns Err before reaching CeltEncoder.
2944    #[test]
2945    #[should_panic]
2946    fn test_celt_frame_size_48_panics_confirms_crash_path() {
2947        let mode = modes::default_mode();
2948        let mut enc = CeltEncoder::new(mode, 1);
2949        // frame_size=48: lm-search fails, falls back to lm=0.
2950        // forward() will panic — either on the input-size assertion (0.1.21+) or
2951        // on the output write (0.1.19): "len is 48 but the index is 119".
2952        let pcm = vec![0.0f32; 48 + mode.overlap]; // supply ≥ frame_size samples
2953        let mut rc = RangeCoder::new_encoder(100);
2954        enc.encode_with_budget(&pcm, 48, &mut rc, 0, 800);
2955    }
2956}