Skip to main content

opus_rs/
celt.rs

1use crate::bands::{
2    SPREAD_NONE, SPREAD_NORMAL, compute_band_energies, denormalise_bands, haar1, log2amp,
3    normalise_bands, quant_all_bands, spreading_decision,
4};
5use crate::modes::{CeltMode, SPREAD_ICDF, TAPSET_ICDF, TF_SELECT_TABLE, TRIM_ICDF};
6use crate::quant_bands::{
7    quant_coarse_energy, quant_energy_finalise, quant_fine_energy, unquant_coarse_energy,
8    unquant_energy_finalise, unquant_fine_energy,
9};
10use crate::range_coder::RangeCoder;
11use crate::rate::{BITRES, clt_compute_allocation};
12
13#[cfg(target_arch = "aarch64")]
14use std::arch::aarch64::*;
15
16/// NEON-optimized sum of absolute values
17#[cfg(target_arch = "aarch64")]
18#[inline(always)]
19#[allow(unsafe_op_in_unsafe_fn)]
20unsafe fn sum_abs_neon(x: &[f32], n: usize) -> f32 {
21    let mut sum_vec = vdupq_n_f32(0.0);
22    let mut i = 0;
23
24    // Process 16 elements at a time
25    while i + 16 <= n {
26        let x0 = vld1q_f32(x.as_ptr().add(i));
27        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
28        let x2 = vld1q_f32(x.as_ptr().add(i + 8));
29        let x3 = vld1q_f32(x.as_ptr().add(i + 12));
30
31        // vabsq_f32 computes absolute value
32        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x0), vdupq_n_f32(1.0));
33        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x1), vdupq_n_f32(1.0));
34        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x2), vdupq_n_f32(1.0));
35        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x3), vdupq_n_f32(1.0));
36
37        i += 16;
38    }
39
40    // Process 8 elements
41    while i + 8 <= n {
42        let x0 = vld1q_f32(x.as_ptr().add(i));
43        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
44        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x0), vdupq_n_f32(1.0));
45        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x1), vdupq_n_f32(1.0));
46        i += 8;
47    }
48
49    // Process 4 elements
50    while i + 4 <= n {
51        let x0 = vld1q_f32(x.as_ptr().add(i));
52        sum_vec = vfmaq_f32(sum_vec, vabsq_f32(x0), vdupq_n_f32(1.0));
53        i += 4;
54    }
55
56    let mut sum = vaddvq_f32(sum_vec);
57
58    // Scalar tail
59    for j in i..n {
60        sum += x[j].abs();
61    }
62
63    sum
64}
65
66/// Sum of absolute values - dispatches to NEON on aarch64
67#[inline(always)]
68fn sum_abs(x: &[f32]) -> f32 {
69    #[cfg(target_arch = "aarch64")]
70    unsafe {
71        return sum_abs_neon(x, x.len());
72    }
73    #[cfg(not(target_arch = "aarch64"))]
74    {
75        x.iter().map(|&v| v.abs()).sum()
76    }
77}
78
79#[allow(dead_code)]
80const MAX_FRAME_SIZE: usize = 2880;
81
82const DECODE_BUFFER_SIZE: usize = 3072;
83
84const INV_TABLE: [u8; 128] = [
85    255, 255, 156, 110, 86, 70, 59, 51, 45, 40, 37, 33, 31, 28, 26, 25, 23, 22, 21, 20, 19, 18, 17,
86    16, 16, 15, 15, 14, 13, 13, 12, 12, 12, 12, 11, 11, 11, 10, 10, 10, 9, 9, 9, 9, 9, 9, 8, 8, 8,
87    8, 8, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
88    5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3,
89    3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2,
90];
91
92/// Max len for transient_analysis: (MAX_FRAME_SIZE + overlap) = 2880 + 120 = 3000
93const MAX_TRANSIENT_LEN: usize = 3000;
94
95#[allow(clippy::too_many_arguments)]
96fn transient_analysis(
97    input: &[f32],
98    len: usize,
99    channels: usize,
100    tf_estimate: &mut f32,
101    tf_chan: &mut usize,
102    allow_weak_transients: bool,
103    weak_transient: &mut bool,
104    _tone_freq: f32,
105    toneishness: f32,
106    tmp: &mut [f32],
107    tmp2: &mut [f32],
108) -> bool {
109    let mut mask_metric = 0.0f32;
110    let mut forward_decay = 0.0625f32;
111
112    *weak_transient = false;
113    if allow_weak_transients {
114        forward_decay = 0.03125f32;
115    }
116
117    let len2 = len / 2;
118    debug_assert!(len <= MAX_TRANSIENT_LEN);
119
120    for c in 0..channels {
121        let mut mem0 = 0.0f32;
122        let mut mem1 = 0.0f32;
123
124        for i in 0..len {
125            let x = input[c * len + i];
126            let y = mem0 + x;
127            let mem00 = mem0;
128            mem0 = mem0 - x + 0.5 * mem1;
129            mem1 = x - mem00;
130            tmp[i] = y;
131        }
132
133        tmp[..12].fill(0.0);
134
135        let mut mean = 0.0f32;
136        mem0 = 0.0f32;
137        for i in 0..len2 {
138            let x2 = (tmp[2 * i] * tmp[2 * i] + tmp[2 * i + 1] * tmp[2 * i + 1]) / 16.0;
139            mean += x2 / 4096.0;
140            mem0 = x2 + (1.0 - forward_decay) * mem0;
141            tmp2[i] = forward_decay * mem0;
142        }
143
144        mem0 = 0.0f32;
145        let mut max_e = 0.0f32;
146        for i in (0..len2).rev() {
147            mem0 = tmp2[i] + 0.875 * mem0;
148            tmp2[i] = 0.125 * mem0;
149            if tmp2[i] > max_e {
150                max_e = tmp2[i];
151            }
152        }
153
154        mean = (mean * max_e * 0.5 * (len2 as f32)).sqrt();
155        let norm = (len2 as f32) / (1e-10 + mean);
156
157        let mut unmask = 0.0f32;
158        for i in (12..(len2 - 5)).step_by(4) {
159            let id = (64.0 * norm * (tmp2[i] + 1e-10)).floor() as i32;
160            let id = id.clamp(0, 127) as usize;
161            unmask += INV_TABLE[id] as f32;
162        }
163
164        unmask = 64.0 * unmask * 4.0 / (6.0 * (len2 as f32 - 17.0));
165        if unmask > mask_metric {
166            *tf_chan = c;
167            mask_metric = unmask;
168        }
169    }
170
171    let mut is_transient = mask_metric > 200.0;
172
173    if toneishness > 0.98 && _tone_freq < 0.026 {
174        is_transient = false;
175        mask_metric = 0.0;
176    }
177
178    *tf_estimate = (mask_metric - 150.0).clamp(0.0, 1.0);
179
180    is_transient
181}
182
183fn l1_metric(tmp: &[f32], n: usize, lm: i32, bias: f32) -> f32 {
184    #[cfg(target_arch = "aarch64")]
185    {
186        if n >= 16 {
187            return unsafe { l1_metric_neon(tmp, n, lm, bias) };
188        }
189    }
190
191    let mut l1 = 0.0f32;
192    for &tv in tmp[..n].iter() {
193        l1 += tv.abs();
194    }
195    l1 + (lm as f32) * bias * l1
196}
197
198#[cfg(target_arch = "aarch64")]
199#[target_feature(enable = "neon")]
200unsafe fn l1_metric_neon(tmp: &[f32], n: usize, lm: i32, bias: f32) -> f32 {
201    unsafe {
202        let mut sum4 = vdupq_n_f32(0.0);
203        let mut i = 0;
204
205        // Process 16 elements at a time (4 vectors of 4 floats)
206        while i + 15 < n {
207            let v0 = vld1q_f32(tmp.as_ptr().add(i));
208            let v1 = vld1q_f32(tmp.as_ptr().add(i + 4));
209            let v2 = vld1q_f32(tmp.as_ptr().add(i + 8));
210            let v3 = vld1q_f32(tmp.as_ptr().add(i + 12));
211
212            sum4 = vaddq_f32(sum4, vabsq_f32(v0));
213            sum4 = vaddq_f32(sum4, vabsq_f32(v1));
214            sum4 = vaddq_f32(sum4, vabsq_f32(v2));
215            sum4 = vaddq_f32(sum4, vabsq_f32(v3));
216
217            i += 16;
218        }
219
220        // Process remaining 4-element chunks
221        while i + 3 < n {
222            let v = vld1q_f32(tmp.as_ptr().add(i));
223            sum4 = vaddq_f32(sum4, vabsq_f32(v));
224            i += 4;
225        }
226
227        // Horizontal sum
228        let sum2 = vpaddq_f32(sum4, sum4);
229        let sum1 = vpaddq_f32(sum2, sum2);
230        let mut l1 = vgetq_lane_f32(sum1, 0);
231
232        // Handle remaining elements
233        while i < n {
234            l1 += tmp[i].abs();
235            i += 1;
236        }
237
238        l1 + (lm as f32) * bias * l1
239    }
240}
241
242/// Max nb_ebands
243const MAX_NB_EBANDS: usize = 21;
244/// Max band width in tf_analysis: (e_bands[21] - e_bands[20]) << max_lm = 22 << 3 = 176
245const MAX_TF_TMP: usize = 176;
246
247#[allow(clippy::too_many_arguments)]
248fn tf_analysis(
249    mode: &CeltMode,
250    len: usize,
251    is_transient: bool,
252    tf_res: &mut [i32],
253    lambda: i32,
254    x: &[f32],
255    n0: usize,
256    lm: i32,
257    tf_estimate: f32,
258    tf_chan: usize,
259) -> i32 {
260    debug_assert!(len <= MAX_NB_EBANDS);
261    let mut metric = [0i32; MAX_NB_EBANDS];
262    let mut tmp = [0.0f32; MAX_TF_TMP];
263    let mut tmp_1 = [0.0f32; MAX_TF_TMP];
264
265    let bias = 0.04 * (-0.25f32).max(0.5 - tf_estimate);
266
267    for (i, metric_i) in metric[..len].iter_mut().enumerate() {
268        let n = ((mode.e_bands[i + 1] - mode.e_bands[i]) as usize) << lm;
269        let narrow = (mode.e_bands[i + 1] - mode.e_bands[i]) == 1;
270        let offset = tf_chan * n0 + ((mode.e_bands[i] as usize) << lm);
271        tmp[..n].copy_from_slice(&x[offset..offset + n]);
272
273        let mut l1 = l1_metric(&tmp[..n], n, if is_transient { lm } else { 0 }, bias);
274        let mut best_l1 = l1;
275        let mut best_level = 0;
276
277        if is_transient && !narrow {
278            tmp_1[..n].copy_from_slice(&tmp[..n]);
279            haar1(&mut tmp_1[..n], n >> lm, 1 << lm);
280            l1 = l1_metric(&tmp_1[..n], n, lm + 1, bias);
281            if l1 < best_l1 {
282                best_l1 = l1;
283                best_level = -1;
284            }
285        }
286
287        for k in 0..(lm + if is_transient || narrow { 0 } else { 1 }) {
288            let b = if is_transient { lm - k - 1 } else { k + 1 };
289
290            haar1(&mut tmp[..n], n >> k, 1 << k);
291            l1 = l1_metric(&tmp[..n], n, b, bias);
292
293            if l1 < best_l1 {
294                best_l1 = l1;
295                best_level = k + 1;
296            }
297        }
298
299        if is_transient {
300            *metric_i = 2 * best_level;
301        } else {
302            *metric_i = -2 * best_level;
303        }
304
305        if narrow && (*metric_i == 0 || *metric_i == -2 * lm) {
306            *metric_i -= 1;
307        }
308    }
309
310    let mut tf_select = 0;
311    let importance = [1.0f32; MAX_NB_EBANDS];
312    let mut selcost = [0.0f32; 2];
313
314    for sel in 0..2 {
315        let mut cost0 = importance[0]
316            * ((metric[0]
317                - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * sel] as i32)
318                as f32)
319                .abs();
320        let mut cost1 = importance[0]
321            * ((metric[0]
322                - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * sel + 1]
323                    as i32) as f32)
324                .abs()
325            + (if is_transient { 0.0 } else { lambda as f32 });
326
327        for i in 1..len {
328            let curr0 = cost0.min(cost1 + lambda as f32);
329            let curr1 = (cost0 + lambda as f32).min(cost1);
330            cost0 = curr0
331                + importance[i]
332                    * ((metric[i]
333                        - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * sel]
334                            as i32) as f32)
335                        .abs();
336            cost1 = curr1
337                + importance[i]
338                    * ((metric[i]
339                        - 2 * TF_SELECT_TABLE[lm as usize]
340                            [4 * (is_transient as usize) + 2 * sel + 1]
341                            as i32) as f32)
342                        .abs();
343        }
344        selcost[sel] = cost0.min(cost1);
345    }
346
347    if selcost[1] < selcost[0] {
348        tf_select = 1;
349    }
350
351    let mut cost0 = importance[0]
352        * ((metric[0]
353            - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * tf_select] as i32)
354            as f32)
355            .abs();
356    let mut cost1 = importance[0]
357        * ((metric[0]
358            - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * tf_select + 1]
359                as i32) as f32)
360            .abs()
361        + (if is_transient { 0.0 } else { lambda as f32 });
362
363    tf_res[0] = if cost0 < cost1 { 0 } else { 1 };
364
365    for i in 1..len {
366        let curr0 = cost0.min(cost1 + lambda as f32);
367        let curr1 = (cost0 + lambda as f32).min(cost1);
368        cost0 = curr0
369            + importance[i]
370                * ((metric[i]
371                    - 2 * TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 * tf_select]
372                        as i32) as f32)
373                    .abs();
374        cost1 = curr1
375            + importance[i]
376                * ((metric[i]
377                    - 2 * TF_SELECT_TABLE[lm as usize]
378                        [4 * (is_transient as usize) + 2 * tf_select + 1]
379                        as i32) as f32)
380                    .abs();
381        tf_res[i] = if cost0 < cost1 { 0 } else { 1 };
382    }
383
384    tf_select as i32
385}
386
387fn tf_encode(
388    start: usize,
389    end: usize,
390    is_transient: bool,
391    tf_res: &mut [i32],
392    lm: i32,
393    mut tf_select: i32,
394    rc: &mut RangeCoder,
395) -> i32 {
396    let mut curr = 0;
397    let mut tf_changed = 0;
398    let mut logp = if is_transient { 2 } else { 4 };
399    let mut budget = rc.storage as i32 * 8;
400    let mut tell = rc.tell();
401
402    let tf_select_rsv = if lm > 0 && tell + logp < budget { 1 } else { 0 };
403    budget -= tf_select_rsv;
404
405    for tf_res_i in tf_res[start..end].iter_mut() {
406        if tell + logp <= budget {
407            rc.encode_bit_logp(*tf_res_i ^ curr != 0, logp as u32);
408            tell = rc.tell();
409            curr = *tf_res_i;
410            tf_changed |= curr;
411        } else {
412            *tf_res_i = curr;
413        }
414        logp = if is_transient { 4 } else { 5 };
415    }
416
417    if tf_select_rsv != 0
418        && TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + (tf_changed as usize)]
419            != TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 + (tf_changed as usize)]
420    {
421        rc.encode_bit_logp(tf_select != 0, 1);
422    } else {
423        tf_select = 0;
424    }
425
426    for tf_res_i in tf_res[start..end].iter_mut() {
427        *tf_res_i = TF_SELECT_TABLE[lm as usize]
428            [4 * (is_transient as usize) + 2 * (tf_select as usize) + (*tf_res_i as usize)]
429            as i32;
430    }
431
432    tf_changed
433}
434
435fn tf_decode(
436    start: usize,
437    end: usize,
438    is_transient: bool,
439    tf_res: &mut [i32],
440    lm: i32,
441    rc: &mut RangeCoder,
442) {
443    let mut curr = 0;
444    let mut tf_changed = 0;
445    let mut logp = if is_transient { 2 } else { 4 };
446    let budget = rc.storage as i32 * 8;
447    let mut tell = rc.tell();
448
449    let tf_select_rsv = if lm > 0 && tell + logp < budget { 1 } else { 0 };
450    let budget = budget - tf_select_rsv;
451
452    for tf_res_i in tf_res[start..end].iter_mut() {
453        if tell + logp <= budget {
454            curr ^= if rc.decode_bit_logp(logp as u32) {
455                1
456            } else {
457                0
458            };
459            tell = rc.tell();
460            tf_changed |= curr;
461        }
462        *tf_res_i = curr;
463        logp = if is_transient { 4 } else { 5 };
464    }
465
466    let mut tf_select = 0;
467    let _budget = budget + tf_select_rsv;
468    if tf_select_rsv > 0
469        && TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + (tf_changed as usize)]
470            != TF_SELECT_TABLE[lm as usize][4 * (is_transient as usize) + 2 + (tf_changed as usize)]
471    {
472        tf_select = if rc.decode_bit_logp(1) { 1 } else { 0 };
473    }
474
475    for tf_res_i in tf_res[start..end].iter_mut() {
476        *tf_res_i = TF_SELECT_TABLE[lm as usize]
477            [4 * (is_transient as usize) + 2 * (tf_select as usize) + (*tf_res_i as usize)]
478            as i32;
479    }
480}
481
482fn stereo_analysis(m: &CeltMode, x: &[f32], lm: i32, n0: usize) -> bool {
483    let mut sum_lr = 1e-9f32;
484    let mut sum_ms = 1e-9f32;
485
486    for i in 0..13 {
487        let start = (m.e_bands[i] as usize) << lm;
488        let end = (m.e_bands[i + 1] as usize) << lm;
489        for j in start..end {
490            let l = x[j];
491            let r = x[n0 + j];
492            let m_val = l + r;
493            let s_val = l - r;
494            sum_lr += l.abs() + r.abs();
495            sum_ms += m_val.abs() + s_val.abs();
496        }
497    }
498
499    sum_ms *= std::f32::consts::FRAC_1_SQRT_2;
500    let mut thetas = 13;
501    if lm <= 1 {
502        thetas -= 8;
503    }
504
505    let left = (((m.e_bands[13] as usize) << (lm + 1)) + thetas) as f32 * sum_ms;
506    let right = ((m.e_bands[13] as usize) << (lm + 1)) as f32 * sum_lr;
507
508    left > right
509}
510
511const COMBFILTER_MINPERIOD: usize = 15;
512const COMBFILTER_MAXPERIOD: usize = 1024;
513
514const PREFILTER_GAINS: [[f32; 3]; 3] = [
515    [0.306_640_6, 0.217_041, 0.129_638_7],
516    [0.463_867_2, 0.268_066_4, 0.0],
517    [0.799_804_7, 0.100_097_7, 0.0],
518];
519
520#[allow(clippy::too_many_arguments)]
521fn comb_filter_const(
522    y: &mut [f32],
523    x: &[f32],
524    y_idx: usize,
525    x_idx: usize,
526    t: usize,
527    n: usize,
528    g10: f32,
529    g11: f32,
530    g12: f32,
531) {
532    #[cfg(target_arch = "aarch64")]
533    {
534        comb_filter_const_neon(y, x, y_idx, x_idx, t, n, g10, g11, g12);
535        return;
536    }
537    #[cfg(not(target_arch = "aarch64"))]
538    {
539        comb_filter_const_scalar(y, x, y_idx, x_idx, t, n, g10, g11, g12);
540    }
541}
542
543#[inline]
544fn comb_filter_const_scalar(
545    y: &mut [f32],
546    x: &[f32],
547    y_idx: usize,
548    x_idx: usize,
549    t: usize,
550    n: usize,
551    g10: f32,
552    g11: f32,
553    g12: f32,
554) {
555    let mut x1;
556    let mut x2;
557    let mut x3;
558    let mut x4;
559    let mut x0;
560
561    x4 = x[x_idx - t - 2];
562    x3 = x[x_idx - t - 1];
563    x2 = x[x_idx - t];
564    x1 = x[x_idx - t + 1];
565
566    for i in 0..n {
567        x0 = x[x_idx + i - t + 2];
568        y[y_idx + i] = x[x_idx + i] + g10 * x2 + g11 * (x1 + x3) + g12 * (x0 + x4);
569        x4 = x3;
570        x3 = x2;
571        x2 = x1;
572        x1 = x0;
573    }
574}
575
576#[cfg(target_arch = "aarch64")]
577fn comb_filter_const_neon(
578    y: &mut [f32],
579    x: &[f32],
580    y_idx: usize,
581    x_idx: usize,
582    t: usize,
583    n: usize,
584    g10: f32,
585    g11: f32,
586    g12: f32,
587) {
588    // For now, use scalar version - NEON would need more complex handling
589    // due to the sliding delay line pattern
590    comb_filter_const_scalar(y, x, y_idx, x_idx, t, n, g10, g11, g12);
591}
592
593#[allow(clippy::too_many_arguments)]
594fn comb_filter(
595    y: &mut [f32],
596    x: &[f32],
597    y_idx: usize,
598    x_idx: usize,
599    t0: usize,
600    t1: usize,
601    n: usize,
602    g0: f32,
603    g1: f32,
604    tapset0: i32,
605    tapset1: i32,
606    window: &[f32],
607    overlap: usize,
608) {
609    if g0 == 0.0 && g1 == 0.0 {
610        if x_idx != y_idx || !std::ptr::eq(x.as_ptr(), y.as_ptr()) {
611            y[y_idx..y_idx + n].copy_from_slice(&x[x_idx..x_idx + n]);
612        }
613        return;
614    }
615
616    let t0 = t0.max(COMBFILTER_MINPERIOD);
617    let t1 = t1.max(COMBFILTER_MINPERIOD);
618
619    let g00 = g0 * PREFILTER_GAINS[tapset0 as usize][0];
620    let g01 = g0 * PREFILTER_GAINS[tapset0 as usize][1];
621    let g02 = g0 * PREFILTER_GAINS[tapset0 as usize][2];
622
623    let g10 = g1 * PREFILTER_GAINS[tapset1 as usize][0];
624    let g11 = g1 * PREFILTER_GAINS[tapset1 as usize][1];
625    let g12 = g1 * PREFILTER_GAINS[tapset1 as usize][2];
626
627    let mut x1 = x[x_idx - t1 + 1];
628    let mut x2 = x[x_idx - t1];
629    let mut x3 = x[x_idx - t1 - 1];
630    let mut x4 = x[x_idx - t1 - 2];
631
632    let mut inner_overlap = overlap;
633    if g0 == g1 && t0 == t1 && tapset0 == tapset1 {
634        inner_overlap = 0;
635    }
636
637    let mut i = 0;
638    while i < inner_overlap && i < n {
639        let x0 = x[x_idx + i - t1 + 2];
640        let f = window[i] * window[i];
641        y[y_idx + i] = x[x_idx + i]
642            + (1.0 - f)
643                * (g00 * x[x_idx + i - t0]
644                    + g01 * (x[x_idx + i - t0 + 1] + x[x_idx + i - t0 - 1])
645                    + g02 * (x[x_idx + i - t0 + 2] + x[x_idx + i - t0 - 2]))
646            + f * (g10 * x2 + g11 * (x1 + x3) + g12 * (x0 + x4));
647
648        x4 = x3;
649        x3 = x2;
650        x2 = x1;
651        x1 = x0;
652        i += 1;
653    }
654
655    if i < n {
656        if g1 == 0.0 {
657            y[y_idx + i..y_idx + n].copy_from_slice(&x[x_idx + i..x_idx + n]);
658        } else {
659            comb_filter_const(y, x, y_idx + i, x_idx + i, t1, n - i, g10, g11, g12);
660        }
661    }
662}
663
664/// Compute CELT pitch pre-filter parameters and apply the filter to in_buf.
665/// Returns (pf_on, gain1, pitch_index).
666/// Matches C's run_prefilter() in celt_encoder.c.
667fn run_prefilter(
668    in_buf: &mut [f32],
669    prefilter_mem: &mut [f32],
670    prefilter_period: usize,
671    prefilter_gain: f32,
672    prefilter_tapset: i32,
673    tapset_decision: i32,
674    window: &[f32],
675    channels: usize,
676    frame_size: usize,
677    overlap: usize,
678    // Pre-allocated buffers to avoid vec! allocation
679    pre: &mut [f32],
680    pitch_buf: &mut [f32],
681    before: &mut [f32],
682    after: &mut [f32],
683) -> (bool, f32, usize) {
684    let max_period = COMBFILTER_MAXPERIOD; // 1024
685    let min_period = COMBFILTER_MINPERIOD; // 15
686    let buf_stride = frame_size + overlap;
687    let pre_size = max_period + frame_size; // 1984
688
689    // Build pre[c] = [prefilter_mem[c*max_period..(c+1)*max_period] | in_buf current frame]
690    for c in 0..channels {
691        pre[c * pre_size..c * pre_size + max_period]
692            .copy_from_slice(&prefilter_mem[c * max_period..(c + 1) * max_period]);
693        pre[c * pre_size + max_period..c * pre_size + pre_size].copy_from_slice(
694            &in_buf[c * buf_stride + overlap..c * buf_stride + overlap + frame_size],
695        );
696    }
697
698    // Downsample for pitch analysis
699    let pitch_buf_len = (max_period + frame_size) >> 1; // 992
700    {
701        let pre_slices: Vec<&[f32]> = (0..channels)
702            .map(|c| &pre[c * pre_size..c * pre_size + pre_size])
703            .collect();
704        crate::pitch::pitch_downsample(&pre_slices, pitch_buf, pitch_buf_len, channels, 2);
705    }
706
707    // Find pitch period
708    let search_max = max_period - 3 * min_period; // 979
709    let pitch_result = crate::pitch::pitch_search(
710        &pitch_buf[max_period >> 1..],
711        &pitch_buf,
712        frame_size,
713        search_max,
714    );
715    let mut pitch_index = (max_period - pitch_result).min(max_period - 2);
716
717    // Refine pitch and compute gain via remove_doubling
718    let gain1_raw = crate::pitch::remove_doubling(
719        &pitch_buf,
720        max_period,
721        min_period,
722        frame_size,
723        &mut pitch_index,
724        prefilter_period,
725        prefilter_gain,
726    );
727    let mut gain1 = gain1_raw * 0.7; // C: MULT16_16_Q15(0.7, gain1)
728
729    // Gain threshold
730    let mut pf_threshold = 0.2f32;
731    if (pitch_index as i32 - prefilter_period as i32).unsigned_abs() as usize * 10 > pitch_index {
732        pf_threshold += 0.2;
733    }
734    if prefilter_gain > 0.4 {
735        pf_threshold -= 0.1;
736    }
737    if prefilter_gain > 0.55 {
738        pf_threshold -= 0.1;
739    }
740    pf_threshold = pf_threshold.max(0.2);
741
742    let pf_on;
743    if gain1 < pf_threshold {
744        gain1 = 0.0;
745        pf_on = false;
746    } else {
747        if (gain1 - prefilter_gain).abs() < 0.1 {
748            gain1 = prefilter_gain;
749        }
750        let qg = ((gain1 * 32.0 / 3.0 + 0.5).floor() as i32 - 1).clamp(0, 7);
751        gain1 = 0.09375 * (qg + 1) as f32;
752        pf_on = true;
753    }
754
755    // Compute "before" energy to check if filter helps
756    let before = &mut before[..channels];
757    for c in 0..channels {
758        let start = c * buf_stride + overlap;
759        before[c] = sum_abs(&in_buf[start..start + frame_size]);
760    }
761
762    // Apply the comb pre-filter (negative gain) to in_buf
763    // offset = shortMdctSize - overlap = 120 - 120 = 0 for 20ms at 48kHz
764    let offset = 0usize; // mode.short_mdct_size - overlap (always 0 for 20ms frames)
765    let prev_period = prefilter_period.max(COMBFILTER_MINPERIOD);
766
767    for c in 0..channels {
768        if offset > 0 {
769            // First segment uses old period/gain only
770            let pre_c = &pre[c * pre_size..];
771            comb_filter(
772                in_buf,
773                pre_c,
774                c * buf_stride + overlap,
775                max_period,
776                prev_period,
777                prev_period,
778                offset,
779                -prefilter_gain,
780                -prefilter_gain,
781                prefilter_tapset,
782                prefilter_tapset,
783                window,
784                0,
785            );
786        }
787
788        // Second segment: transition from old period/gain to new
789        {
790            let pre_c = &pre[c * pre_size..];
791            comb_filter(
792                in_buf,
793                pre_c,
794                c * buf_stride + overlap + offset,
795                max_period + offset,
796                prev_period,
797                pitch_index,
798                frame_size - offset,
799                -prefilter_gain,
800                -gain1,
801                prefilter_tapset,
802                tapset_decision,
803                window,
804                overlap,
805            );
806        }
807    }
808
809    // Compute "after" energy
810    let after = &mut after[..channels];
811    for c in 0..channels {
812        let start = c * buf_stride + overlap;
813        after[c] = sum_abs(&in_buf[start..start + frame_size]);
814    }
815
816    // Check if filter helped: revert if any channel got worse
817    let cancel_pitch = (0..channels).any(|c| after[c] > before[c]);
818
819    if cancel_pitch {
820        // Restore original signal from pre
821        for c in 0..channels {
822            in_buf[c * buf_stride + overlap..c * buf_stride + overlap + frame_size]
823                .copy_from_slice(
824                    &pre[c * pre_size + max_period..c * pre_size + max_period + frame_size],
825                );
826        }
827        // Update prefilter_mem with current frame
828        for c in 0..channels {
829            if frame_size >= max_period {
830                prefilter_mem[c * max_period..(c + 1) * max_period].copy_from_slice(
831                    &pre[c * pre_size + frame_size..c * pre_size + frame_size + max_period],
832                );
833            } else {
834                let shift = max_period - frame_size;
835                prefilter_mem.copy_within(
836                    c * max_period + frame_size..(c + 1) * max_period,
837                    c * max_period,
838                );
839                prefilter_mem[c * max_period + shift..(c + 1) * max_period].copy_from_slice(
840                    &pre[c * pre_size + max_period..c * pre_size + max_period + frame_size],
841                );
842            }
843        }
844        return (false, 0.0, pitch_index);
845    }
846
847    // Update prefilter_mem with current frame
848    for c in 0..channels {
849        if frame_size >= max_period {
850            prefilter_mem[c * max_period..(c + 1) * max_period].copy_from_slice(
851                &pre[c * pre_size + frame_size..c * pre_size + frame_size + max_period],
852            );
853        } else {
854            let shift = max_period - frame_size;
855            prefilter_mem.copy_within(
856                c * max_period + frame_size..(c + 1) * max_period,
857                c * max_period,
858            );
859            prefilter_mem[c * max_period + shift..(c + 1) * max_period].copy_from_slice(
860                &pre[c * pre_size + max_period..c * pre_size + max_period + frame_size],
861            );
862        }
863    }
864
865    (pf_on, gain1, pitch_index)
866}
867
868/// Max nb_ebands * max channels
869#[allow(dead_code)]
870const MAX_EBANDS_X_CH: usize = 21 * 2;
871/// Max frame_size * max channels (2880 * 2)
872#[allow(dead_code)]
873const MAX_FRAME_X_CH: usize = MAX_FRAME_SIZE * 2;
874/// Padding for stride-based access in alg_unquant/exp_rotation.
875/// Max stride = 8 (1 << max_lm), max band = 352 (MAX_PVQ_N).
876const STRIDE_ACCESS_PAD: usize = crate::pvq::MAX_PVQ_N * 8;
877/// Max buf_stride * max channels ((2880 + 120) * 2)
878#[allow(dead_code)]
879const MAX_BUFSTRIDE_X_CH: usize = (MAX_FRAME_SIZE + 120) * 2;
880
881pub struct CeltEncoder {
882    mode: &'static CeltMode,
883    channels: usize,
884    pub complexity: i32,
885    syn_mem: Vec<f32>,
886    enc_decode_mem: Vec<f32>,
887    old_band_e: Vec<f32>,
888    preemph_mem: Vec<f32>,
889    tonal_average: i32,
890    hf_average: i32,
891    tapset_decision: i32,
892    spread_decision: i32,
893    intensity: i32,
894    last_coded_bands: i32,
895    prefilter_mem: Vec<f32>,
896    prefilter_period: usize,
897    prefilter_gain: f32,
898    prefilter_tapset: i32,
899    old_band_e2: Vec<f32>,
900    old_band_e3: Vec<f32>,
901    last_band_log_e: Vec<f32>,
902    // Pre-allocated working buffers for encode_impl
903    w_in_buf: Vec<f32>,
904    w_freq: Vec<f32>,
905    w_band_e: Vec<f32>,
906    w_x: Vec<f32>,
907    w_band_log_e: Vec<f32>,
908    w_error: Vec<f32>,
909    w_tf_res: Vec<i32>,
910    w_cap: Vec<i32>,
911    w_offsets: Vec<i32>,
912    w_pulses: Vec<i32>,
913    w_ebits: Vec<i32>,
914    w_fine_priority: Vec<i32>,
915    w_collapse_masks: Vec<u32>,
916    w_band_amp_synth: Vec<f32>,
917    w_freq_synth: Vec<f32>,
918    consec_transient: i32,
919    // Pre-allocated buffers for run_prefilter to avoid vec! allocation
920    w_prefilter_pre: Vec<f32>,
921    w_prefilter_pitch_buf: Vec<f32>,
922    w_prefilter_before: Vec<f32>,
923    w_prefilter_after: Vec<f32>,
924    // Pre-allocated buffers for transient_analysis to avoid 18KB stack arrays per frame
925    w_transient_tmp: Vec<f32>,
926    w_transient_tmp2: Vec<f32>,
927}
928
929const INTEN_THRESHOLDS: [i32; 21] = [
930    1, 2, 3, 4, 5, 6, 7, 8, 16, 24, 36, 44, 50, 56, 62, 67, 72, 79, 88, 106, 134,
931];
932const INTEN_HYSTERESIS: [i32; 21] = [
933    1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 4, 5, 6, 8, 8,
934];
935
936fn hysteresis_decision(val: i32, thresholds: &[i32], hysteresis: &[i32], prev: i32) -> i32 {
937    let mut i = 0;
938    while i < thresholds.len() {
939        if val < thresholds[i] {
940            break;
941        }
942        i += 1;
943    }
944    let mut res = i as i32;
945    if res > prev && val < thresholds[prev as usize] + hysteresis[prev as usize] {
946        res = prev;
947    }
948    if res < prev && res > 0 && val > thresholds[prev as usize - 1] - hysteresis[prev as usize - 1]
949    {
950        res = prev;
951    }
952    res
953}
954
955#[allow(clippy::too_many_arguments)]
956fn alloc_trim_analysis(
957    mode: &CeltMode,
958    x: &[f32],
959    band_log_e: &[f32],
960    end: usize,
961    lm: i32,
962    channels: usize,
963    n0: usize,
964    stereo_saving: &mut f32,
965    tf_estimate: f32,
966    intensity: i32,
967    surround_trim: f32,
968    equiv_rate: i32,
969) -> i32 {
970    let mut trim = 5.0f32;
971    if equiv_rate < 64000 {
972        trim = 4.0;
973    } else if equiv_rate < 80000 {
974        let frac = (equiv_rate - 64000) as f32 / 1024.0;
975        trim = 4.0 + (1.0 / 16.0) * frac;
976    }
977
978    if channels == 2 {
979        let mut sum = 0.0f32;
980        for i in 0..8 {
981            let offset = (mode.e_bands[i] as usize) << lm;
982            let n = ((mode.e_bands[i + 1] - mode.e_bands[i]) as usize) << lm;
983            let mut partial = 0.0f32;
984            for j in 0..n {
985                partial += x[offset + j] * x[n0 + offset + j];
986            }
987            sum += partial;
988        }
989        sum = (sum / 8.0).abs().min(1.0);
990        let mut min_xc = sum;
991        for i in 8..intensity as usize {
992            let offset = (mode.e_bands[i] as usize) << lm;
993            let n = ((mode.e_bands[i + 1] - mode.e_bands[i]) as usize) << lm;
994            let mut partial = 0.0f32;
995            for j in 0..n {
996                partial += x[offset + j] * x[n0 + offset + j];
997            }
998            min_xc = min_xc.min(partial.abs());
999        }
1000        min_xc = min_xc.min(1.0);
1001
1002        let log_xc = (1.001 - sum * sum).log2();
1003        let log_xc2 = (log_xc * 0.5).max((1.001 - min_xc * min_xc).log2());
1004
1005        trim += (-4.0f32).max(0.75 * log_xc);
1006        *stereo_saving = (*stereo_saving + 0.25).min(-0.5 * log_xc2);
1007    }
1008
1009    let mut diff = 0.0f32;
1010    for c in 0..channels {
1011        for i in 0..end - 1 {
1012            diff += band_log_e[c * mode.nb_ebands + i] * (2 + 2 * i as i32 - end as i32) as f32;
1013        }
1014    }
1015    diff /= (channels * (end - 1)) as f32;
1016    trim -= (-2.0f32).max(2.0f32.min((diff + 1.0) / 6.0));
1017    trim -= surround_trim;
1018    trim -= 2.0 * tf_estimate;
1019
1020    let trim_index = (trim + 0.5).floor() as i32;
1021    trim_index.clamp(0, 10)
1022}
1023
1024impl CeltEncoder {
1025    pub fn new(mode: &'static CeltMode, channels: usize) -> Self {
1026        let overlap = mode.overlap;
1027        let channel_mem_size = 2048 + overlap;
1028        let syn_mem_size = channels * channel_mem_size;
1029        let nb_ebands = mode.nb_ebands;
1030        let nb_x_ch = nb_ebands * channels;
1031        let frame_x_ch = MAX_FRAME_SIZE * channels;
1032        let bufstride_x_ch = (MAX_FRAME_SIZE + overlap) * channels;
1033        Self {
1034            mode,
1035            channels,
1036            complexity: 9,
1037            syn_mem: vec![0.0; syn_mem_size],
1038            enc_decode_mem: vec![0.0; syn_mem_size],
1039            old_band_e: vec![-28.0; nb_x_ch],
1040            preemph_mem: vec![0.0; channels],
1041            tonal_average: 256,
1042            hf_average: 0,
1043            tapset_decision: 0,
1044            spread_decision: SPREAD_NORMAL,
1045            intensity: 0,
1046            last_coded_bands: 0,
1047            prefilter_mem: vec![0.0; channels * COMBFILTER_MAXPERIOD],
1048            prefilter_period: COMBFILTER_MINPERIOD,
1049            prefilter_gain: 0.0,
1050            prefilter_tapset: 0,
1051            old_band_e2: vec![-28.0; nb_x_ch],
1052            old_band_e3: vec![-28.0; nb_x_ch],
1053            last_band_log_e: vec![-28.0; nb_x_ch],
1054            // Pre-allocate working buffers
1055            w_in_buf: vec![0.0; bufstride_x_ch],
1056            w_freq: vec![0.0; frame_x_ch],
1057            w_band_e: vec![0.0; nb_x_ch],
1058            // Extra padding for stride-based access in alg_unquant/exp_rotation
1059            w_x: vec![0.0; frame_x_ch + STRIDE_ACCESS_PAD],
1060            w_band_log_e: vec![0.0; nb_x_ch],
1061            w_error: vec![0.0; nb_x_ch],
1062            w_tf_res: vec![0; nb_ebands],
1063            w_cap: vec![0; nb_ebands],
1064            w_offsets: vec![0; nb_ebands],
1065            w_pulses: vec![0; nb_ebands],
1066            w_ebits: vec![0; nb_x_ch],
1067            w_fine_priority: vec![0; nb_x_ch],
1068            w_collapse_masks: vec![0; nb_x_ch],
1069            w_band_amp_synth: vec![0.0; nb_x_ch],
1070            w_freq_synth: vec![0.0; frame_x_ch],
1071            // Max sizes for run_prefilter buffers (for max frame_size=2880)
1072            // pre_size = max_period + frame_size = 1024 + 2880 = 3904 per channel
1073            // pitch_buf_len = (max_period + frame_size) >> 1 = 1952
1074            w_prefilter_pre: vec![0.0; channels * (COMBFILTER_MAXPERIOD + MAX_FRAME_SIZE)],
1075            w_prefilter_pitch_buf: vec![0.0; (COMBFILTER_MAXPERIOD + MAX_FRAME_SIZE) >> 1],
1076            w_prefilter_before: vec![0.0; channels],
1077            w_prefilter_after: vec![0.0; channels],
1078            w_transient_tmp: vec![0.0; MAX_TRANSIENT_LEN],
1079            w_transient_tmp2: vec![0.0; MAX_TRANSIENT_LEN / 2],
1080            consec_transient: 0,
1081        }
1082    }
1083
1084    pub fn encode(&mut self, pcm: &[f32], frame_size: usize, rc: &mut RangeCoder) {
1085        self.encode_impl(pcm, frame_size, rc, 0, None)
1086    }
1087
1088    pub fn encode_with_start_band(
1089        &mut self,
1090        pcm: &[f32],
1091        frame_size: usize,
1092        rc: &mut RangeCoder,
1093        start_band: usize,
1094    ) {
1095        self.encode_impl(pcm, frame_size, rc, start_band, None)
1096    }
1097
1098    /// Encode with explicit total_bits (for Hybrid mode where SILK has already used some bits)
1099    pub fn encode_with_budget(
1100        &mut self,
1101        pcm: &[f32],
1102        frame_size: usize,
1103        rc: &mut RangeCoder,
1104        start_band: usize,
1105        total_bits: i32,
1106    ) {
1107        self.encode_impl(pcm, frame_size, rc, start_band, Some(total_bits))
1108    }
1109
1110    fn encode_impl(
1111        &mut self,
1112        pcm: &[f32],
1113        frame_size: usize,
1114        rc: &mut RangeCoder,
1115        start_band: usize,
1116        explicit_total_bits: Option<i32>,
1117    ) {
1118        let mode = self.mode;
1119        let channels = self.channels;
1120        let nb_ebands = mode.nb_ebands;
1121        let overlap = mode.overlap;
1122
1123        let mut lm = 0;
1124        while (mode.short_mdct_size << lm) != frame_size {
1125            lm += 1;
1126            if lm > mode.max_lm {
1127                break;
1128            }
1129        }
1130        if (mode.short_mdct_size << lm) != frame_size {
1131            lm = 0;
1132        }
1133
1134        let syn_mem_size = 2048 + overlap;
1135        for c in 0..channels {
1136            let channel_offset = c * syn_mem_size;
1137
1138            self.syn_mem.copy_within(
1139                channel_offset + frame_size..channel_offset + syn_mem_size,
1140                channel_offset,
1141            );
1142
1143            let mut m = self.preemph_mem[c];
1144            let coef = mode.preemph[0];
1145            for i in 0..frame_size {
1146                let x = pcm[c * frame_size + i];
1147                let val = x - m;
1148                self.syn_mem[channel_offset + syn_mem_size - frame_size + i] = val;
1149                m = x * coef;
1150            }
1151            self.preemph_mem[c] = m;
1152        }
1153
1154        let buf_stride = frame_size + overlap;
1155        let in_buf = &mut self.w_in_buf[..buf_stride * channels];
1156        for c in 0..channels {
1157            let channel_offset = c * syn_mem_size;
1158            let in_buf_offset = c * buf_stride;
1159
1160            let src_start = syn_mem_size - frame_size - overlap;
1161            in_buf[in_buf_offset..in_buf_offset + buf_stride].copy_from_slice(
1162                &self.syn_mem[channel_offset + src_start..channel_offset + syn_mem_size],
1163            );
1164        }
1165
1166        let mut tf_estimate = 0.0f32;
1167        let mut tf_chan = 0;
1168        let mut weak_transient = false;
1169        // C opus skips transient_analysis at complexity < 1
1170        let is_transient = if self.complexity >= 1 {
1171            transient_analysis(
1172                &in_buf,
1173                buf_stride,
1174                channels,
1175                &mut tf_estimate,
1176                &mut tf_chan,
1177                false,
1178                &mut weak_transient,
1179                0.0,
1180                0.0,
1181                &mut self.w_transient_tmp,
1182                &mut self.w_transient_tmp2,
1183            )
1184        } else {
1185            false
1186        };
1187
1188        // C opus skips pitch prefilter when complexity < 5
1189        let pf_enabled = start_band == 0 && self.complexity >= 5;
1190        let (pf_on, gain1, pitch_index) = if pf_enabled {
1191            run_prefilter(
1192                in_buf,
1193                &mut self.prefilter_mem,
1194                self.prefilter_period,
1195                self.prefilter_gain,
1196                self.prefilter_tapset,
1197                self.tapset_decision,
1198                mode.window,
1199                channels,
1200                frame_size,
1201                overlap,
1202                &mut self.w_prefilter_pre,
1203                &mut self.w_prefilter_pitch_buf,
1204                &mut self.w_prefilter_before,
1205                &mut self.w_prefilter_after,
1206            )
1207        } else {
1208            (false, 0.0f32, COMBFILTER_MINPERIOD)
1209        };
1210
1211        let freq = &mut self.w_freq[..frame_size * channels];
1212        let (shift, b) = if is_transient {
1213            (mode.max_lm, 1 << lm)
1214        } else {
1215            (mode.max_lm - lm, 1)
1216        };
1217        let n = frame_size / b;
1218
1219        for c in 0..channels {
1220            let c_buf_offset = c * buf_stride;
1221
1222            if c == 0 && b == 1 && channels == 1 {
1223                let mut max_val = 0.0f32;
1224                let check_len = (frame_size + overlap).min(buf_stride);
1225                for j in 0..check_len {
1226                    max_val = max_val.max(in_buf[c_buf_offset + j].abs());
1227                }
1228            }
1229
1230            for i in 0..b {
1231                mode.mdct.forward(
1232                    &in_buf[c_buf_offset + i * n..],
1233                    &mut freq[c * frame_size + i..],
1234                    mode.window,
1235                    overlap,
1236                    shift,
1237                    b,
1238                );
1239            }
1240        }
1241
1242        let band_e = &mut self.w_band_e[..nb_ebands * channels];
1243        compute_band_energies(mode, &freq, band_e, nb_ebands, channels, lm);
1244
1245        // Include stride-access padding so alg_unquant/exp_rotation can use
1246        // x[i*stride] without going out of bounds (matches C's raw-pointer access).
1247        let x_pad_end = (frame_size * channels + STRIDE_ACCESS_PAD).min(self.w_x.len());
1248        let x = &mut self.w_x[..x_pad_end];
1249        normalise_bands(
1250            mode,
1251            &freq,
1252            x,
1253            &band_e,
1254            nb_ebands,
1255            channels,
1256            (1 << lm) as usize,
1257        );
1258
1259        if channels == 1 {
1260            let _ = freq[0];
1261        }
1262
1263        let band_log_e = &mut self.w_band_log_e[..nb_ebands * channels];
1264        crate::bands::amp2log2(mode, nb_ebands, nb_ebands, &band_e, band_log_e, channels);
1265
1266        // Use explicit total_bits if provided (for Hybrid mode), otherwise calculate from buffer
1267        let total_bits = explicit_total_bits.unwrap_or_else(|| (rc.buf.len() * 8) as i32);
1268        self.w_error[..nb_ebands * channels].fill(0.0);
1269        let error = &mut self.w_error[..nb_ebands * channels];
1270
1271        let _celt_dbg = std::env::var("CELT_DBG").is_ok();
1272        if _celt_dbg {
1273            eprintln!("[ENC] band_e (linear): {:?}", &band_e[..nb_ebands.min(6)]);
1274            eprintln!("[ENC] band_log_e: {:?}", &band_log_e[..nb_ebands.min(6)]);
1275        }
1276
1277        let tell = rc.tell();
1278        let silence = false;
1279        if tell == 1 {
1280            rc.encode_bit_logp(silence, 15);
1281        }
1282        if _celt_dbg {
1283            eprintln!(
1284                "[ENC] start_band={} total_bits={} after_silence tell={}",
1285                start_band,
1286                total_bits,
1287                rc.tell()
1288            );
1289        }
1290
1291        // Prefilter bit is only written in non-hybrid mode (start_band == 0)
1292        if start_band == 0 && !silence && rc.tell() + 16 <= total_bits {
1293            rc.encode_bit_logp(pf_on, 1);
1294            if _celt_dbg {
1295                eprintln!("[ENC] pf_on={} after_prefilter tell={}", pf_on, rc.tell());
1296            }
1297            if pf_on {
1298                let qg = (gain1 / 0.09375 - 1.0 + 0.5).floor() as i32;
1299                let qg = qg.clamp(0, 7);
1300                let pi = (pitch_index + 1) as u32;
1301                let octave = 31 - pi.leading_zeros();
1302                let octave = (octave as i32 - 5).max(0) as u32;
1303                rc.enc_uint(octave, 6);
1304                rc.enc_bits(pi - (16 << octave), 4 + octave);
1305                rc.enc_bits(qg as u32, 3);
1306                rc.encode_icdf(self.tapset_decision, &TAPSET_ICDF, 2);
1307            }
1308        }
1309
1310        let mut short_blocks = false;
1311        if lm > 0 && rc.tell() + 3 <= total_bits {
1312            rc.encode_bit_logp(is_transient, 3);
1313            if is_transient {
1314                short_blocks = true;
1315            }
1316        }
1317        if _celt_dbg {
1318            eprintln!(
1319                "[ENC] is_transient={} short_blocks={} after_transient tell={}",
1320                is_transient,
1321                short_blocks,
1322                rc.tell()
1323            );
1324        }
1325
1326        if short_blocks {
1327            let b = 1 << lm;
1328            let n = frame_size / b;
1329            for c in 0..channels {
1330                let c_offset = c * buf_stride;
1331                for i in 0..b {
1332                    mode.mdct.forward(
1333                        &in_buf[c_offset + i * n..c_offset + buf_stride],
1334                        &mut freq[c * frame_size + i..],
1335                        mode.window,
1336                        overlap,
1337                        mode.max_lm,
1338                        b,
1339                    );
1340                }
1341            }
1342
1343            compute_band_energies(mode, &freq, band_e, nb_ebands, channels, lm);
1344            normalise_bands(
1345                mode,
1346                &freq,
1347                x,
1348                &band_e,
1349                nb_ebands,
1350                channels,
1351                (1 << lm) as usize,
1352            );
1353        }
1354
1355        // C: with complexity >= 4 (two_pass=true), intra is never forced (only force_intra which defaults false)
1356        // C: with complexity < 4, intra triggers via delayedIntra mechanism (not implemented; use first-frame approximation)
1357        let intra_ener = if self.complexity >= 4 {
1358            false
1359        } else {
1360            self.old_band_e[..nb_ebands * channels]
1361                .iter()
1362                .all(|&e| e <= -27.0)
1363        };
1364        quant_coarse_energy(
1365            mode,
1366            start_band,
1367            nb_ebands,
1368            &band_log_e,
1369            &mut self.old_band_e,
1370            total_bits as u32,
1371            error,
1372            rc,
1373            channels,
1374            lm,
1375            is_transient || intra_ener,
1376            (total_bits / 8) as usize,
1377        );
1378        if _celt_dbg {
1379            eprintln!(
1380                "[ENC] old_band_e after coarse: {:?}",
1381                &self.old_band_e[..nb_ebands.min(6)]
1382            );
1383        }
1384
1385        self.w_tf_res[..nb_ebands].fill(0);
1386        let tf_res = &mut self.w_tf_res[..nb_ebands];
1387        let effective_bytes = ((total_bits / 8) as usize).max(1);
1388        let lambda = 80.max(20480 / effective_bytes + 2) as i32;
1389
1390        // C opus skips tf_analysis at complexity < 2
1391        let tf_select = if self.complexity >= 2 && effective_bytes >= 15 * channels {
1392            tf_analysis(
1393                mode,
1394                nb_ebands,
1395                is_transient,
1396                tf_res,
1397                lambda,
1398                &x,
1399                frame_size,
1400                lm as i32,
1401                tf_estimate,
1402                tf_chan,
1403            )
1404        } else {
1405            0
1406        };
1407        tf_encode(
1408            start_band,
1409            nb_ebands,
1410            is_transient,
1411            tf_res,
1412            lm as i32,
1413            tf_select,
1414            rc,
1415        );
1416        if _celt_dbg {
1417            eprintln!("[ENC] after_coarse+tf tell={}", rc.tell());
1418        }
1419
1420        let mut dual_stereo_val = if channels == 2 {
1421            stereo_analysis(mode, &x, lm as i32, frame_size) as i32
1422        } else {
1423            0
1424        };
1425
1426        let mut stereo_saving = 0.0f32;
1427        let equiv_rate = (total_bits * 48000) / frame_size as i32;
1428        if channels == 2 {
1429            self.intensity = hysteresis_decision(
1430                equiv_rate / 1000,
1431                &INTEN_THRESHOLDS,
1432                &INTEN_HYSTERESIS,
1433                self.intensity,
1434            );
1435            self.intensity = self.intensity.clamp(0, nb_ebands as i32);
1436        }
1437
1438        // C opus uses SPREAD_NONE at complexity 0
1439        if self.complexity == 0 {
1440            self.spread_decision = SPREAD_NONE;
1441            if rc.tell() + 4 <= total_bits {
1442                rc.encode_icdf(self.spread_decision, &SPREAD_ICDF, 5);
1443            }
1444        } else if rc.tell() + 4 <= total_bits {
1445            // C: for shortBlocks (transients), complexity < 3, or few available bytes: use SPREAD_NORMAL
1446            // otherwise call spreading_decision()
1447            if is_transient || self.complexity < 3 || effective_bytes < 10 * channels {
1448                self.spread_decision = SPREAD_NORMAL;
1449            } else {
1450                let update_hf = lm == mode.max_lm;
1451                let spread_weights = [32i32; 21];
1452                self.spread_decision = spreading_decision(
1453                    mode,
1454                    &x,
1455                    &mut self.tonal_average,
1456                    self.spread_decision,
1457                    &mut self.hf_average,
1458                    &mut self.tapset_decision,
1459                    update_hf,
1460                    nb_ebands,
1461                    channels,
1462                    (1 << lm) as usize,
1463                    &spread_weights,
1464                );
1465            }
1466            rc.encode_icdf(self.spread_decision, &SPREAD_ICDF, 5);
1467        } else {
1468            self.spread_decision = SPREAD_NORMAL;
1469        }
1470        if _celt_dbg {
1471            eprintln!(
1472                "[ENC] spread={} after_spread tell={}",
1473                self.spread_decision,
1474                rc.tell()
1475            );
1476        }
1477
1478        self.w_cap[..nb_ebands].fill(0);
1479        let cap = &mut self.w_cap[..nb_ebands];
1480        for (i, cap_i) in cap.iter_mut().enumerate() {
1481            *cap_i = (mode.cache.caps[nb_ebands * (2 * lm + channels - 1) + i] as i32 + 64)
1482                * channels as i32
1483                * 2;
1484        }
1485
1486        self.w_offsets[..nb_ebands].fill(0);
1487        let offsets = &mut self.w_offsets[..nb_ebands];
1488        let dynalloc_logp = 6i32;
1489        let total_bits_bitres = total_bits << BITRES;
1490        let total_boost = 0i32;
1491        // Dynamic allocation: for each band, write one FALSE bit to indicate no boost.
1492        // The decoder reads matching FALSE bits to confirm zero allocation increase.
1493        for i in 0..nb_ebands {
1494            let tell_frac = rc.tell() << BITRES;
1495            if tell_frac + (dynalloc_logp << BITRES) >= total_bits_bitres - total_boost {
1496                break;
1497            }
1498            rc.encode_bit_logp(false, dynalloc_logp as u32);
1499            offsets[i] = 0;
1500        }
1501        if _celt_dbg {
1502            eprintln!("[ENC] after_dynalloc tell={}", rc.tell());
1503        }
1504
1505        let alloc_trim = alloc_trim_analysis(
1506            mode,
1507            &x,
1508            &band_log_e,
1509            nb_ebands,
1510            lm as i32,
1511            channels,
1512            frame_size,
1513            &mut stereo_saving,
1514            tf_estimate,
1515            self.intensity,
1516            0.0,
1517            equiv_rate,
1518        );
1519        if (rc.tell() << BITRES) + (6 << BITRES) <= total_bits_bitres - total_boost {
1520            rc.encode_icdf(alloc_trim, &TRIM_ICDF, 7);
1521        }
1522        if _celt_dbg {
1523            eprintln!(
1524                "[ENC] alloc_trim={} after_trim tell={}",
1525                alloc_trim,
1526                rc.tell()
1527            );
1528        }
1529
1530        let mut intensity = self.intensity;
1531        self.w_pulses[..nb_ebands].fill(0);
1532        let pulses = &mut self.w_pulses[..nb_ebands];
1533
1534        let stereo = channels > 1;
1535        let ebands_stereo = if stereo {
1536            nb_ebands * channels
1537        } else {
1538            nb_ebands
1539        };
1540        self.w_fine_priority[..ebands_stereo].fill(0);
1541        let fine_priority = &mut self.w_fine_priority[..ebands_stereo];
1542        self.w_ebits[..ebands_stereo].fill(0);
1543        let ebits = &mut self.w_ebits[..ebands_stereo];
1544        let mut balance = 0;
1545
1546        self.last_coded_bands = clt_compute_allocation(
1547            mode,
1548            start_band,
1549            nb_ebands,
1550            &offsets,
1551            &cap,
1552            alloc_trim,
1553            &mut intensity,
1554            &mut dual_stereo_val,
1555            total_bits << 3,
1556            &mut balance,
1557            pulses,
1558            ebits,
1559            fine_priority,
1560            channels as i32,
1561            lm as i32,
1562            rc,
1563            true,
1564            0,
1565            nb_ebands as i32 - 1,
1566        );
1567        if _celt_dbg {
1568            eprintln!(
1569                "[ENC] coded_bands={} after_alloc tell={}",
1570                self.last_coded_bands,
1571                rc.tell()
1572            );
1573            eprintln!("[ENC] pulses={:?}", &pulses[..nb_ebands]);
1574            eprintln!("[ENC] ebits={:?}", &ebits[..nb_ebands]);
1575        }
1576
1577        quant_fine_energy(
1578            mode,
1579            start_band,
1580            nb_ebands,
1581            &mut self.old_band_e,
1582            error,
1583            &ebits,
1584            rc,
1585            channels,
1586        );
1587
1588        self.w_collapse_masks[..nb_ebands * channels].fill(0);
1589        let collapse_masks = &mut self.w_collapse_masks[..nb_ebands * channels];
1590        let (x_split, y_split) = x.split_at_mut(frame_size);
1591        let y_opt = if channels == 2 { Some(y_split) } else { None };
1592
1593        // Reserve bits for anti-collapse (matching C reference)
1594        let anti_collapse_rsv = if is_transient && lm >= 2 {
1595            let remaining = (total_bits << BITRES) - (rc.tell() << BITRES) - 1;
1596            if remaining >= ((lm as i32 + 2) << BITRES) {
1597                1i32 << BITRES
1598            } else {
1599                0
1600            }
1601        } else {
1602            0
1603        };
1604
1605        let mut dual_stereo = dual_stereo_val != 0;
1606        // theta_rdo requires stereo + !dual_stereo + complexity >= 8
1607        let theta_rdo = channels == 2 && !dual_stereo && self.complexity >= 8;
1608        let resynth = theta_rdo;
1609
1610        quant_all_bands(
1611            true,
1612            mode,
1613            start_band,
1614            nb_ebands,
1615            x_split,
1616            y_opt,
1617            collapse_masks,
1618            &band_e,
1619            &pulses,
1620            short_blocks,
1621            self.spread_decision,
1622            &mut dual_stereo,
1623            intensity as usize,
1624            &tf_res,
1625            (total_bits << 3) - anti_collapse_rsv,
1626            &mut balance,
1627            rc,
1628            lm as i32,
1629            self.last_coded_bands,
1630            resynth,
1631            &mut 0u32, // encoder doesn't need stateful seed for noise fill
1632        );
1633        if _celt_dbg {
1634            eprintln!("[ENC] after_quant_all_bands tell={}", rc.tell());
1635        }
1636        if _celt_dbg {
1637            eprintln!("[ENC] freq[0..10] after quant: {:?}", &freq[..10]);
1638            eprintln!("[ENC] x[0..10] after quant: {:?}", &x[..10]);
1639        }
1640
1641        // Write anti-collapse bit (matching C reference: after quant_all_bands, before quant_energy_finalise)
1642        if anti_collapse_rsv > 0 {
1643            let anti_collapse_on = if self.consec_transient < 2 {
1644                1u32
1645            } else {
1646                0u32
1647            };
1648            rc.enc_bits(anti_collapse_on, 1);
1649        }
1650
1651        quant_energy_finalise(
1652            mode,
1653            start_band,
1654            nb_ebands,
1655            &mut self.old_band_e,
1656            error,
1657            &ebits,
1658            &fine_priority,
1659            (total_bits - rc.tell()) << 3,
1660            rc,
1661            channels,
1662        );
1663        if _celt_dbg {
1664            eprintln!(
1665                "[ENC] after_energy_finalise tell={}/{}",
1666                rc.tell(),
1667                total_bits
1668            );
1669        }
1670        if _celt_dbg {
1671            eprintln!(
1672                "[ENC] old_band_e after ALL energy quant: {:?}",
1673                &self.old_band_e[..nb_ebands.min(6)]
1674            );
1675        }
1676        if resynth {
1677            let band_amp_synth = &mut self.w_band_amp_synth[..nb_ebands * channels];
1678            log2amp(mode, nb_ebands, band_amp_synth, &self.old_band_e, channels);
1679            self.w_freq_synth[..frame_size * channels].fill(0.0);
1680            let freq_synth = &mut self.w_freq_synth[..frame_size * channels];
1681            denormalise_bands(
1682                mode,
1683                &x,
1684                freq_synth,
1685                &band_amp_synth,
1686                start_band,
1687                nb_ebands,
1688                channels,
1689                (1 << lm) as usize,
1690            );
1691            let (syn_shift, syn_b) = if is_transient {
1692                (mode.max_lm, 1 << lm)
1693            } else {
1694                (mode.max_lm - lm, 1)
1695            };
1696            let syn_n = frame_size / syn_b;
1697            let decode_buf_size = 2048;
1698
1699            for c in 0..channels {
1700                let co = c * syn_mem_size;
1701                self.enc_decode_mem
1702                    .copy_within(co + frame_size..co + decode_buf_size + overlap, co);
1703            }
1704
1705            for c in 0..channels {
1706                let co = c * syn_mem_size;
1707                let out_syn_idx = decode_buf_size - frame_size;
1708                for bi in 0..syn_b {
1709                    mode.mdct.backward(
1710                        &freq_synth[c * frame_size + bi..],
1711                        &mut self.enc_decode_mem[co + out_syn_idx + bi * syn_n..],
1712                        mode.window,
1713                        overlap,
1714                        syn_shift,
1715                        syn_b,
1716                    );
1717                }
1718            }
1719        }
1720
1721        self.last_band_log_e.copy_from_slice(&self.old_band_e);
1722
1723        if !is_transient {
1724            self.old_band_e3.copy_from_slice(&self.old_band_e2);
1725            self.old_band_e2.copy_from_slice(&self.old_band_e);
1726        } else {
1727            for i in 0..channels * nb_ebands {
1728                self.old_band_e2[i] = self.old_band_e2[i].min(self.old_band_e[i]);
1729            }
1730        }
1731
1732        rc.pad_to_bits(total_bits);
1733
1734        if pf_on {
1735            self.prefilter_period = pitch_index;
1736            self.prefilter_gain = gain1;
1737            self.prefilter_tapset = self.tapset_decision;
1738        } else {
1739            self.prefilter_period = COMBFILTER_MINPERIOD;
1740            self.prefilter_gain = 0.0;
1741            self.prefilter_tapset = self.tapset_decision;
1742        }
1743
1744        let syn_mem_size = 2048 + overlap;
1745
1746        for c in 0..channels {
1747            let channel_offset = c * syn_mem_size;
1748            let n = frame_size;
1749            let max_period = COMBFILTER_MAXPERIOD;
1750            if n >= max_period {
1751                self.prefilter_mem[c * max_period..(c + 1) * max_period].copy_from_slice(
1752                    &self.syn_mem
1753                        [channel_offset + syn_mem_size - max_period..channel_offset + syn_mem_size],
1754                );
1755            } else {
1756                let mut new_mem = [0.0f32; COMBFILTER_MAXPERIOD];
1757                new_mem[..max_period - n]
1758                    .copy_from_slice(&self.prefilter_mem[c * max_period + n..(c + 1) * max_period]);
1759                new_mem[max_period - n..].copy_from_slice(
1760                    &self.syn_mem[channel_offset + syn_mem_size - n..channel_offset + syn_mem_size],
1761                );
1762                self.prefilter_mem[c * max_period..(c + 1) * max_period].copy_from_slice(&new_mem);
1763            }
1764        }
1765
1766        // Update consec_transient counter (matching C reference)
1767        if is_transient {
1768            self.consec_transient += 1;
1769        } else {
1770            self.consec_transient = 0;
1771        }
1772    }
1773}
1774
1775pub struct CeltDecoder {
1776    mode: &'static CeltMode,
1777    channels: usize,
1778    decode_mem: Vec<f32>,
1779    old_band_e: Vec<f32>,
1780    preemph_mem: Vec<f32>,
1781    prefilter_mem: Vec<f32>,
1782    prefilter_period: usize,
1783    prefilter_period_old: usize,
1784    prefilter_gain: f32,
1785    prefilter_gain_old: f32,
1786    prefilter_tapset: i32,
1787    prefilter_tapset_old: i32,
1788    old_band_e2: Vec<f32>,
1789    old_band_e3: Vec<f32>,
1790    rng: u32,
1791    // Pre-allocated working buffers for decode_impl
1792    w_tf_res: Vec<i32>,
1793    w_cap: Vec<i32>,
1794    w_offsets: Vec<i32>,
1795    w_pulses: Vec<i32>,
1796    w_ebits: Vec<i32>,
1797    w_fine_priority: Vec<i32>,
1798    w_x: Vec<f32>,
1799    w_collapse_masks: Vec<u32>,
1800    w_freq: Vec<f32>,
1801    w_band_amp: Vec<f32>,
1802    w_pcm_frame: Vec<f32>,
1803    w_filtered: Vec<f32>,
1804    w_post: Vec<f32>,
1805}
1806
1807impl CeltDecoder {
1808    pub fn new(mode: &'static CeltMode, channels: usize) -> Self {
1809        let overlap = mode.overlap;
1810        let nb_ebands = mode.nb_ebands;
1811        let nb_x_ch = nb_ebands * channels;
1812        let dec_frame_x_ch = DECODE_BUFFER_SIZE * channels;
1813        Self {
1814            mode,
1815            channels,
1816            decode_mem: vec![0.0; channels * (DECODE_BUFFER_SIZE + overlap)],
1817            old_band_e: vec![-28.0; nb_x_ch],
1818            preemph_mem: vec![0.0; channels],
1819            prefilter_mem: vec![0.0; channels * COMBFILTER_MAXPERIOD],
1820            prefilter_period: COMBFILTER_MINPERIOD,
1821            prefilter_period_old: COMBFILTER_MINPERIOD,
1822            prefilter_gain: 0.0,
1823            prefilter_gain_old: 0.0,
1824            prefilter_tapset: 0,
1825            prefilter_tapset_old: 0,
1826            old_band_e2: vec![-28.0; nb_x_ch],
1827            old_band_e3: vec![-28.0; nb_x_ch],
1828            rng: 0,
1829            // Pre-allocate working buffers
1830            w_tf_res: vec![0; nb_ebands],
1831            w_cap: vec![0; nb_ebands],
1832            w_offsets: vec![0; nb_ebands],
1833            w_pulses: vec![0; nb_ebands],
1834            w_ebits: vec![0; nb_x_ch],
1835            w_fine_priority: vec![0; nb_x_ch],
1836            // Extra padding for stride-based access in alg_unquant/exp_rotation
1837            w_x: vec![0.0; dec_frame_x_ch + STRIDE_ACCESS_PAD],
1838            w_collapse_masks: vec![0; nb_x_ch],
1839            w_freq: vec![0.0; dec_frame_x_ch],
1840            w_band_amp: vec![0.0; nb_x_ch],
1841            w_pcm_frame: vec![0.0; DECODE_BUFFER_SIZE],
1842            w_filtered: vec![0.0; DECODE_BUFFER_SIZE],
1843            w_post: vec![0.0; DECODE_BUFFER_SIZE + COMBFILTER_MAXPERIOD],
1844        }
1845    }
1846
1847    pub fn decode(&mut self, compressed: &[u8], frame_size: usize, pcm: &mut [f32]) -> usize {
1848        self.decode_impl(compressed, frame_size, pcm, 0)
1849    }
1850
1851    pub fn decode_with_start_band(
1852        &mut self,
1853        compressed: &[u8],
1854        frame_size: usize,
1855        pcm: &mut [f32],
1856        start_band: usize,
1857    ) -> usize {
1858        self.decode_impl(compressed, frame_size, pcm, start_band)
1859    }
1860
1861    /// Decode from an existing RangeCoder (for Hybrid mode where SILK has already consumed bits)
1862    pub fn decode_from_range_coder(
1863        &mut self,
1864        rc: &mut RangeCoder,
1865        total_bits: i32,
1866        frame_size: usize,
1867        pcm: &mut [f32],
1868        start_band: usize,
1869    ) -> usize {
1870        self.decode_impl_from_rc(rc, total_bits, frame_size, pcm, start_band)
1871    }
1872
1873    fn decode_impl(
1874        &mut self,
1875        compressed: &[u8],
1876        frame_size: usize,
1877        pcm: &mut [f32],
1878        start_band: usize,
1879    ) -> usize {
1880        let total_bits = (compressed.len() * 8) as i32;
1881        let mut rc = RangeCoder::new_decoder(compressed);
1882        self.decode_impl_from_rc(&mut rc, total_bits, frame_size, pcm, start_band)
1883    }
1884
1885    fn decode_impl_from_rc(
1886        &mut self,
1887        rc: &mut RangeCoder,
1888        total_bits: i32,
1889        frame_size: usize,
1890        pcm: &mut [f32],
1891        start_band: usize,
1892    ) -> usize {
1893        let _celt_dbg = std::env::var("CELT_DBG").is_ok();
1894        let mode = self.mode;
1895        let channels = self.channels;
1896        let nb_ebands = mode.nb_ebands;
1897        let overlap = mode.overlap;
1898
1899        let mut lm = 0;
1900        while (mode.short_mdct_size << lm) != frame_size {
1901            lm += 1;
1902            if lm > mode.max_lm {
1903                break;
1904            }
1905        }
1906        if (mode.short_mdct_size << lm) != frame_size {
1907            lm = 0;
1908        }
1909
1910        let tell = rc.tell();
1911        let mut silence = false;
1912        if tell >= total_bits {
1913            silence = true;
1914        } else if tell == 1 {
1915            silence = rc.decode_bit_logp(15);
1916        }
1917        if _celt_dbg {
1918            eprintln!(
1919                "[DEC] start_band={} total_bits={} after_silence tell={}",
1920                start_band,
1921                total_bits,
1922                rc.tell()
1923            );
1924        }
1925
1926        // Handle silence: output zeros and return early
1927        if silence {
1928            pcm[..frame_size * channels].fill(0.0);
1929            return frame_size;
1930        }
1931
1932        let mut pf_on = false;
1933        let mut pitch_index = COMBFILTER_MINPERIOD;
1934        let mut gain1 = 0.0f32;
1935        let mut prefilter_tapset = 0;
1936        // Prefilter bit is only present in non-hybrid mode (start_band == 0)
1937        if start_band == 0 && !silence && rc.tell() + 16 <= total_bits {
1938            pf_on = rc.decode_bit_logp(1);
1939            if pf_on {
1940                let octave = rc.dec_uint(6);
1941                pitch_index = ((16 << octave) + rc.dec_bits(4 + octave)) as usize - 1;
1942                let qg = rc.dec_bits(3);
1943                if rc.tell() + 2 <= total_bits {
1944                    prefilter_tapset = rc.decode_icdf(&TAPSET_ICDF, 2) as usize;
1945                }
1946                gain1 = 0.09375 * (qg as f32 + 1.0);
1947            }
1948        }
1949        if _celt_dbg {
1950            eprintln!("[DEC] pf_on={} after_prefilter tell={}", pf_on, rc.tell());
1951        }
1952        // In hybrid mode, ensure the combfilter doesn't run from stale previous state
1953        if start_band != 0 {
1954            self.prefilter_gain = 0.0;
1955        }
1956
1957        let mut is_transient = false;
1958        if lm > 0 && rc.tell() + 3 <= total_bits {
1959            is_transient = rc.decode_bit_logp(3);
1960        }
1961        let short_blocks = is_transient;
1962        if _celt_dbg {
1963            eprintln!(
1964                "[DEC] is_transient={} after_transient tell={}",
1965                is_transient,
1966                rc.tell()
1967            );
1968        }
1969        let intra_ener = false;
1970
1971        unquant_coarse_energy(
1972            mode,
1973            start_band,
1974            nb_ebands,
1975            &mut self.old_band_e,
1976            total_bits as u32,
1977            rc,
1978            channels,
1979            lm,
1980            is_transient || intra_ener,
1981        );
1982        if _celt_dbg {
1983            eprintln!(
1984                "[DEC] old_band_e after coarse: {:?}",
1985                &self.old_band_e[..nb_ebands.min(6)]
1986            );
1987        }
1988
1989        self.w_tf_res[..nb_ebands].fill(0);
1990        let tf_res = &mut self.w_tf_res[..nb_ebands];
1991        tf_decode(start_band, nb_ebands, is_transient, tf_res, lm as i32, rc);
1992        if _celt_dbg {
1993            eprintln!("[DEC] after_coarse+tf tell={}", rc.tell());
1994        }
1995
1996        let spread_decision = if rc.tell() + 4 <= total_bits {
1997            rc.decode_icdf(&SPREAD_ICDF, 5)
1998        } else {
1999            SPREAD_NORMAL
2000        };
2001        if _celt_dbg {
2002            eprintln!(
2003                "[DEC] spread={} after_spread tell={}",
2004                spread_decision,
2005                rc.tell()
2006            );
2007        }
2008
2009        self.w_cap[..nb_ebands].fill(0);
2010        let cap = &mut self.w_cap[..nb_ebands];
2011        for (i, cap_i) in cap.iter_mut().enumerate() {
2012            *cap_i = (mode.cache.caps[nb_ebands * (2 * lm + channels - 1) + i] as i32 + 64)
2013                * channels as i32
2014                * 2;
2015        }
2016
2017        self.w_offsets[..nb_ebands].fill(0);
2018        let offsets = &mut self.w_offsets[..nb_ebands];
2019        let mut dynalloc_logp = 6i32;
2020        let mut total_bits_bitres = total_bits << BITRES;
2021        let mut tell_frac = rc.tell() << BITRES;
2022        for i in 0..nb_ebands {
2023            let width =
2024                channels as i32 * (mode.e_bands[i + 1] - mode.e_bands[i]) as i32 * (1 << lm);
2025            let quanta = (width << BITRES).min((6i32 << BITRES).max(width));
2026            let mut dynalloc_loop_logp = dynalloc_logp;
2027            let mut boost = 0i32;
2028            while tell_frac + (dynalloc_loop_logp << BITRES) < total_bits_bitres && boost < cap[i] {
2029                let flag = rc.decode_bit_logp(dynalloc_loop_logp as u32);
2030                tell_frac = rc.tell() << BITRES;
2031                if !flag {
2032                    break;
2033                }
2034                boost += quanta;
2035                total_bits_bitres -= quanta;
2036                dynalloc_loop_logp = 1;
2037            }
2038            offsets[i] = boost;
2039            if boost > 0 {
2040                dynalloc_logp = dynalloc_logp.max(2) - 1;
2041                dynalloc_logp = dynalloc_logp.max(2);
2042            }
2043        }
2044        if _celt_dbg {
2045            eprintln!("[DEC] after_dynalloc tell={}", rc.tell());
2046        }
2047
2048        let alloc_trim = if (rc.tell() << BITRES) + (6 << BITRES) <= total_bits_bitres {
2049            rc.decode_icdf(&TRIM_ICDF, 7)
2050        } else {
2051            5
2052        };
2053        if _celt_dbg {
2054            eprintln!(
2055                "[DEC] alloc_trim={} after_trim tell={}",
2056                alloc_trim,
2057                rc.tell()
2058            );
2059        }
2060        let anti_collapse_rsv = if is_transient && lm >= 2 {
2061            let remaining = (total_bits << BITRES) - (rc.tell() << BITRES) - 1;
2062            if remaining >= ((lm as i32 + 2) << BITRES) {
2063                1i32 << BITRES
2064            } else {
2065                0
2066            }
2067        } else {
2068            0
2069        };
2070
2071        let mut intensity = 0;
2072        let mut dual_stereo_val = if channels == 2 { 1 } else { 0 };
2073        let mut balance = 0;
2074        self.w_pulses[..nb_ebands].fill(0);
2075        let pulses = &mut self.w_pulses[..nb_ebands];
2076
2077        let ebands_stereo = if channels > 1 {
2078            nb_ebands * channels
2079        } else {
2080            nb_ebands
2081        };
2082        self.w_fine_priority[..ebands_stereo].fill(0);
2083        let fine_priority = &mut self.w_fine_priority[..ebands_stereo];
2084        self.w_ebits[..ebands_stereo].fill(0);
2085        let ebits = &mut self.w_ebits[..ebands_stereo];
2086
2087        let coded_bands = clt_compute_allocation(
2088            mode,
2089            start_band,
2090            nb_ebands,
2091            &offsets,
2092            &cap,
2093            alloc_trim,
2094            &mut intensity,
2095            &mut dual_stereo_val,
2096            (total_bits << 3) - anti_collapse_rsv,
2097            &mut balance,
2098            pulses,
2099            ebits,
2100            fine_priority,
2101            channels as i32,
2102            lm as i32,
2103            rc,
2104            false,
2105            0,
2106            nb_ebands as i32 - 1,
2107        );
2108        if _celt_dbg {
2109            eprintln!(
2110                "[DEC] coded_bands={} after_alloc tell={}",
2111                coded_bands,
2112                rc.tell()
2113            );
2114            eprintln!("[DEC] pulses={:?}", &pulses[..nb_ebands]);
2115            eprintln!("[DEC] ebits={:?}", &ebits[..nb_ebands]);
2116        }
2117
2118        unquant_fine_energy(
2119            mode,
2120            start_band,
2121            nb_ebands,
2122            &mut self.old_band_e,
2123            &ebits,
2124            rc,
2125            channels,
2126        );
2127
2128        if frame_size > DECODE_BUFFER_SIZE + overlap {
2129            return 0;
2130        }
2131
2132        self.w_x[..frame_size * channels].fill(0.0);
2133        // Include stride-access padding so alg_unquant/exp_rotation can use
2134        // x[i*stride] without going out of bounds (matches C's raw-pointer access).
2135        let x_pad_end = (frame_size * channels + STRIDE_ACCESS_PAD).min(self.w_x.len());
2136        let x = &mut self.w_x[..x_pad_end];
2137        self.w_collapse_masks[..nb_ebands * channels].fill(0);
2138        let collapse_masks = &mut self.w_collapse_masks[..nb_ebands * channels];
2139
2140        // NOTE: Buffer shift must happen AFTER MDCT backward, not before.
2141        // The C code does OPUS_MOVE after deemphasis, which preserves the overlap
2142        // data in out_syn[0..overlap-1] for the next frame's TDAC.
2143        // We'll shift the buffer at the end of decode instead.
2144
2145        let (x_split, y_split) = x.split_at_mut(frame_size);
2146        let y_opt = if channels == 2 { Some(y_split) } else { None };
2147
2148        let mut dual_stereo = dual_stereo_val != 0;
2149        self.w_band_amp[..nb_ebands * channels].fill(0.0);
2150        let band_amp = &mut self.w_band_amp[..nb_ebands * channels];
2151        log2amp(mode, nb_ebands, band_amp, &self.old_band_e, channels);
2152        if _celt_dbg {
2153            eprintln!("[DEC] band_amp (log2): {:?}", &band_amp[..nb_ebands.min(6)]);
2154        }
2155
2156        quant_all_bands(
2157            false,
2158            mode,
2159            start_band,
2160            nb_ebands,
2161            x_split,
2162            y_opt,
2163            collapse_masks,
2164            &band_amp,
2165            &pulses,
2166            short_blocks,
2167            spread_decision,
2168            &mut dual_stereo,
2169            intensity as usize,
2170            &tf_res,
2171            (total_bits << 3) - anti_collapse_rsv,
2172            &mut balance,
2173            rc,
2174            lm as i32,
2175            coded_bands,
2176            true,
2177            &mut self.rng,
2178        );
2179        if _celt_dbg {
2180            eprintln!("[DEC] after_quant_all_bands tell={}", rc.tell());
2181        }
2182        if _celt_dbg {
2183            eprintln!("[DEC] x[0..10] after quant_all_bands: {:?}", &x[..10]);
2184        }
2185
2186        let mut anti_collapse_on = false;
2187        if anti_collapse_rsv > 0 {
2188            anti_collapse_on = rc.dec_bits(1) != 0;
2189        }
2190
2191        unquant_energy_finalise(
2192            mode,
2193            start_band,
2194            nb_ebands,
2195            &mut self.old_band_e,
2196            &ebits,
2197            &fine_priority,
2198            (total_bits - rc.tell()) << 3,
2199            rc,
2200            channels,
2201        );
2202        if _celt_dbg {
2203            eprintln!(
2204                "[DEC] after_energy_finalise tell={}/{}",
2205                rc.tell(),
2206                total_bits
2207            );
2208        }
2209        if _celt_dbg {
2210            eprintln!(
2211                "[DEC] old_band_e after ALL energy dequant: {:?}",
2212                &self.old_band_e[..nb_ebands.min(6)]
2213            );
2214        }
2215
2216        if anti_collapse_on {
2217            self.rng = crate::bands::anti_collapse(
2218                mode,
2219                x,
2220                &collapse_masks,
2221                lm as i32,
2222                channels,
2223                frame_size,
2224                start_band,
2225                nb_ebands,
2226                &self.old_band_e,
2227                &self.old_band_e2,
2228                &self.old_band_e3,
2229                &pulses,
2230                self.rng,
2231            );
2232        }
2233
2234        self.w_freq[..frame_size * channels].fill(0.0);
2235        let freq = &mut self.w_freq[..frame_size * channels];
2236        denormalise_bands(
2237            mode,
2238            &x,
2239            freq,
2240            &band_amp,
2241            start_band,
2242            nb_ebands,
2243            channels,
2244            (1 << lm) as usize,
2245        );
2246        if _celt_dbg {
2247            eprintln!("[DEC] freq[0..10] after denorm: {:?}", &freq[..10]);
2248        }
2249
2250        let (shift, b) = if short_blocks {
2251            (mode.max_lm, 1 << lm)
2252        } else {
2253            (mode.max_lm - lm, 1)
2254        };
2255        let n = frame_size / b;
2256
2257        for c in 0..channels {
2258            let channel_mem_offset = c * (DECODE_BUFFER_SIZE + overlap);
2259
2260            // Shift decode_mem left by frame_size (matches C's OPUS_MOVE).
2261            // This moves the previous frame's "future overlap" (at decode_mem[DECODE_BUFFER_SIZE..])
2262            // to decode_mem[DECODE_BUFFER_SIZE - frame_size..] = the TDAC x2 read position,
2263            // ensuring correct MDCT-IV aliasing cancellation across frames.
2264            let mem_size = DECODE_BUFFER_SIZE + overlap;
2265            self.decode_mem.copy_within(
2266                channel_mem_offset + frame_size..channel_mem_offset + mem_size,
2267                channel_mem_offset,
2268            );
2269
2270            let out_syn_idx = DECODE_BUFFER_SIZE - frame_size;
2271
2272            for i in 0..b {
2273                let block_freq_idx = c * frame_size + i;
2274                let block_out_idx = channel_mem_offset + out_syn_idx + i * n;
2275                let available_len = self.decode_mem.len() - block_out_idx;
2276                if available_len < n + overlap {
2277                    panic!(
2278                        "MDCT backward buffer too small: need {}, have {} (out_syn_idx={}, n={}, overlap={})",
2279                        n + overlap,
2280                        available_len,
2281                        out_syn_idx,
2282                        n,
2283                        overlap
2284                    );
2285                }
2286                self.mode.mdct.backward(
2287                    &freq[block_freq_idx..],
2288                    &mut self.decode_mem[block_out_idx..],
2289                    mode.window,
2290                    overlap,
2291                    shift,
2292                    b,
2293                );
2294            }
2295
2296            self.w_pcm_frame[..frame_size].fill(0.0);
2297            let pcm_frame = &mut self.w_pcm_frame[..frame_size];
2298
2299            pcm_frame.copy_from_slice(
2300                &self.decode_mem[channel_mem_offset + out_syn_idx
2301                    ..channel_mem_offset + out_syn_idx + frame_size],
2302            );
2303
2304            if pf_on || self.prefilter_gain > 0.0 {
2305                // Build input buffer for postfilter: [prefilter_mem | pcm_frame]
2306                self.w_post[..frame_size + COMBFILTER_MAXPERIOD].fill(0.0);
2307                {
2308                    let post = &mut self.w_post[..frame_size + COMBFILTER_MAXPERIOD];
2309                    post[..COMBFILTER_MAXPERIOD].copy_from_slice(
2310                        &self.prefilter_mem
2311                            [c * COMBFILTER_MAXPERIOD..(c + 1) * COMBFILTER_MAXPERIOD],
2312                    );
2313                    post[COMBFILTER_MAXPERIOD..].copy_from_slice(&pcm_frame);
2314                }
2315
2316                self.w_filtered[..frame_size].fill(0.0);
2317                {
2318                    let post = &self.w_post[..frame_size + COMBFILTER_MAXPERIOD];
2319                    let filtered = &mut self.w_filtered[..frame_size];
2320                    comb_filter(
2321                        filtered,
2322                        post,
2323                        0,
2324                        COMBFILTER_MAXPERIOD,
2325                        self.prefilter_period,
2326                        pitch_index,
2327                        frame_size,
2328                        self.prefilter_gain,
2329                        gain1,
2330                        self.prefilter_tapset,
2331                        prefilter_tapset as i32,
2332                        mode.window,
2333                        overlap,
2334                    );
2335                }
2336
2337                pcm_frame.copy_from_slice(&self.w_filtered[..frame_size]);
2338
2339                self.decode_mem[channel_mem_offset + out_syn_idx
2340                    ..channel_mem_offset + out_syn_idx + frame_size]
2341                    .copy_from_slice(&pcm_frame);
2342            }
2343
2344            let mut new_mem = [0.0f32; COMBFILTER_MAXPERIOD];
2345            if frame_size >= COMBFILTER_MAXPERIOD {
2346                new_mem.copy_from_slice(&pcm_frame[frame_size - COMBFILTER_MAXPERIOD..frame_size]);
2347            } else {
2348                new_mem[..COMBFILTER_MAXPERIOD - frame_size].copy_from_slice(
2349                    &self.prefilter_mem
2350                        [c * COMBFILTER_MAXPERIOD + frame_size..(c + 1) * COMBFILTER_MAXPERIOD],
2351                );
2352                new_mem[COMBFILTER_MAXPERIOD - frame_size..].copy_from_slice(&pcm_frame);
2353            }
2354            self.prefilter_mem[c * COMBFILTER_MAXPERIOD..(c + 1) * COMBFILTER_MAXPERIOD]
2355                .copy_from_slice(&new_mem);
2356
2357            let coef = mode.preemph[0];
2358            let mut m = self.preemph_mem[c];
2359            for i in 0..frame_size {
2360                let x = pcm_frame[i];
2361                let val = x + m;
2362                pcm[c * frame_size + i] = val;
2363                m = val * coef;
2364            }
2365            self.preemph_mem[c] = m;
2366        }
2367
2368        // Update postfilter state (matching C reference)
2369        self.prefilter_period_old = self.prefilter_period;
2370        self.prefilter_gain_old = self.prefilter_gain;
2371        self.prefilter_tapset_old = self.prefilter_tapset;
2372
2373        if pf_on {
2374            self.prefilter_period = pitch_index;
2375            self.prefilter_gain = gain1;
2376            self.prefilter_tapset = prefilter_tapset as i32;
2377        } else {
2378            self.prefilter_period = COMBFILTER_MINPERIOD;
2379            self.prefilter_gain = 0.0;
2380            self.prefilter_tapset = 0;
2381        }
2382
2383        if lm > 0 {
2384            self.prefilter_period_old = self.prefilter_period;
2385            self.prefilter_gain_old = self.prefilter_gain;
2386            self.prefilter_tapset_old = self.prefilter_tapset;
2387        }
2388
2389        if !is_transient {
2390            self.old_band_e3.copy_from_slice(&self.old_band_e2);
2391            self.old_band_e2.copy_from_slice(&self.old_band_e);
2392        } else {
2393            let nb_ebands = mode.nb_ebands;
2394            for i in 0..channels * nb_ebands {
2395                self.old_band_e2[i] = self.old_band_e2[i].min(self.old_band_e[i]);
2396            }
2397        }
2398
2399        // Update RNG from range coder for next frame's anti-collapse
2400        self.rng = rc.rng;
2401
2402        frame_size
2403    }
2404}