Skip to main content

opus_rs/
bands.rs

1use crate::modes::CeltMode;
2use crate::pvq::*;
3use crate::range_coder::RangeCoder;
4use crate::rate::{BITRES, bits2pulses, get_pulses, pulses2bits};
5use crate::tell_frac_inline;
6
7const MIN_STEREO_ENERGY: f32 = 1e-10;
8
9pub struct BandCtx<'a> {
10    pub encode: bool,
11    pub m: &'a CeltMode,
12    pub i: usize,
13    pub band_e: &'a [f32],
14    pub rc: &'a mut RangeCoder,
15    pub spread: i32,
16    pub remaining_bits: i32,
17    pub resynth: bool,
18    pub tf_change: i32,
19    pub intensity: usize,
20    pub theta_round: i32,
21    pub avoid_split_noise: bool,
22    pub arch: i32,
23    pub disable_inv: bool,
24    pub seed: u32,
25}
26
27/// Bit-exact cosine approximation matching C opus's bitexact_cos().
28/// Input: x in [0, 16384] representing angle * 16384 / (π/2).
29/// Output: cos(x * π/2 / 16384) * 32768 as i16.
30/// Matches C opus celt/bands.c bitexact_cos() exactly.
31#[inline]
32fn bitexact_cos(x: i16) -> i16 {
33    // FRAC_MUL16(a,b) = (16384 + (a as i16 as i32) * (b as i16 as i32)) >> 15
34    #[inline(always)]
35    fn frac_mul16(a: i16, b: i16) -> i16 {
36        ((16384i32 + (a as i32) * (b as i32)) >> 15) as i16
37    }
38
39    let tmp = (4096i32 + (x as i32) * (x as i32)) >> 13;
40    let x2 = tmp as i16;
41    let x2 = (32767 - x2 as i32
42        + frac_mul16(x2, -7651 + frac_mul16(x2, 8277 + frac_mul16(-626, x2))) as i32)
43        as i16;
44    1 + x2
45}
46
47#[inline]
48pub fn bitexact_log2tan(isin: i32, icos: i32) -> i32 {
49    let ec_ilog = |x: u32| -> i32 {
50        if x == 0 {
51            0
52        } else {
53            32 - x.leading_zeros() as i32
54        }
55    };
56    let lc = ec_ilog(icos.max(0) as u32);
57    let ls = ec_ilog(isin.max(0) as u32);
58    let icos_shifted = if lc > 0 {
59        icos.max(0) << (15 - lc).max(0)
60    } else {
61        0
62    };
63    let isin_shifted = if ls > 0 {
64        isin.max(0) << (15 - ls).max(0)
65    } else {
66        0
67    };
68    let fract_mul = |a: i32, b: i32| -> i32 { (a * b + 16384) >> 15 };
69    (ls - lc) * (1 << 11) + fract_mul(isin_shifted, fract_mul(isin_shifted, -2597) + 7932)
70        - fract_mul(icos_shifted, fract_mul(icos_shifted, -2597) + 7932)
71}
72
73#[inline(always)]
74fn celt_sudiv(n: i32, d: i32) -> i32 {
75    if n < 0 {
76        -((-n + (d >> 1)) / d)
77    } else {
78        (n + (d >> 1)) / d
79    }
80}
81
82#[inline]
83fn isqrt32(mut val: u32) -> u32 {
84    let mut g = 0u32;
85    let mut bshift = ((32 - val.leading_zeros()) as i32 - 1) >> 1;
86    let mut b = 1u32 << bshift;
87    while bshift >= 0 {
88        let t = (((g << 1) + b) as u64) << bshift;
89        if t <= val as u64 {
90            g += b;
91            val -= t as u32;
92        }
93        b >>= 1;
94        bshift -= 1;
95    }
96    g
97}
98
99pub const SPREAD_NONE: i32 = 0;
100pub const SPREAD_LIGHT: i32 = 1;
101pub const SPREAD_NORMAL: i32 = 2;
102pub const SPREAD_AGGRESSIVE: i32 = 3;
103
104#[allow(clippy::too_many_arguments)]
105pub fn spreading_decision(
106    m: &CeltMode,
107    x_buf: &[f32],
108    average: &mut i32,
109    last_decision: i32,
110    hf_average: &mut i32,
111    tapset_decision: &mut i32,
112    update_hf: bool,
113    end: usize,
114    channels: usize,
115    m_val: usize,
116    spread_weight: &[i32],
117) -> i32 {
118    let mut sum = 0;
119    let mut nb_bands = 0;
120    let n0 = m_val * m.short_mdct_size;
121    let mut hf_sum = 0;
122
123    if m_val * (m.e_bands[end] as usize - m.e_bands[end - 1] as usize) <= 8 {
124        return SPREAD_NONE;
125    }
126
127    for c in 0..channels {
128        for (i, &sw) in spread_weight[..end].iter().enumerate() {
129            let n = m_val * (m.e_bands[i + 1] as usize - m.e_bands[i] as usize);
130            if n <= 8 {
131                continue;
132            }
133
134            let mut tcount = [0; 3];
135            let offset = m_val * m.e_bands[i] as usize + c * n0;
136            let x = &x_buf[offset..offset + n];
137
138            for xv in x.iter().copied() {
139                let x2n = xv * xv * (n as f32);
140                if x2n < 0.25 {
141                    tcount[0] += 1;
142                }
143                if x2n < 0.0625 {
144                    tcount[1] += 1;
145                }
146                if x2n < 0.015625 {
147                    tcount[2] += 1;
148                }
149            }
150
151            if i > m.nb_ebands - 4 {
152                hf_sum += 32 * (tcount[1] + tcount[0]) / (n as i32);
153            }
154
155            let tmp = (if 2 * tcount[2] >= (n as i32) { 1 } else { 0 })
156                + (if 2 * tcount[1] >= (n as i32) { 1 } else { 0 })
157                + (if 2 * tcount[0] >= (n as i32) { 1 } else { 0 });
158            sum += tmp * sw;
159            nb_bands += sw;
160        }
161    }
162
163    if update_hf {
164        if hf_sum > 0 {
165            hf_sum /= (channels as i32) * (4 - m.nb_ebands as i32 + end as i32);
166        }
167        *hf_average = (*hf_average + hf_sum) >> 1;
168        hf_sum = *hf_average;
169
170        if *tapset_decision == 2 {
171            hf_sum += 4;
172        } else if *tapset_decision == 0 {
173            hf_sum -= 4;
174        }
175
176        if hf_sum > 22 {
177            *tapset_decision = 2;
178        } else if hf_sum > 18 {
179            *tapset_decision = 1;
180        } else {
181            *tapset_decision = 0;
182        }
183    }
184
185    if nb_bands == 0 {
186        return SPREAD_NORMAL;
187    }
188
189    let mut sum_scaled = (sum << 8) / nb_bands;
190    sum_scaled = (sum_scaled + *average) >> 1;
191    *average = sum_scaled;
192
193    let sum_final = (3 * sum_scaled + (((3 - last_decision) << 7) + 64) + 2) >> 2;
194
195    if sum_final < 80 {
196        SPREAD_AGGRESSIVE
197    } else if sum_final < 256 {
198        SPREAD_NORMAL
199    } else if sum_final < 384 {
200        SPREAD_LIGHT
201    } else {
202        SPREAD_NONE
203    }
204}
205
206pub fn haar1(x: &mut [f32], n0: usize, stride: usize) {
207    #[cfg(target_arch = "aarch64")]
208    {
209        haar1_neon(x, n0, stride);
210        return;
211    }
212    #[cfg(not(target_arch = "aarch64"))]
213    {
214        haar1_scalar(x, n0, stride);
215    }
216}
217
218#[cfg(not(target_arch = "aarch64"))]
219#[inline]
220fn haar1_scalar(x: &mut [f32], n0: usize, stride: usize) {
221    let n = n0 >> 1;
222    let scale = std::f32::consts::FRAC_1_SQRT_2;
223    for i in 0..stride {
224        for j in 0..n {
225            let idx1 = stride * 2 * j + i;
226            let idx2 = stride * (2 * j + 1) + i;
227            let tmp1 = scale * x[idx1];
228            let tmp2 = scale * x[idx2];
229            x[idx1] = tmp1 + tmp2;
230            x[idx2] = tmp1 - tmp2;
231        }
232    }
233}
234
235#[cfg(target_arch = "aarch64")]
236fn haar1_neon(x: &mut [f32], n0: usize, stride: usize) {
237    use std::arch::aarch64::*;
238
239    let n = n0 >> 1;
240    let scale = std::f32::consts::FRAC_1_SQRT_2;
241
242    unsafe {
243        let vscale = vdupq_n_f32(scale);
244
245        for i in 0..stride {
246            // Process 4 pairs at a time when possible
247            let mut j = 0;
248            while j + 4 <= n {
249                // Load 4 even elements and 4 odd elements
250                let idx_even = stride * 2 * j + i;
251                let idx_odd = stride * (2 * j + 1) + i;
252
253                let ve0 = vld1q_f32(x.as_ptr().add(idx_even));
254                let ve1 = vld1q_f32(x.as_ptr().add(idx_even + stride * 2));
255                let ve2 = vld1q_f32(x.as_ptr().add(idx_even + stride * 4));
256                let ve3 = vld1q_f32(x.as_ptr().add(idx_even + stride * 6));
257
258                let vo0 = vld1q_f32(x.as_ptr().add(idx_odd));
259                let vo1 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 2));
260                let vo2 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 4));
261                let vo3 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 6));
262
263                // Scale
264                let te0 = vmulq_f32(ve0, vscale);
265                let te1 = vmulq_f32(ve1, vscale);
266                let te2 = vmulq_f32(ve2, vscale);
267                let te3 = vmulq_f32(ve3, vscale);
268
269                let to0 = vmulq_f32(vo0, vscale);
270                let to1 = vmulq_f32(vo1, vscale);
271                let to2 = vmulq_f32(vo2, vscale);
272                let to3 = vmulq_f32(vo3, vscale);
273
274                // Store sum and difference
275                vst1q_f32(x.as_mut_ptr().add(idx_even), vaddq_f32(te0, to0));
276                vst1q_f32(
277                    x.as_mut_ptr().add(idx_even + stride * 2),
278                    vaddq_f32(te1, to1),
279                );
280                vst1q_f32(
281                    x.as_mut_ptr().add(idx_even + stride * 4),
282                    vaddq_f32(te2, to2),
283                );
284                vst1q_f32(
285                    x.as_mut_ptr().add(idx_even + stride * 6),
286                    vaddq_f32(te3, to3),
287                );
288
289                vst1q_f32(x.as_mut_ptr().add(idx_odd), vsubq_f32(te0, to0));
290                vst1q_f32(
291                    x.as_mut_ptr().add(idx_odd + stride * 2),
292                    vsubq_f32(te1, to1),
293                );
294                vst1q_f32(
295                    x.as_mut_ptr().add(idx_odd + stride * 4),
296                    vsubq_f32(te2, to2),
297                );
298                vst1q_f32(
299                    x.as_mut_ptr().add(idx_odd + stride * 6),
300                    vsubq_f32(te3, to3),
301                );
302
303                j += 4;
304            }
305
306            // Handle remaining elements
307            while j < n {
308                let idx1 = stride * 2 * j + i;
309                let idx2 = stride * (2 * j + 1) + i;
310                let tmp1 = scale * x[idx1];
311                let tmp2 = scale * x[idx2];
312                x[idx1] = tmp1 + tmp2;
313                x[idx2] = tmp1 - tmp2;
314                j += 1;
315            }
316        }
317    }
318}
319
320#[inline(always)]
321pub fn compute_qn(n: usize, b: i32, offset: i32, pulse_cap: i32, stereo: bool) -> i32 {
322    static EXP2_TABLE8: [i16; 8] = [16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048];
323    let mut n2 = (2 * n as i32) - 1;
324    if stereo && n == 2 {
325        n2 -= 1;
326    }
327    let mut qb = celt_sudiv(b + n2 * offset, n2);
328    qb = qb.min(b - pulse_cap - (4 << BITRES));
329    qb = qb.min(8 << BITRES);
330    if qb < (1i32 << BITRES >> 1) {
331        1
332    } else {
333        let val = EXP2_TABLE8[(qb & 0x7) as usize] as i32;
334        let shift = 14 - (qb >> BITRES);
335        let raw = if (0..32).contains(&shift) {
336            val >> shift
337        } else {
338            0
339        };
340        let qn = (raw + 1) >> 1 << 1;
341        qn.min(256)
342    }
343}
344
345/// NEON-optimized stereo_itheta computation
346#[cfg(target_arch = "aarch64")]
347#[inline(always)]
348#[allow(unsafe_op_in_unsafe_fn)]
349unsafe fn stereo_itheta_neon(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
350    use std::arch::aarch64::*;
351
352    let mut emid = 1e-15f32;
353    let mut eside = 1e-15f32;
354
355    if stereo {
356        let mut sum_mid = vdupq_n_f32(0.0);
357        let mut sum_side = vdupq_n_f32(0.0);
358        let mut i = 0;
359
360        // Process 16 elements at a time
361        while i + 16 <= n {
362            let x0 = vld1q_f32(x.as_ptr().add(i));
363            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
364            let x2 = vld1q_f32(x.as_ptr().add(i + 8));
365            let x3 = vld1q_f32(x.as_ptr().add(i + 12));
366            let y0 = vld1q_f32(y.as_ptr().add(i));
367            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
368            let y2 = vld1q_f32(y.as_ptr().add(i + 8));
369            let y3 = vld1q_f32(y.as_ptr().add(i + 12));
370
371            let m0 = vaddq_f32(x0, y0);
372            let m1 = vaddq_f32(x1, y1);
373            let m2 = vaddq_f32(x2, y2);
374            let m3 = vaddq_f32(x3, y3);
375            let s0 = vsubq_f32(x0, y0);
376            let s1 = vsubq_f32(x1, y1);
377            let s2 = vsubq_f32(x2, y2);
378            let s3 = vsubq_f32(x3, y3);
379
380            sum_mid = vfmaq_f32(sum_mid, m0, m0);
381            sum_mid = vfmaq_f32(sum_mid, m1, m1);
382            sum_mid = vfmaq_f32(sum_mid, m2, m2);
383            sum_mid = vfmaq_f32(sum_mid, m3, m3);
384            sum_side = vfmaq_f32(sum_side, s0, s0);
385            sum_side = vfmaq_f32(sum_side, s1, s1);
386            sum_side = vfmaq_f32(sum_side, s2, s2);
387            sum_side = vfmaq_f32(sum_side, s3, s3);
388
389            i += 16;
390        }
391
392        // Process 8 elements
393        while i + 8 <= n {
394            let x0 = vld1q_f32(x.as_ptr().add(i));
395            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
396            let y0 = vld1q_f32(y.as_ptr().add(i));
397            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
398
399            let m0 = vaddq_f32(x0, y0);
400            let m1 = vaddq_f32(x1, y1);
401            let s0 = vsubq_f32(x0, y0);
402            let s1 = vsubq_f32(x1, y1);
403
404            sum_mid = vfmaq_f32(sum_mid, m0, m0);
405            sum_mid = vfmaq_f32(sum_mid, m1, m1);
406            sum_side = vfmaq_f32(sum_side, s0, s0);
407            sum_side = vfmaq_f32(sum_side, s1, s1);
408
409            i += 8;
410        }
411
412        // Process 4 elements
413        while i + 4 <= n {
414            let x0 = vld1q_f32(x.as_ptr().add(i));
415            let y0 = vld1q_f32(y.as_ptr().add(i));
416            let m0 = vaddq_f32(x0, y0);
417            let s0 = vsubq_f32(x0, y0);
418            sum_mid = vfmaq_f32(sum_mid, m0, m0);
419            sum_side = vfmaq_f32(sum_side, s0, s0);
420            i += 4;
421        }
422
423        emid += vaddvq_f32(sum_mid);
424        eside += vaddvq_f32(sum_side);
425
426        // Scalar tail
427        for j in i..n {
428            let m = x[j] + y[j];
429            let s = x[j] - y[j];
430            emid += m * m;
431            eside += s * s;
432        }
433    } else {
434        let mut sum_mid = vdupq_n_f32(0.0);
435        let mut sum_side = vdupq_n_f32(0.0);
436        let mut i = 0;
437
438        // Process 16 elements at a time
439        while i + 16 <= n {
440            let x0 = vld1q_f32(x.as_ptr().add(i));
441            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
442            let x2 = vld1q_f32(x.as_ptr().add(i + 8));
443            let x3 = vld1q_f32(x.as_ptr().add(i + 12));
444            let y0 = vld1q_f32(y.as_ptr().add(i));
445            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
446            let y2 = vld1q_f32(y.as_ptr().add(i + 8));
447            let y3 = vld1q_f32(y.as_ptr().add(i + 12));
448
449            sum_mid = vfmaq_f32(sum_mid, x0, x0);
450            sum_mid = vfmaq_f32(sum_mid, x1, x1);
451            sum_mid = vfmaq_f32(sum_mid, x2, x2);
452            sum_mid = vfmaq_f32(sum_mid, x3, x3);
453            sum_side = vfmaq_f32(sum_side, y0, y0);
454            sum_side = vfmaq_f32(sum_side, y1, y1);
455            sum_side = vfmaq_f32(sum_side, y2, y2);
456            sum_side = vfmaq_f32(sum_side, y3, y3);
457
458            i += 16;
459        }
460
461        // Process 8 elements
462        while i + 8 <= n {
463            let x0 = vld1q_f32(x.as_ptr().add(i));
464            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
465            let y0 = vld1q_f32(y.as_ptr().add(i));
466            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
467
468            sum_mid = vfmaq_f32(sum_mid, x0, x0);
469            sum_mid = vfmaq_f32(sum_mid, x1, x1);
470            sum_side = vfmaq_f32(sum_side, y0, y0);
471            sum_side = vfmaq_f32(sum_side, y1, y1);
472
473            i += 8;
474        }
475
476        // Process 4 elements
477        while i + 4 <= n {
478            let x0 = vld1q_f32(x.as_ptr().add(i));
479            let y0 = vld1q_f32(y.as_ptr().add(i));
480            sum_mid = vfmaq_f32(sum_mid, x0, x0);
481            sum_side = vfmaq_f32(sum_side, y0, y0);
482            i += 4;
483        }
484
485        emid += vaddvq_f32(sum_mid);
486        eside += vaddvq_f32(sum_side);
487
488        // Scalar tail
489        for j in i..n {
490            emid += x[j] * x[j];
491            eside += y[j] * y[j];
492        }
493    }
494
495    let mid = emid.sqrt();
496    let side = eside.sqrt();
497    let theta_norm = celt_atan2p_norm(side, mid);
498    (0.5 + 16384.0 * theta_norm) as i32
499}
500
501#[inline(always)]
502#[cfg(target_arch = "aarch64")]
503pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
504    unsafe {
505        return stereo_itheta_neon(x, y, stereo, n);
506    }
507}
508
509#[inline(always)]
510#[cfg(not(target_arch = "aarch64"))]
511pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
512    // Use NEON on aarch64
513    #[cfg(target_arch = "aarch64")]
514    unsafe {
515        return stereo_itheta_neon(x, y, stereo, n);
516    }
517    #[cfg(not(target_arch = "aarch64"))]
518    {
519        let mut emid = 1e-15f32;
520        let mut eside = 1e-15f32;
521        if stereo {
522            for i in 0..n {
523                let m = x[i] + y[i];
524                let s = x[i] - y[i];
525                emid += m * m;
526                eside += s * s;
527            }
528        } else {
529            for i in 0..n {
530                emid += x[i] * x[i];
531                eside += y[i] * y[i];
532            }
533        }
534        let mid = emid.sqrt();
535        let side = eside.sqrt();
536        let theta_norm = celt_atan2p_norm(side, mid);
537        (0.5 + 16384.0 * theta_norm) as i32
538    }
539}
540
541/// Polynomial approximation of atan2(y, x) / (π/2) for x,y >= 0.
542/// Matches C opus celt_atan2p_norm() from celt/mathops.h (Remez degree 7).
543/// Returns value in [0, 1] mapping [0°, 90°].
544#[inline(always)]
545fn celt_atan2p_norm(y: f32, x: f32) -> f32 {
546    #[inline(always)]
547    fn atan_norm(x: f32) -> f32 {
548        const ATAN2_2_OVER_PI: f32 = 0.636_619_77_f32;
549        const A03: f32 = -3.333_165_9e-1_f32;
550        const A05: f32 = 1.996_270_4e-1_f32;
551        const A07: f32 = -1.397_658_3e-1_f32;
552        const A09: f32 = 9.794_234_e-2_f32;
553        const A11: f32 = -5.777_359_e-2_f32;
554        const A13: f32 = 2.304_014e-2_f32;
555        const A15: f32 = -4.355_406e-3_f32;
556        let x2 = x * x;
557        ATAN2_2_OVER_PI
558            * x
559            * (1.0
560                + x2 * (A03
561                    + x2 * (A05 + x2 * (A07 + x2 * (A09 + x2 * (A11 + x2 * (A13 + x2 * A15)))))))
562    }
563    if x * x + y * y < 1e-18 {
564        return 0.0;
565    }
566    if y < x {
567        atan_norm(y / x)
568    } else {
569        1.0 - atan_norm(x / y)
570    }
571}
572
573pub struct SplitCtx {
574    pub inv: bool,
575    pub imid: i32,
576    pub iside: i32,
577    pub delta: i32,
578    pub itheta: i32,
579    pub qalloc: i32,
580}
581
582/// Encode-only compute_theta - eliminates all decode branches at compile time.
583#[allow(clippy::too_many_arguments)]
584#[inline(always)]
585fn compute_theta_encode(
586    ctx: &mut BandCtx,
587    sctx: &mut SplitCtx,
588    x: &[f32],
589    y: &[f32],
590    n: usize,
591    b: &mut i32,
592    b_blocks: i32,
593    _b0: i32,
594    lm: i32,
595    stereo: bool,
596    fill: &mut u32,
597) {
598    let pulse_cap = ctx.m.log_n[ctx.i] as i32 + (lm << BITRES);
599    let offset = (pulse_cap >> 1) - if stereo && n == 2 { 16 } else { 4 };
600    let mut qn = compute_qn(n, *b, offset, pulse_cap, stereo);
601
602    if stereo && ctx.i >= ctx.intensity {
603        qn = 1;
604    }
605
606    // Fast path: when qn == 1 and not stereo-intensity, theta is fixed at 8192.
607    // Skip expensive stereo_itheta (NEON vector op over n elements) and tell_frac.
608    // This is common for low-bitrate bands where split energy ratio isn't encoded.
609    if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
610        sctx.itheta = 8192;
611        sctx.qalloc = 0;
612        let imid = bitexact_cos(8192i16);
613        sctx.imid = imid as i32;
614        let iside = bitexact_cos(8192i16);
615        sctx.iside = iside as i32;
616        sctx.delta =
617            (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
618        return;
619    }
620
621    let mut itheta = stereo_itheta(x, y, stereo, n);
622
623    let tell_start = tell_frac_inline!(ctx.rc);
624
625    if qn != 1 {
626        if !stereo || ctx.theta_round == 0 {
627            itheta = (itheta * qn + 8192) >> 14;
628            if !stereo && ctx.avoid_split_noise && itheta > 0 && itheta < qn {
629                let unquantized = (itheta * 16384) / qn;
630                let imid = bitexact_cos(unquantized as i16) as i32;
631                let iside = bitexact_cos((16384 - unquantized) as i16) as i32;
632                let delta = (((n as i32 - 1) << 7) * bitexact_log2tan(iside, imid) + 16384) >> 15;
633                if delta > *b {
634                    itheta = qn;
635                } else if delta < -*b {
636                    itheta = 0;
637                }
638            }
639        } else {
640            let bias = if itheta > 8192 {
641                32767 / qn
642            } else {
643                -32767 / qn
644            };
645            let down = (itheta * qn + bias) >> 14;
646            let down = down.clamp(0, qn - 1);
647            if ctx.theta_round < 0 {
648                itheta = down;
649            } else {
650                itheta = down + 1;
651            }
652        }
653
654        if stereo && n > 2 {
655            let p0 = 3;
656            let x0 = qn / 2;
657            let ft = p0 * (x0 + 1) + x0;
658            let fl = if itheta <= x0 {
659                p0 * itheta
660            } else {
661                (itheta - 1 - x0) + (x0 + 1) * p0
662            };
663            let fh = if itheta <= x0 {
664                p0 * (itheta + 1)
665            } else {
666                (itheta - x0) + (x0 + 1) * p0
667            };
668            ctx.rc.encode(fl as u32, fh as u32, ft as u32);
669        } else if b_blocks > 1 || stereo {
670            ctx.rc.enc_uint(itheta as u32, (qn + 1) as u32);
671        } else {
672            let ft = ((qn >> 1) + 1) * ((qn >> 1) + 1);
673            let fs = if itheta <= (qn >> 1) {
674                itheta + 1
675            } else {
676                qn + 1 - itheta
677            };
678            let fl = if itheta <= (qn >> 1) {
679                (itheta * (itheta + 1)) >> 1
680            } else {
681                ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1)
682            };
683            ctx.rc.encode(fl as u32, (fl + fs) as u32, ft as u32);
684        }
685        itheta = (itheta * 16384 + (qn >> 1)) / qn;
686    } else {
687        if stereo && ctx.i >= ctx.intensity {
688            let mut emid = 1e-15f32;
689            let mut eside = 1e-15f32;
690            for i in 0..n {
691                let m = x[i] + y[i];
692                let s = x[i] - y[i];
693                emid += m * m;
694                eside += s * s;
695            }
696            let inv = eside > emid;
697            ctx.rc.encode_bit_logp(inv, 1);
698            itheta = 0;
699            sctx.inv = inv;
700        } else {
701            itheta = 8192;
702        }
703    }
704
705    sctx.itheta = itheta;
706
707    sctx.qalloc = if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
708        0
709    } else {
710        tell_frac_inline!(ctx.rc) - tell_start
711    };
712
713    if itheta == 0 {
714        sctx.imid = 32767;
715        sctx.iside = 0;
716        sctx.delta = -16384;
717        *fill &= (1 << b_blocks) - 1;
718    } else if itheta == 16384 {
719        sctx.imid = 0;
720        sctx.iside = 32767;
721        sctx.delta = 16384;
722        *fill &= !((1 << b_blocks) - 1);
723    } else {
724        let imid = bitexact_cos(itheta as i16);
725        sctx.imid = imid as i32;
726        let iside = bitexact_cos((16384 - itheta) as i16);
727        sctx.iside = iside as i32;
728        sctx.delta =
729            (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
730    }
731}
732
733#[allow(clippy::too_many_arguments)]
734#[inline(always)]
735pub fn compute_theta(
736    ctx: &mut BandCtx,
737    sctx: &mut SplitCtx,
738    x: &[f32],
739    y: &[f32],
740    n: usize,
741    b: &mut i32,
742    b_blocks: i32,
743    _b0: i32,
744    lm: i32,
745    stereo: bool,
746    fill: &mut u32,
747) {
748    let pulse_cap = ctx.m.log_n[ctx.i] as i32 + (lm << BITRES);
749    let offset = (pulse_cap >> 1) - if stereo && n == 2 { 16 } else { 4 };
750    let mut qn = compute_qn(n, *b, offset, pulse_cap, stereo);
751
752    if stereo && ctx.i >= ctx.intensity {
753        qn = 1;
754    }
755
756    let mut itheta = 0;
757    if ctx.encode {
758        itheta = stereo_itheta(x, y, stereo, n);
759    }
760
761    let tell_start = tell_frac_inline!(ctx.rc);
762
763    if qn != 1 {
764        if ctx.encode {
765            if !stereo || ctx.theta_round == 0 {
766                itheta = (itheta * qn + 8192) >> 14;
767                if !stereo && ctx.avoid_split_noise && itheta > 0 && itheta < qn {
768                    let unquantized = (itheta * 16384) / qn;
769                    let imid = bitexact_cos(unquantized as i16) as i32;
770                    let iside = bitexact_cos((16384 - unquantized) as i16) as i32;
771                    let delta =
772                        (((n as i32 - 1) << 7) * bitexact_log2tan(iside, imid) + 16384) >> 15;
773                    if delta > *b {
774                        itheta = qn;
775                    } else if delta < -*b {
776                        itheta = 0;
777                    }
778                }
779            } else {
780                let bias = if itheta > 8192 {
781                    32767 / qn
782                } else {
783                    -32767 / qn
784                };
785                let down = (itheta * qn + bias) >> 14;
786                let down = down.clamp(0, qn - 1);
787                if ctx.theta_round < 0 {
788                    itheta = down;
789                } else {
790                    itheta = down + 1;
791                }
792            }
793        }
794
795        if stereo && n > 2 {
796            let p0 = 3;
797            let x0 = qn / 2;
798            let ft = p0 * (x0 + 1) + x0;
799            if ctx.encode {
800                let fl = if itheta <= x0 {
801                    p0 * itheta
802                } else {
803                    (itheta - 1 - x0) + (x0 + 1) * p0
804                };
805                let fh = if itheta <= x0 {
806                    p0 * (itheta + 1)
807                } else {
808                    (itheta - x0) + (x0 + 1) * p0
809                };
810                ctx.rc.encode(fl as u32, fh as u32, ft as u32);
811            } else {
812                let fs = ctx.rc.decode(ft as u32);
813                if fs < (x0 + 1) as u32 * p0 as u32 {
814                    itheta = fs as i32 / p0;
815                } else {
816                    itheta = (x0 + 1) + (fs as i32 - (x0 + 1) * p0);
817                }
818                let fl = if itheta <= x0 {
819                    p0 * itheta
820                } else {
821                    (itheta - 1 - x0) + (x0 + 1) * p0
822                };
823                let fh = if itheta <= x0 {
824                    p0 * (itheta + 1)
825                } else {
826                    (itheta - x0) + (x0 + 1) * p0
827                };
828                ctx.rc.update(fl as u32, fh as u32, ft as u32);
829            }
830        } else if b_blocks > 1 || stereo {
831            if ctx.encode {
832                ctx.rc.enc_uint(itheta as u32, (qn + 1) as u32);
833            } else {
834                itheta = ctx.rc.dec_uint((qn + 1) as u32) as i32;
835            }
836        } else {
837            let ft = ((qn >> 1) + 1) * ((qn >> 1) + 1);
838            if ctx.encode {
839                let fs = if itheta <= (qn >> 1) {
840                    itheta + 1
841                } else {
842                    qn + 1 - itheta
843                };
844                let fl = if itheta <= (qn >> 1) {
845                    (itheta * (itheta + 1)) >> 1
846                } else {
847                    ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1)
848                };
849                ctx.rc.encode(fl as u32, (fl + fs) as u32, ft as u32);
850            } else {
851                let fm = ctx.rc.decode(ft as u32) as i32;
852                if fm < (((qn >> 1) * ((qn >> 1) + 1)) >> 1) {
853                    itheta = (isqrt32((8 * fm + 1) as u32) as i32 - 1) >> 1;
854                    let fl = (itheta * (itheta + 1)) >> 1;
855                    let fs = itheta + 1;
856                    ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
857                } else {
858                    itheta = (2 * (qn + 1) - isqrt32((8 * (ft - fm - 1) + 1) as u32) as i32) >> 1;
859                    let fs = qn + 1 - itheta;
860                    let fl = ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1);
861                    ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
862                }
863            }
864        }
865        itheta = (itheta * 16384 + (qn >> 1)) / qn;
866    } else {
867        if stereo && ctx.i >= ctx.intensity {
868            if ctx.encode {
869                let mut emid = 1e-15f32;
870                let mut eside = 1e-15f32;
871                for i in 0..n {
872                    let m = x[i] + y[i];
873                    let s = x[i] - y[i];
874                    emid += m * m;
875                    eside += s * s;
876                }
877                let inv = eside > emid;
878                ctx.rc.encode_bit_logp(inv, 1);
879                itheta = 0;
880                sctx.inv = inv;
881            } else {
882                sctx.inv = ctx.rc.decode_bit_logp(1);
883                itheta = 0;
884            }
885        } else {
886            itheta = 8192;
887        }
888    }
889
890    sctx.itheta = itheta;
891
892    // Fast path: when qn == 1 and no stereo intensity, no bits are used
893    // Avoid second tell_frac call which saves ~3-5ns per call
894    sctx.qalloc = if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
895        0 // No encoding happened, qalloc is 0
896    } else {
897        tell_frac_inline!(ctx.rc) - tell_start
898    };
899
900    if itheta == 0 {
901        sctx.imid = 32767;
902        sctx.iside = 0;
903        sctx.delta = -16384;
904        *fill &= (1 << b_blocks) - 1;
905    } else if itheta == 16384 {
906        sctx.imid = 0;
907        sctx.iside = 32767;
908        sctx.delta = 16384;
909        *fill &= !((1 << b_blocks) - 1);
910    } else {
911        let imid = bitexact_cos(itheta as i16);
912        sctx.imid = imid as i32;
913        let iside = bitexact_cos((16384 - itheta) as i16);
914        sctx.iside = iside as i32;
915        sctx.delta =
916            (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
917    }
918}
919
920/// Encode-only fast path for quant_partition with n=2
921#[inline(always)]
922fn quant_partition_n2_encode(
923    ctx: &mut BandCtx,
924    x: &mut [f32],
925    b: i32,
926    b_blocks: i32,
927    lowband: Option<&mut [f32]>,
928    lm: i32,
929    gain: f32,
930    fill: u32,
931) -> u32 {
932    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
933    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
934    ctx.remaining_bits -= curr_bits;
935
936    while ctx.remaining_bits < 0 && q > 0 {
937        ctx.remaining_bits += curr_bits;
938        q -= 1;
939        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
940        ctx.remaining_bits -= curr_bits;
941    }
942
943    if q != 0 {
944        let k = get_pulses(q);
945        alg_quant(x, 2, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
946    } else {
947        let has_lowband = lowband.is_some();
948        if has_lowband {
949            fill
950        } else {
951            (1u32 << b_blocks) - 1
952        }
953    }
954}
955
956/// Fast path for quant_partition with n=2
957#[inline(always)]
958fn quant_partition_n2(
959    ctx: &mut BandCtx,
960    x: &mut [f32],
961    b: i32,
962    b_blocks: i32,
963    lowband: Option<&mut [f32]>,
964    lm: i32,
965    gain: f32,
966    fill: u32,
967) -> u32 {
968    // For n=2, we can skip the split and directly quantize
969    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
970    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
971    ctx.remaining_bits -= curr_bits;
972
973    while ctx.remaining_bits < 0 && q > 0 {
974        ctx.remaining_bits += curr_bits;
975        q -= 1;
976        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
977        ctx.remaining_bits -= curr_bits;
978    }
979
980    if q != 0 {
981        let k = get_pulses(q);
982        if ctx.encode {
983            alg_quant(
984                x,
985                2,
986                k,
987                ctx.spread,
988                b_blocks as usize,
989                ctx.rc,
990                gain,
991                ctx.resynth,
992            )
993        } else {
994            alg_unquant(x, 2, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
995        }
996    } else {
997        let has_lowband = lowband.is_some();
998        if ctx.resynth {
999            let mut seed = ctx.rc.tell() as u32;
1000            if has_lowband {
1001                let lb = lowband.unwrap();
1002                for i in 0..2 {
1003                    seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1004                    x[i] = lb[i]
1005                        + if seed & 0x8000 != 0 {
1006                            1.0 / 256.0
1007                        } else {
1008                            -1.0 / 256.0
1009                        };
1010                }
1011            } else {
1012                for i in 0..2 {
1013                    seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1014                    x[i] = ((seed as i32 >> 20) as f32) / 16384.0;
1015                }
1016            }
1017            renormalise_vector(x, 2, gain);
1018        }
1019        if has_lowband {
1020            fill
1021        } else {
1022            (1u32 << b_blocks) - 1
1023        }
1024    }
1025}
1026
1027/// Encode-only fast path for quant_partition with n=4
1028#[inline(always)]
1029fn quant_partition_n4_encode(
1030    ctx: &mut BandCtx,
1031    x: &mut [f32],
1032    b: i32,
1033    b_blocks: i32,
1034    lowband: Option<&mut [f32]>,
1035    lm: i32,
1036    gain: f32,
1037    fill: u32,
1038) -> u32 {
1039    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1040    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1041    ctx.remaining_bits -= curr_bits;
1042
1043    while ctx.remaining_bits < 0 && q > 0 {
1044        ctx.remaining_bits += curr_bits;
1045        q -= 1;
1046        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1047        ctx.remaining_bits -= curr_bits;
1048    }
1049
1050    if q != 0 {
1051        let k = get_pulses(q);
1052        alg_quant(x, 4, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1053    } else {
1054        let has_lowband = lowband.is_some();
1055        if has_lowband {
1056            fill
1057        } else {
1058            (1u32 << b_blocks) - 1
1059        }
1060    }
1061}
1062
1063/// Encode-only fast path for quant_partition with n=8
1064#[inline(always)]
1065fn quant_partition_n8_encode(
1066    ctx: &mut BandCtx,
1067    x: &mut [f32],
1068    b: i32,
1069    b_blocks: i32,
1070    lowband: Option<&mut [f32]>,
1071    lm: i32,
1072    gain: f32,
1073    fill: u32,
1074) -> u32 {
1075    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1076    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1077    ctx.remaining_bits -= curr_bits;
1078
1079    while ctx.remaining_bits < 0 && q > 0 {
1080        ctx.remaining_bits += curr_bits;
1081        q -= 1;
1082        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1083        ctx.remaining_bits -= curr_bits;
1084    }
1085
1086    if q != 0 {
1087        let k = get_pulses(q);
1088        alg_quant(x, 8, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1089    } else {
1090        let has_lowband = lowband.is_some();
1091        if has_lowband {
1092            fill
1093        } else {
1094            (1u32 << b_blocks) - 1
1095        }
1096    }
1097}
1098
1099/// Encode-only direct quantization for n=16
1100#[inline(always)]
1101#[allow(clippy::too_many_arguments)]
1102fn quant_partition_direct_encode(
1103    ctx: &mut BandCtx,
1104    x: &mut [f32],
1105    n: usize,
1106    b: i32,
1107    b_blocks: i32,
1108    lowband: Option<&mut [f32]>,
1109    lm: i32,
1110    gain: f32,
1111    fill: u32,
1112) -> u32 {
1113    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1114    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1115    ctx.remaining_bits -= curr_bits;
1116
1117    while ctx.remaining_bits < 0 && q > 0 {
1118        ctx.remaining_bits += curr_bits;
1119        q -= 1;
1120        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1121        ctx.remaining_bits -= curr_bits;
1122    }
1123
1124    if q != 0 {
1125        let k = get_pulses(q);
1126        alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1127    } else {
1128        let has_lowband = lowband.is_some();
1129        if has_lowband {
1130            fill
1131        } else {
1132            (1u32 << b_blocks) - 1
1133        }
1134    }
1135}
1136
1137/// Encode-only quant_partition - eliminates all decode branches at compile time.
1138/// This is the hot path for encoding, called for each band in quant_all_bands.
1139#[inline(always)]
1140#[allow(clippy::too_many_arguments)]
1141fn quant_partition_encode(
1142    ctx: &mut BandCtx,
1143    x: &mut [f32],
1144    n: usize,
1145    b: i32,
1146    b_blocks: i32,
1147    lowband: Option<&mut [f32]>,
1148    lm: i32,
1149    gain: f32,
1150    fill: u32,
1151) -> u32 {
1152    if n == 2 {
1153        return quant_partition_n2_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1154    }
1155    if n == 4 {
1156        return quant_partition_n4_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1157    }
1158    if n == 8 {
1159        return quant_partition_n8_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1160    }
1161    if n == 16 {
1162        return quant_partition_direct_encode(ctx, x, n, b, b_blocks, lowband, lm, gain, fill);
1163    }
1164
1165    // Split decision
1166    let should_split = if lm >= 0 && n > 2 {
1167        let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
1168        let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) } as usize;
1169        if cache_base > 0 {
1170            let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
1171            let max_q = unsafe { *cache_ptr } as usize;
1172            b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
1173        } else {
1174            false
1175        }
1176    } else {
1177        false
1178    };
1179
1180    if should_split {
1181        let mut sctx = SplitCtx {
1182            inv: false,
1183            imid: 0,
1184            iside: 0,
1185            delta: 0,
1186            itheta: 0,
1187            qalloc: 0,
1188        };
1189        let mut b_mut = b;
1190        let mut fill_mut = fill;
1191        let mid = n / 2;
1192        let (x_mid, x_side) = x.split_at_mut(mid);
1193
1194        compute_theta_encode(
1195            ctx,
1196            &mut sctx,
1197            x_mid,
1198            x_side,
1199            mid,
1200            &mut b_mut,
1201            (b_blocks + 1) >> 1,
1202            b_blocks,
1203            lm,
1204            false,
1205            &mut fill_mut,
1206        );
1207
1208        ctx.remaining_bits -= sctx.qalloc;
1209        let mbits = (0).max((b_mut - sctx.delta) / 2).min(b_mut);
1210        let mut sbits = b_mut - mbits;
1211        let mut mbits = mbits;
1212
1213        let mut rebalance = ctx.remaining_bits;
1214        let mut cm;
1215
1216        // In the encode-only path, gain is passed to alg_quant with resynth=false,
1217        // so it's never used. Skip the float multiplication to save work.
1218        if mbits >= sbits {
1219            cm = quant_partition_encode(
1220                ctx,
1221                x_mid,
1222                mid,
1223                mbits,
1224                (b_blocks + 1) >> 1,
1225                lowband,
1226                lm,
1227                gain,
1228                fill_mut,
1229            );
1230            rebalance = mbits - (rebalance - ctx.remaining_bits);
1231            if rebalance > (3 << 3) && sctx.itheta != 0 {
1232                sbits += rebalance - (3 << 3);
1233            }
1234            cm |= quant_partition_encode(
1235                ctx,
1236                x_side,
1237                mid,
1238                sbits,
1239                (b_blocks + 1) >> 1,
1240                None,
1241                lm,
1242                gain,
1243                fill_mut >> b_blocks,
1244            ) << (b_blocks >> 1);
1245        } else {
1246            cm = quant_partition_encode(
1247                ctx,
1248                x_side,
1249                mid,
1250                sbits,
1251                (b_blocks + 1) >> 1,
1252                None,
1253                lm,
1254                gain,
1255                fill_mut >> b_blocks,
1256            ) << (b_blocks >> 1);
1257            rebalance = sbits - (rebalance - ctx.remaining_bits);
1258            if rebalance > (3 << 3) && sctx.itheta != 16384 {
1259                mbits += rebalance - (3 << 3);
1260            }
1261            cm |= quant_partition_encode(
1262                ctx,
1263                x_mid,
1264                mid,
1265                mbits,
1266                (b_blocks + 1) >> 1,
1267                lowband,
1268                lm,
1269                gain,
1270                fill_mut,
1271            );
1272        }
1273        cm
1274    } else {
1275        let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1276        let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1277        ctx.remaining_bits -= curr_bits;
1278
1279        while ctx.remaining_bits < 0 && q > 0 {
1280            ctx.remaining_bits += curr_bits;
1281            q -= 1;
1282            curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1283            ctx.remaining_bits -= curr_bits;
1284        }
1285
1286        if q != 0 {
1287            let k = get_pulses(q);
1288            alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1289        } else {
1290            if lowband.is_some() {
1291                fill
1292            } else {
1293                (1 << b_blocks) - 1
1294            }
1295        }
1296    }
1297}
1298
1299/// Fast path for quant_partition with n=4 and small b
1300#[inline(always)]
1301fn quant_partition_n4(
1302    ctx: &mut BandCtx,
1303    x: &mut [f32],
1304    b: i32,
1305    b_blocks: i32,
1306    lowband: Option<&mut [f32]>,
1307    lm: i32,
1308    gain: f32,
1309    fill: u32,
1310) -> u32 {
1311    // For n=4 with small b, skip the split and directly quantize
1312    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1313    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1314    ctx.remaining_bits -= curr_bits;
1315
1316    while ctx.remaining_bits < 0 && q > 0 {
1317        ctx.remaining_bits += curr_bits;
1318        q -= 1;
1319        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1320        ctx.remaining_bits -= curr_bits;
1321    }
1322
1323    if q != 0 {
1324        let k = get_pulses(q);
1325        if ctx.encode {
1326            alg_quant(
1327                x,
1328                4,
1329                k,
1330                ctx.spread,
1331                b_blocks as usize,
1332                ctx.rc,
1333                gain,
1334                ctx.resynth,
1335            )
1336        } else {
1337            alg_unquant(x, 4, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1338        }
1339    } else {
1340        let has_lowband = lowband.is_some();
1341        if ctx.resynth {
1342            let mut seed = ctx.rc.tell() as u32;
1343            if has_lowband {
1344                let lb = lowband.unwrap();
1345                for i in 0..4 {
1346                    seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1347                    x[i] = lb[i]
1348                        + if seed & 0x8000 != 0 {
1349                            1.0 / 256.0
1350                        } else {
1351                            -1.0 / 256.0
1352                        };
1353                }
1354            } else {
1355                for i in 0..4 {
1356                    seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1357                    x[i] = ((seed as i32 >> 20) as f32) / 16384.0;
1358                }
1359            }
1360            renormalise_vector(x, 4, gain);
1361        }
1362        if has_lowband {
1363            fill
1364        } else {
1365            (1u32 << b_blocks) - 1
1366        }
1367    }
1368}
1369
1370/// Fast path for quant_partition with n=8 and small b
1371#[inline(always)]
1372fn quant_partition_n8(
1373    ctx: &mut BandCtx,
1374    x: &mut [f32],
1375    b: i32,
1376    b_blocks: i32,
1377    lowband: Option<&mut [f32]>,
1378    lm: i32,
1379    gain: f32,
1380    fill: u32,
1381) -> u32 {
1382    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1383    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1384    ctx.remaining_bits -= curr_bits;
1385
1386    while ctx.remaining_bits < 0 && q > 0 {
1387        ctx.remaining_bits += curr_bits;
1388        q -= 1;
1389        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1390        ctx.remaining_bits -= curr_bits;
1391    }
1392
1393    if q != 0 {
1394        let k = get_pulses(q);
1395        if ctx.encode {
1396            alg_quant(
1397                x,
1398                8,
1399                k,
1400                ctx.spread,
1401                b_blocks as usize,
1402                ctx.rc,
1403                gain,
1404                ctx.resynth,
1405            )
1406        } else {
1407            alg_unquant(x, 8, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1408        }
1409    } else {
1410        let has_lowband = lowband.is_some();
1411        if ctx.resynth {
1412            let mut seed = ctx.rc.tell() as u32;
1413            if has_lowband {
1414                let lb = lowband.unwrap();
1415                for i in 0..8 {
1416                    seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1417                    x[i] = lb[i]
1418                        + if seed & 0x8000 != 0 {
1419                            1.0 / 256.0
1420                        } else {
1421                            -1.0 / 256.0
1422                        };
1423                }
1424            } else {
1425                for i in 0..8 {
1426                    seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1427                    x[i] = ((seed as i32 >> 20) as f32) / 16384.0;
1428                }
1429            }
1430            renormalise_vector(x, 8, gain);
1431        }
1432        if has_lowband {
1433            fill
1434        } else {
1435            (1u32 << b_blocks) - 1
1436        }
1437    }
1438}
1439
1440/// Direct quantization for n=16 (skip split recursion).
1441#[inline(always)]
1442#[allow(clippy::too_many_arguments)]
1443fn quant_partition_direct(
1444    ctx: &mut BandCtx,
1445    x: &mut [f32],
1446    n: usize,
1447    b: i32,
1448    b_blocks: i32,
1449    lowband: Option<&mut [f32]>,
1450    lm: i32,
1451    gain: f32,
1452    fill: u32,
1453) -> u32 {
1454    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1455    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1456    ctx.remaining_bits -= curr_bits;
1457
1458    while ctx.remaining_bits < 0 && q > 0 {
1459        ctx.remaining_bits += curr_bits;
1460        q -= 1;
1461        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1462        ctx.remaining_bits -= curr_bits;
1463    }
1464
1465    if q != 0 {
1466        let k = get_pulses(q);
1467        if ctx.encode {
1468            alg_quant(
1469                x,
1470                n,
1471                k,
1472                ctx.spread,
1473                b_blocks as usize,
1474                ctx.rc,
1475                gain,
1476                ctx.resynth,
1477            )
1478        } else {
1479            alg_unquant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1480        }
1481    } else {
1482        let has_lowband = lowband.is_some();
1483        if ctx.resynth {
1484            let mut seed = ctx.rc.tell() as u32;
1485            if let Some(lb) = lowband {
1486                for i in 0..n {
1487                    seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1488                    x[i] = lb[i]
1489                        + if seed & 0x8000 != 0 {
1490                            1.0 / 256.0
1491                        } else {
1492                            -1.0 / 256.0
1493                        };
1494                }
1495            } else {
1496                for i in 0..n {
1497                    seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1498                    x[i] = ((seed as i32 >> 20) as f32) / 16384.0;
1499                }
1500            }
1501            renormalise_vector(x, n, gain);
1502        }
1503        if has_lowband {
1504            fill
1505        } else {
1506            (1u32 << b_blocks) - 1
1507        }
1508    }
1509}
1510
1511#[inline(always)]
1512#[allow(clippy::too_many_arguments)]
1513pub fn quant_partition(
1514    ctx: &mut BandCtx,
1515    x: &mut [f32],
1516    n: usize,
1517    b: i32,
1518    b_blocks: i32,
1519    lowband: Option<&mut [f32]>,
1520    lm: i32,
1521    gain: f32,
1522    fill: u32,
1523) -> u32 {
1524    // Fast paths for small n: skip split recursion entirely.
1525    // Direct quantization is correct and avoids compute_theta + recursive calls.
1526    if n == 2 {
1527        return quant_partition_n2(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1528    }
1529
1530    if n == 4 {
1531        return quant_partition_n4(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1532    }
1533
1534    if n == 8 {
1535        return quant_partition_n8(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1536    }
1537
1538    // Fast path for n=16: skip split recursion (common at lm=3)
1539    if n == 16 {
1540        return quant_partition_direct(ctx, x, n, b, b_blocks, lowband, lm, gain, fill);
1541    }
1542
1543    // Match C opus split condition: only split if the band has more bits
1544    // than needed for max quality + 12. Uses the pulse cache table.
1545    // C: if (LM != -1 && b > cache[cache[0]]+12 && N>2)
1546    let should_split = if lm >= 0 && n > 2 {
1547        let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
1548        let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) } as usize;
1549        if cache_base > 0 {
1550            let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
1551            let max_q = unsafe { *cache_ptr } as usize;
1552            b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
1553        } else {
1554            false
1555        }
1556    } else {
1557        false
1558    };
1559    if should_split {
1560        let mut sctx = SplitCtx {
1561            inv: false,
1562            imid: 0,
1563            iside: 0,
1564            delta: 0,
1565            itheta: 0,
1566            qalloc: 0,
1567        };
1568        let mut b_mut = b;
1569        let mut fill_mut = fill;
1570        let mid = n / 2;
1571        let (x_mid, x_side) = x.split_at_mut(mid);
1572
1573        compute_theta(
1574            ctx,
1575            &mut sctx,
1576            x_mid,
1577            x_side,
1578            mid,
1579            &mut b_mut,
1580            (b_blocks + 1) >> 1,
1581            b_blocks,
1582            lm,
1583            false,
1584            &mut fill_mut,
1585        );
1586
1587        ctx.remaining_bits -= sctx.qalloc;
1588        let mbits = (0).max((b_mut - sctx.delta) / 2).min(b_mut);
1589        let mut sbits = b_mut - mbits;
1590        let mut mbits = mbits;
1591
1592        let mut rebalance = ctx.remaining_bits;
1593        let mut cm;
1594
1595        if mbits >= sbits {
1596            cm = quant_partition(
1597                ctx,
1598                x_mid,
1599                mid,
1600                mbits,
1601                (b_blocks + 1) >> 1,
1602                lowband,
1603                lm,
1604                gain * (sctx.imid as f32 / 32768.0),
1605                fill_mut,
1606            );
1607            rebalance = mbits - (rebalance - ctx.remaining_bits);
1608            if rebalance > (3 << 3) && sctx.itheta != 0 {
1609                sbits += rebalance - (3 << 3);
1610            }
1611            cm |= quant_partition(
1612                ctx,
1613                x_side,
1614                mid,
1615                sbits,
1616                (b_blocks + 1) >> 1,
1617                None,
1618                lm,
1619                gain * (sctx.iside as f32 / 32768.0),
1620                fill_mut >> b_blocks,
1621            ) << (b_blocks >> 1);
1622        } else {
1623            cm = quant_partition(
1624                ctx,
1625                x_side,
1626                mid,
1627                sbits,
1628                (b_blocks + 1) >> 1,
1629                None,
1630                lm,
1631                gain * (sctx.iside as f32 / 32768.0),
1632                fill_mut >> b_blocks,
1633            ) << (b_blocks >> 1);
1634            rebalance = sbits - (rebalance - ctx.remaining_bits);
1635            if rebalance > (3 << 3) && sctx.itheta != 16384 {
1636                mbits += rebalance - (3 << 3);
1637            }
1638            cm |= quant_partition(
1639                ctx,
1640                x_mid,
1641                mid,
1642                mbits,
1643                (b_blocks + 1) >> 1,
1644                lowband,
1645                lm,
1646                gain * (sctx.imid as f32 / 32768.0),
1647                fill_mut,
1648            );
1649        }
1650        cm
1651    } else {
1652        let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1653        let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1654        ctx.remaining_bits -= curr_bits;
1655
1656        // Match C opus: never exceed the budget.
1657        // If selected q would make remaining_bits negative, reduce q until it fits.
1658        while ctx.remaining_bits < 0 && q > 0 {
1659            ctx.remaining_bits += curr_bits;
1660            q -= 1;
1661            curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1662            ctx.remaining_bits -= curr_bits;
1663        }
1664
1665        if q != 0 {
1666            let k = get_pulses(q);
1667            if ctx.encode {
1668                alg_quant(
1669                    x,
1670                    n,
1671                    k,
1672                    ctx.spread,
1673                    b_blocks as usize,
1674                    ctx.rc,
1675                    gain,
1676                    ctx.resynth,
1677                )
1678            } else {
1679                alg_unquant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1680            }
1681        } else {
1682            // q == 0: no pulses, fill with noise if resynth
1683            let has_lowband = lowband.is_some();
1684            if ctx.resynth {
1685                let cm_mask = (1u32 << b_blocks) - 1;
1686                let fill_masked = fill & cm_mask;
1687                if fill_masked == 0 {
1688                    // All sub-blocks collapsed — produce silence
1689                    x[..n].fill(0.0);
1690                } else if has_lowband {
1691                    let lb = lowband.unwrap();
1692                    #[cfg(target_arch = "aarch64")]
1693                    unsafe {
1694                        use std::arch::aarch64::*;
1695                        let n8 = n & !7;
1696                        let mut i = 0;
1697                        while i < n8 {
1698                            // Generate 8 random values
1699                            let mut vals = [0.0f32; 8];
1700                            for j in 0..8 {
1701                                ctx.seed = celt_lcg_rand(ctx.seed);
1702                                vals[j] = if ctx.seed & 0x8000 != 0 {
1703                                    1.0 / 256.0
1704                                } else {
1705                                    -1.0 / 256.0
1706                                };
1707                            }
1708                            let vnoise = vld1q_f32(vals.as_ptr());
1709                            let vnoise1 = vld1q_f32(vals.as_ptr().add(4));
1710                            let vlb = vld1q_f32(lb.as_ptr().add(i));
1711                            let vlb1 = vld1q_f32(lb.as_ptr().add(i + 4));
1712                            let vres = vaddq_f32(vlb, vnoise);
1713                            let vres1 = vaddq_f32(vlb1, vnoise1);
1714                            vst1q_f32(x.as_mut_ptr().add(i), vres);
1715                            vst1q_f32(x.as_mut_ptr().add(i + 4), vres1);
1716                            i += 8;
1717                        }
1718                        for j in i..n {
1719                            ctx.seed = celt_lcg_rand(ctx.seed);
1720                            x[j] = lb[j]
1721                                + if ctx.seed & 0x8000 != 0 {
1722                                    1.0 / 256.0
1723                                } else {
1724                                    -1.0 / 256.0
1725                                };
1726                        }
1727                    }
1728                    #[cfg(not(target_arch = "aarch64"))]
1729                    {
1730                        for j in 0..n {
1731                            ctx.seed = celt_lcg_rand(ctx.seed);
1732                            x[j] = lb[j]
1733                                + if ctx.seed & 0x8000 != 0 {
1734                                    1.0 / 256.0
1735                                } else {
1736                                    -1.0 / 256.0
1737                                };
1738                        }
1739                    }
1740                    renormalise_vector(x, n, gain);
1741                } else {
1742                    for xv in x[..n].iter_mut() {
1743                        ctx.seed = celt_lcg_rand(ctx.seed);
1744                        *xv = ((ctx.seed as i32 >> 20) as f32) / 16384.0;
1745                    }
1746                    renormalise_vector(x, n, gain);
1747                }
1748            }
1749            if has_lowband {
1750                fill
1751            } else {
1752                (1 << b_blocks) - 1
1753            }
1754        }
1755    }
1756}
1757
1758/// NEON-optimized deinterleave for non-hadamard case
1759#[cfg(target_arch = "aarch64")]
1760#[inline(always)]
1761unsafe fn deinterleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1762    let n = n0 * stride;
1763    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1764    let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1765
1766    for i in 0..stride {
1767        let src_offset = i;
1768        let dst_offset = i * n0;
1769        for j in 0..n0 {
1770            tmp[dst_offset + j] = x[j * stride + src_offset];
1771        }
1772    }
1773
1774    x[..n].copy_from_slice(tmp);
1775}
1776
1777pub fn deinterleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1778    let n = n0 * stride;
1779    // Use MaybeUninit to avoid zero-initializing 1408 bytes when n is typically 2-16
1780    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1781    // SAFETY: we write tmp[0..n] fully before any read below
1782    let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1783    if hadamard {
1784        let offset = match stride {
1785            2 => 0,
1786            4 => 2,
1787            8 => 6,
1788            16 => 14,
1789            _ => 0,
1790        };
1791        let ordery = &ORDERY_TABLE[offset..offset + stride];
1792        for i in 0..stride {
1793            for j in 0..n0 {
1794                tmp[ordery[i] as usize * n0 + j] = x[j * stride + i];
1795            }
1796        }
1797    } else {
1798        // Use NEON on aarch64 for better performance
1799        #[cfg(target_arch = "aarch64")]
1800        unsafe {
1801            if n0 >= 4 {
1802                deinterleave_hadamard_neon(x, n0, stride);
1803                return;
1804            }
1805        }
1806        for i in 0..stride {
1807            for j in 0..n0 {
1808                tmp[i * n0 + j] = x[j * stride + i];
1809            }
1810        }
1811    }
1812    x[..n].copy_from_slice(tmp);
1813}
1814
1815/// NEON-optimized interleave for non-hadamard case
1816#[cfg(target_arch = "aarch64")]
1817#[inline(always)]
1818unsafe fn interleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1819    let n = n0 * stride;
1820    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1821    let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1822
1823    for i in 0..stride {
1824        let src_offset = i * n0;
1825        let dst_offset = i;
1826        for j in 0..n0 {
1827            tmp[j * stride + dst_offset] = x[src_offset + j];
1828        }
1829    }
1830
1831    x[..n].copy_from_slice(tmp);
1832}
1833
1834pub fn interleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1835    let n = n0 * stride;
1836    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1837    let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1838    if hadamard {
1839        let offset = match stride {
1840            2 => 0,
1841            4 => 2,
1842            8 => 6,
1843            16 => 14,
1844            _ => 0,
1845        };
1846        let ordery = &ORDERY_TABLE[offset..offset + stride];
1847        for i in 0..stride {
1848            for j in 0..n0 {
1849                tmp[j * stride + i] = x[ordery[i] as usize * n0 + j];
1850            }
1851        }
1852    } else {
1853        // Use NEON on aarch64 for better performance
1854        #[cfg(target_arch = "aarch64")]
1855        unsafe {
1856            if n0 >= 4 {
1857                interleave_hadamard_neon(x, n0, stride);
1858                return;
1859            }
1860        }
1861        for i in 0..stride {
1862            for j in 0..n0 {
1863                tmp[j * stride + i] = x[i * n0 + j];
1864            }
1865        }
1866    }
1867    x[..n].copy_from_slice(tmp);
1868}
1869
1870const ORDERY_TABLE: [i32; 30] = [
1871    1, 0, 3, 0, 2, 1, 7, 0, 4, 3, 6, 1, 5, 2, 15, 0, 8, 7, 12, 3, 11, 4, 14, 1, 9, 6, 13, 2, 10, 5,
1872];
1873
1874fn quant_band_n1(
1875    ctx: &mut BandCtx,
1876    x: &mut [f32],
1877    y: Option<&mut [f32]>,
1878    lowband_out: Option<&mut [f32]>,
1879) -> u32 {
1880    let mut sign = 0;
1881    if ctx.remaining_bits >= 1 << BITRES {
1882        if ctx.encode {
1883            sign = if x[0] < 0.0 { 1 } else { 0 };
1884            ctx.rc.enc_bits(sign as u32, 1);
1885        } else {
1886            sign = ctx.rc.dec_bits(1) as i32;
1887        }
1888        ctx.remaining_bits -= 1 << BITRES;
1889    }
1890    if ctx.resynth {
1891        x[0] = if sign != 0 { -1.0 } else { 1.0 };
1892    }
1893    if let Some(y_val) = y {
1894        let mut y_sign = 0;
1895        if ctx.remaining_bits >= 1 << BITRES {
1896            if ctx.encode {
1897                y_sign = if y_val[0] < 0.0 { 1 } else { 0 };
1898                ctx.rc.enc_bits(y_sign as u32, 1);
1899            } else {
1900                y_sign = ctx.rc.dec_bits(1) as i32;
1901            }
1902            ctx.remaining_bits -= 1 << BITRES;
1903        }
1904        if ctx.resynth {
1905            y_val[0] = if y_sign != 0 { -1.0 } else { 1.0 };
1906        }
1907    }
1908    if let Some(l_out) = lowband_out {
1909        l_out[0] = x[0] / 16.0;
1910    }
1911    1
1912}
1913
1914#[allow(clippy::too_many_arguments)]
1915#[inline(always)]
1916pub fn quant_band(
1917    ctx: &mut BandCtx,
1918    x: &mut [f32],
1919    n: usize,
1920    b: i32,
1921    b_blocks: i32,
1922    lowband: Option<&mut [f32]>,
1923    lm: i32,
1924    lowband_out: Option<&mut [f32]>,
1925    gain: f32,
1926    fill: u32,
1927) -> u32 {
1928    let n0 = n;
1929    let b0 = b_blocks;
1930    let long_blocks = b0 == 1;
1931
1932    if n == 1 {
1933        return quant_band_n1(ctx, x, None, lowband_out);
1934    }
1935
1936    let mut b_blocks = b_blocks;
1937    let mut n_b = n / b_blocks as usize;
1938    let mut time_divide = 0;
1939    let mut recombine = 0;
1940    let mut tf_change_local = ctx.tf_change;
1941    let mut fill = fill;
1942
1943    if tf_change_local > 0 {
1944        recombine = tf_change_local;
1945    }
1946
1947    // The lowband data comes from quant_all_bands' lowband_scratch_buf, which is
1948    // already a copy from norm. We can modify it in-place since quant_all_bands
1949    // copies fresh data for each band. This eliminates a redundant 1408-byte
1950    // copy and MaybeUninit stack allocation.
1951    let mut lowband_buf = lowband;
1952
1953    static BIT_INTERLEAVE_TABLE: [u8; 16] = [0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3];
1954
1955    for k in 0..recombine {
1956        if ctx.encode {
1957            haar1(x, n >> k, 1 << k);
1958        }
1959        if let Some(ref mut lb) = lowband_buf {
1960            haar1(lb, n >> k, 1 << k);
1961        }
1962        fill = (BIT_INTERLEAVE_TABLE[(fill & 0xF) as usize] as u32)
1963            | ((BIT_INTERLEAVE_TABLE[(fill >> 4) as usize] as u32) << 2);
1964    }
1965    b_blocks >>= recombine;
1966    n_b <<= recombine;
1967
1968    while n_b & 1 == 0 && tf_change_local < 0 {
1969        if ctx.encode {
1970            haar1(x, n_b, b_blocks as usize);
1971        }
1972        if let Some(ref mut lb) = lowband_buf {
1973            haar1(lb, n_b, b_blocks as usize);
1974        }
1975        fill |= fill << b_blocks;
1976        b_blocks <<= 1;
1977        n_b >>= 1;
1978        time_divide += 1;
1979        tf_change_local += 1;
1980    }
1981
1982    let b0_after = b_blocks;
1983    let n_b0 = n_b;
1984
1985    if b_blocks > 1 {
1986        if ctx.encode {
1987            deinterleave_hadamard(
1988                x,
1989                n_b >> recombine as usize,
1990                (b_blocks << recombine) as usize,
1991                long_blocks,
1992            );
1993        }
1994        if let Some(ref mut lb) = lowband_buf {
1995            deinterleave_hadamard(
1996                lb,
1997                n_b >> recombine as usize,
1998                (b_blocks << recombine) as usize,
1999                long_blocks,
2000            );
2001        }
2002    }
2003
2004    let cm = if ctx.encode {
2005        quant_partition_encode(
2006            ctx,
2007            x,
2008            n,
2009            b,
2010            b_blocks,
2011            lowband_buf.as_deref_mut(),
2012            lm,
2013            gain,
2014            fill,
2015        )
2016    } else {
2017        quant_partition(
2018            ctx,
2019            x,
2020            n,
2021            b,
2022            b_blocks,
2023            lowband_buf.as_deref_mut(),
2024            lm,
2025            gain,
2026            fill,
2027        )
2028    };
2029
2030    if ctx.resynth {
2031        let mut cm = cm;
2032
2033        if b_blocks > 1 {
2034            interleave_hadamard(
2035                x,
2036                n_b >> recombine as usize,
2037                (b0_after << recombine) as usize,
2038                long_blocks,
2039            );
2040        }
2041
2042        let mut n_b_undo = n_b0;
2043        let mut b_undo = b0_after;
2044        for _ in 0..time_divide {
2045            b_undo >>= 1;
2046            n_b_undo <<= 1;
2047            cm |= cm >> b_undo;
2048            haar1(x, n_b_undo, b_undo as usize);
2049        }
2050
2051        static BIT_DEINTERLEAVE_TABLE: [u8; 16] = [
2052            0x00, 0x03, 0x0C, 0x0F, 0x30, 0x33, 0x3C, 0x3F, 0xC0, 0xC3, 0xCC, 0xCF, 0xF0, 0xF3,
2053            0xFC, 0xFF,
2054        ];
2055        for k in 0..recombine {
2056            cm = BIT_DEINTERLEAVE_TABLE[cm as usize & 0xF] as u32;
2057            haar1(x, n0 >> k, 1 << k);
2058        }
2059        let mut b_final = b0_after;
2060        b_final <<= recombine;
2061
2062        if let Some(lb_out) = lowband_out {
2063            let scale = (n0 as f32).sqrt();
2064            for j in 0..n0 {
2065                lb_out[j] = scale * x[j];
2066            }
2067        }
2068        cm &= (1u32 << b_final) - 1;
2069        return cm;
2070    }
2071
2072    cm
2073}
2074
2075pub fn stereo_merge(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
2076    #[cfg(target_arch = "aarch64")]
2077    {
2078        stereo_merge_neon(x, y, mid, side, n);
2079        return;
2080    }
2081    #[cfg(not(target_arch = "aarch64"))]
2082    {
2083        stereo_merge_scalar(x, y, mid, side, n);
2084    }
2085}
2086
2087#[cfg(not(target_arch = "aarch64"))]
2088#[inline]
2089fn stereo_merge_scalar(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
2090    for i in 0..n {
2091        let x_val = x[i] * mid;
2092        let y_val = y[i] * side;
2093        x[i] = x_val - y_val;
2094        y[i] = x_val + y_val;
2095    }
2096}
2097
2098#[cfg(target_arch = "aarch64")]
2099fn stereo_merge_neon(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
2100    use std::arch::aarch64::*;
2101
2102    unsafe {
2103        let vmid = vdupq_n_f32(mid);
2104        let vside = vdupq_n_f32(side);
2105
2106        // Process 16 elements at a time (4 vectors)
2107        let n16 = n & !15;
2108        for i in (0..n16).step_by(16) {
2109            let x0 = vld1q_f32(x.as_ptr().add(i));
2110            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2111            let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2112            let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2113
2114            let y0 = vld1q_f32(y.as_ptr().add(i));
2115            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
2116            let y2 = vld1q_f32(y.as_ptr().add(i + 8));
2117            let y3 = vld1q_f32(y.as_ptr().add(i + 12));
2118
2119            // x_val = x * mid, y_val = y * side
2120            let xv0 = vmulq_f32(x0, vmid);
2121            let xv1 = vmulq_f32(x1, vmid);
2122            let xv2 = vmulq_f32(x2, vmid);
2123            let xv3 = vmulq_f32(x3, vmid);
2124
2125            let yv0 = vmulq_f32(y0, vside);
2126            let yv1 = vmulq_f32(y1, vside);
2127            let yv2 = vmulq_f32(y2, vside);
2128            let yv3 = vmulq_f32(y3, vside);
2129
2130            // x = x_val - y_val, y = x_val + y_val
2131            vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(xv0, yv0));
2132            vst1q_f32(x.as_mut_ptr().add(i + 4), vsubq_f32(xv1, yv1));
2133            vst1q_f32(x.as_mut_ptr().add(i + 8), vsubq_f32(xv2, yv2));
2134            vst1q_f32(x.as_mut_ptr().add(i + 12), vsubq_f32(xv3, yv3));
2135
2136            vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(xv0, yv0));
2137            vst1q_f32(y.as_mut_ptr().add(i + 4), vaddq_f32(xv1, yv1));
2138            vst1q_f32(y.as_mut_ptr().add(i + 8), vaddq_f32(xv2, yv2));
2139            vst1q_f32(y.as_mut_ptr().add(i + 12), vaddq_f32(xv3, yv3));
2140        }
2141
2142        // Process remaining 4-element chunks
2143        let n4 = (n & !3) - n16;
2144        for i in (n16..n16 + n4).step_by(4) {
2145            let xv = vld1q_f32(x.as_ptr().add(i));
2146            let yv = vld1q_f32(y.as_ptr().add(i));
2147
2148            let x_val = vmulq_f32(xv, vmid);
2149            let y_val = vmulq_f32(yv, vside);
2150
2151            vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(x_val, y_val));
2152            vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(x_val, y_val));
2153        }
2154
2155        // Process remaining elements
2156        for i in (n16 + n4)..n {
2157            let x_val = x[i] * mid;
2158            let y_val = y[i] * side;
2159            x[i] = x_val - y_val;
2160            y[i] = x_val + y_val;
2161        }
2162    }
2163}
2164
2165#[allow(clippy::too_many_arguments)]
2166#[inline(always)]
2167pub fn quant_band_stereo(
2168    ctx: &mut BandCtx,
2169    x: &mut [f32],
2170    y: &mut [f32],
2171    n: usize,
2172    b: i32,
2173    b_blocks: i32,
2174    lowband: Option<&mut [f32]>,
2175    lm: i32,
2176    lowband_out: Option<&mut [f32]>,
2177    _gain: f32,
2178    fill: u32,
2179) -> u32 {
2180    if n == 1 {
2181        return quant_band_n1(ctx, x, Some(y), lowband_out);
2182    }
2183
2184    if ctx.encode
2185        && (ctx.band_e[ctx.i] < MIN_STEREO_ENERGY
2186            || ctx.band_e[ctx.m.nb_ebands + ctx.i] < MIN_STEREO_ENERGY)
2187    {
2188        if ctx.band_e[ctx.i] > ctx.band_e[ctx.m.nb_ebands + ctx.i] {
2189            y.copy_from_slice(x);
2190        } else {
2191            x.copy_from_slice(y);
2192        }
2193    }
2194
2195    let mut sctx = SplitCtx {
2196        inv: false,
2197        imid: 0,
2198        iside: 0,
2199        delta: 0,
2200        itheta: 0,
2201        qalloc: 0,
2202    };
2203    let mut b_mut = b;
2204    let mut fill_mut = fill;
2205    if ctx.encode {
2206        compute_theta_encode(
2207            ctx,
2208            &mut sctx,
2209            x,
2210            y,
2211            n,
2212            &mut b_mut,
2213            b_blocks,
2214            b_blocks,
2215            lm,
2216            true,
2217            &mut fill_mut,
2218        );
2219    } else {
2220        compute_theta(
2221            ctx,
2222            &mut sctx,
2223            x,
2224            y,
2225            n,
2226            &mut b_mut,
2227            b_blocks,
2228            b_blocks,
2229            lm,
2230            true,
2231            &mut fill_mut,
2232        );
2233    };
2234
2235    let mid_gain = sctx.imid as f32 / 32768.0;
2236    let side_gain = sctx.iside as f32 / 32768.0;
2237
2238    if n == 2 {
2239        let mut mbits = b_mut;
2240        let mut sbits = 0;
2241        if sctx.itheta != 0 && sctx.itheta != 16384 {
2242            sbits = 1 << BITRES;
2243        }
2244        mbits -= sbits;
2245        let c = sctx.itheta > 8192;
2246        ctx.remaining_bits -= sctx.qalloc + sbits;
2247
2248        let mut sign = 0;
2249        if sbits != 0 {
2250            if ctx.encode {
2251                sign = if c {
2252                    if (y[0] * x[1] - y[1] * x[0]) < 0.0 {
2253                        1
2254                    } else {
2255                        0
2256                    }
2257                } else {
2258                    if (x[0] * y[1] - x[1] * y[0]) < 0.0 {
2259                        1
2260                    } else {
2261                        0
2262                    }
2263                };
2264                ctx.rc.enc_bits(sign as u32, 1);
2265            } else {
2266                sign = ctx.rc.dec_bits(1) as i32;
2267            }
2268        }
2269        let sign_val = (1 - 2 * sign) as f32;
2270        let cm = if c {
2271            let cm = quant_band(
2272                ctx,
2273                y,
2274                n,
2275                mbits,
2276                b_blocks,
2277                lowband,
2278                lm,
2279                lowband_out,
2280                1.0,
2281                fill,
2282            );
2283            x[0] = -sign_val * y[1];
2284            x[1] = sign_val * y[0];
2285            cm
2286        } else {
2287            let cm = quant_band(
2288                ctx,
2289                x,
2290                n,
2291                mbits,
2292                b_blocks,
2293                lowband,
2294                lm,
2295                lowband_out,
2296                1.0,
2297                fill,
2298            );
2299            y[0] = -sign_val * x[1];
2300            y[1] = sign_val * x[0];
2301            cm
2302        };
2303
2304        if ctx.resynth {
2305            let x0 = x[0];
2306            let x1 = x[1];
2307            let y0 = y[0];
2308            let y1 = y[1];
2309            x[0] = mid_gain * x0 - side_gain * y0;
2310            x[1] = mid_gain * x1 - side_gain * y1;
2311            y[0] = mid_gain * x0 + side_gain * y0;
2312            y[1] = mid_gain * x1 + side_gain * y1;
2313        }
2314        return cm;
2315    }
2316
2317    ctx.remaining_bits -= sctx.qalloc;
2318    let mut mbits = (0).max((b_mut - sctx.delta) / 2).min(b_mut);
2319    let mut sbits = b_mut - mbits;
2320
2321    let mut rebalance = ctx.remaining_bits;
2322    let mut cm;
2323
2324    if mbits >= sbits {
2325        cm = quant_band(
2326            ctx,
2327            x,
2328            n,
2329            mbits,
2330            b_blocks,
2331            lowband,
2332            lm,
2333            lowband_out,
2334            1.0,
2335            fill_mut,
2336        );
2337        rebalance = mbits - (rebalance - ctx.remaining_bits);
2338        if rebalance > (3 << 3) && sctx.itheta != 0 {
2339            sbits += rebalance - (3 << 3);
2340        }
2341        cm |= quant_band(
2342            ctx,
2343            y,
2344            n,
2345            sbits,
2346            b_blocks,
2347            None,
2348            lm,
2349            None,
2350            side_gain,
2351            fill_mut >> b_blocks,
2352        ) << (b_blocks >> 1);
2353    } else {
2354        cm = quant_band(
2355            ctx,
2356            y,
2357            n,
2358            sbits,
2359            b_blocks,
2360            None,
2361            lm,
2362            None,
2363            side_gain,
2364            fill_mut >> b_blocks,
2365        ) << (b_blocks >> 1);
2366        rebalance = sbits - (rebalance - ctx.remaining_bits);
2367        if rebalance > (3 << 3) && sctx.itheta != 16384 {
2368            mbits += rebalance - (3 << 3);
2369        }
2370        cm |= quant_band(
2371            ctx,
2372            x,
2373            n,
2374            mbits,
2375            b_blocks,
2376            lowband,
2377            lm,
2378            lowband_out,
2379            1.0,
2380            fill_mut,
2381        );
2382    }
2383
2384    if ctx.resynth {
2385        stereo_merge(x, y, mid_gain, side_gain, n);
2386        if sctx.inv {
2387            for yv in y[..n].iter_mut() {
2388                *yv = -*yv;
2389            }
2390        }
2391    }
2392    cm
2393}
2394
2395#[allow(clippy::too_many_arguments)]
2396pub fn quant_all_bands(
2397    encode: bool,
2398    m: &CeltMode,
2399    start: usize,
2400    end: usize,
2401    x: &mut [f32],
2402    mut y: Option<&mut [f32]>,
2403    collapse_masks: &mut [u32],
2404    band_e: &[f32],
2405    pulses: &[i32],
2406    short_blocks: bool,
2407    spread: i32,
2408    dual_stereo: &mut bool,
2409    intensity: usize,
2410    tf_res: &[i32],
2411    total_bits: i32,
2412    balance: &mut i32,
2413    rc: &mut RangeCoder,
2414    lm: i32,
2415    coded_bands: i32,
2416    resynth: bool,
2417    seed: &mut u32,
2418) {
2419    let mut balance_val = *balance;
2420    let b_blocks = if short_blocks { 1 << lm } else { 1 };
2421    let c_channels = if y.is_some() { 2 } else { 1 };
2422    let m_val = 1usize << lm as usize;
2423
2424    let norm_offset = m_val * (m.e_bands[start] as usize);
2425    let norm_size = m_val * (m.e_bands[m.nb_ebands - 1] as usize) - norm_offset;
2426    // max norm_size = m_val(8) * e_bands_last(100) = 800
2427    const MAX_NORM_SIZE: usize = 800;
2428    debug_assert!(norm_size <= MAX_NORM_SIZE);
2429    // Use MaybeUninit to avoid zeroing 3200 bytes per frame.
2430    // SAFETY: norm is only read at positions previously written by a prior band's
2431    // lowband_out (which writes norm[norm_pos..norm_pos+n]). The first band never
2432    // reads from norm (lowband_offset starts at 0, effective_lowband stays -1).
2433    // This matches C opus which declares norm[] on the stack without initialization.
2434    let mut norm_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_NORM_SIZE];
2435    let norm =
2436        unsafe { std::slice::from_raw_parts_mut(norm_buf.as_mut_ptr() as *mut f32, norm_size) };
2437
2438    // Stack scratch buffer for lowband data (avoids per-band heap allocation).
2439    // MAX_PVQ_N = 352 covers the largest possible band size.
2440    // Use MaybeUninit since it's always fully overwritten by copy_from_slice before any read.
2441    let mut lowband_scratch_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
2442    let lowband_scratch_ptr = lowband_scratch_buf.as_mut_ptr() as *mut f32;
2443
2444    let mut lowband_offset: usize = 0;
2445    let mut update_lowband = true;
2446    let mut avoid_split_noise = b_blocks > 1;
2447
2448    // Cache e_bands locally to avoid repeated bound checks and indirection
2449    let e_bands = &m.e_bands;
2450
2451    // Local copy of seed that will be modified by BandCtx as bands are processed
2452    let mut ctx_seed = *seed;
2453
2454    for i in start..end {
2455        // Pre-fetch e_bands values to avoid repeated indexing
2456        let e_band_i = e_bands[i] as usize;
2457        let e_band_i1 = e_bands[i + 1] as usize;
2458        let offset = m_val * e_band_i;
2459        let n = m_val * (e_band_i1 - e_band_i);
2460        let last = i == end - 1;
2461
2462        // Use inline macro to avoid function call overhead
2463        let tell = tell_frac_inline!(rc);
2464        if i != start {
2465            balance_val -= tell;
2466        }
2467        let remaining_bits = total_bits - tell - 1;
2468
2469        let mut b = 0i32;
2470        if i < coded_bands as usize {
2471            let curr_balance = celt_sudiv(balance_val, 3i32.min(coded_bands - i as i32));
2472            b = 0i32.max(16383i32.min((remaining_bits + 1).min(pulses[i] + curr_balance)));
2473        }
2474
2475        let norm_pos = m_val * e_band_i - norm_offset;
2476
2477        let band_start = m_val * e_band_i;
2478        let bands_start = m_val * (e_bands[start] as usize);
2479        if resynth
2480            && (band_start as i32 - n as i32 >= bands_start as i32 || i == start + 1)
2481            && (update_lowband || lowband_offset == 0)
2482        {
2483            lowband_offset = i;
2484        }
2485
2486        let tf_change = tf_res[i];
2487
2488        let mut effective_lowband: i32 = -1;
2489        let mut x_cm: u32;
2490        let mut y_cm: u32;
2491
2492        if lowband_offset != 0 && (spread != SPREAD_AGGRESSIVE || b_blocks > 1 || tf_change < 0) {
2493            effective_lowband = 0i32.max(
2494                (m_val * e_bands[lowband_offset] as usize) as i32 - norm_offset as i32 - n as i32,
2495            );
2496            let el_abs = effective_lowband as usize + norm_offset;
2497
2498            let mut fold_start = lowband_offset;
2499            loop {
2500                if fold_start == 0 {
2501                    break;
2502                }
2503                fold_start -= 1;
2504                if m_val * (e_bands[fold_start] as usize) <= el_abs {
2505                    break;
2506                }
2507            }
2508            let mut fold_end = lowband_offset.saturating_sub(1);
2509            while fold_end + 1 < i && m_val * (e_bands[fold_end + 1] as usize) < el_abs + n {
2510                fold_end += 1;
2511            }
2512
2513            x_cm = 0;
2514            y_cm = 0;
2515            for fi in fold_start..fold_end {
2516                x_cm |= collapse_masks[fi * c_channels];
2517                y_cm |= collapse_masks[fi * c_channels + c_channels - 1];
2518            }
2519        } else {
2520            x_cm = (1u32 << b_blocks) - 1;
2521            y_cm = (1u32 << b_blocks) - 1;
2522        }
2523
2524        let mut ctx = BandCtx {
2525            encode,
2526            m,
2527            i,
2528            band_e,
2529            rc,
2530            spread,
2531            remaining_bits,
2532            resynth,
2533            tf_change,
2534            intensity,
2535            theta_round: 0,
2536            avoid_split_noise,
2537            arch: 0,
2538            disable_inv: false,
2539            seed: ctx_seed,
2540        };
2541
2542        if *dual_stereo && i == intensity {
2543            *dual_stereo = false;
2544        }
2545
2546        let mut lowband_scratch: Option<&mut [f32]> = if effective_lowband >= 0 {
2547            let lb_start = effective_lowband as usize;
2548            let lb_end = lb_start + n;
2549            if lb_end <= norm.len() {
2550                unsafe {
2551                    std::ptr::copy_nonoverlapping(
2552                        norm.as_ptr().add(lb_start),
2553                        lowband_scratch_ptr,
2554                        n,
2555                    )
2556                };
2557                Some(unsafe { std::slice::from_raw_parts_mut(lowband_scratch_ptr, n) })
2558            } else {
2559                None
2560            }
2561        } else {
2562            None
2563        };
2564
2565        // Pass a slice extending to the end of the frequency array, not just this band.
2566        // This is needed because alg_unquant/exp_rotation use stride-based access (x[i*stride])
2567        // which can go beyond this band's n elements. In the C reference, X is a raw pointer
2568        // into the full array so this works naturally. The stray writes into later bands'
2569        // memory are harmless because those bands are processed afterward and overwrite them.
2570        let x_slice = &mut x[offset..];
2571        if *dual_stereo {
2572            let y_slice = &mut y.as_mut().unwrap()[offset..];
2573            let lb_x = lowband_scratch.as_deref_mut();
2574            let lb_out_x = if !last && norm_pos + n <= norm.len() {
2575                Some(&mut norm[norm_pos..norm_pos + n])
2576            } else {
2577                None
2578            };
2579            x_cm = quant_band(
2580                &mut ctx,
2581                x_slice,
2582                n,
2583                b / 2,
2584                b_blocks,
2585                lb_x,
2586                lm,
2587                lb_out_x,
2588                1.0,
2589                x_cm,
2590            );
2591            y_cm = quant_band(
2592                &mut ctx,
2593                y_slice,
2594                n,
2595                b / 2,
2596                b_blocks,
2597                None,
2598                lm,
2599                None,
2600                1.0,
2601                y_cm,
2602            );
2603        } else {
2604            if let Some(y_all) = y.as_mut() {
2605                let y_slice = &mut y_all[offset..offset + n];
2606                let lb = lowband_scratch.as_deref_mut();
2607                let lb_out = if !last && norm_pos + n <= norm.len() {
2608                    Some(&mut norm[norm_pos..norm_pos + n])
2609                } else {
2610                    None
2611                };
2612                x_cm = quant_band_stereo(
2613                    &mut ctx,
2614                    x_slice,
2615                    y_slice,
2616                    n,
2617                    b,
2618                    b_blocks,
2619                    lb,
2620                    lm,
2621                    lb_out,
2622                    1.0,
2623                    x_cm | y_cm,
2624                );
2625                y_cm = x_cm;
2626            } else {
2627                let lb = lowband_scratch.as_deref_mut();
2628                let lb_out = if !last && norm_pos + n <= norm.len() {
2629                    Some(&mut norm[norm_pos..norm_pos + n])
2630                } else {
2631                    None
2632                };
2633                x_cm = quant_band(&mut ctx, x_slice, n, b, b_blocks, lb, lm, lb_out, 1.0, x_cm);
2634                y_cm = x_cm;
2635            }
2636        }
2637
2638        collapse_masks[i * c_channels] = (x_cm & 0xFF) as u8 as u32; // Truncate to 8-bit like libopus
2639        if c_channels == 2 {
2640            collapse_masks[i * c_channels + 1] = (y_cm & 0xFF) as u8 as u32;
2641        }
2642
2643        balance_val += pulses[i] + tell;
2644
2645        update_lowband = b > ((n as i32) << BITRES);
2646
2647        // Propagate seed back from ctx (modified by noise fill in quant_partition)
2648        ctx_seed = ctx.seed;
2649
2650        avoid_split_noise = false;
2651    }
2652    *balance = balance_val;
2653    *seed = ctx_seed;
2654}
2655
2656/// Compute band energies with NEON optimization on ARM64
2657#[cfg(target_arch = "aarch64")]
2658fn compute_band_energy_neon(band: &[f32]) -> f32 {
2659    use std::arch::aarch64::*;
2660
2661    let n = band.len();
2662    let mut sum = 1e-15f32;
2663
2664    unsafe {
2665        // Process 16 elements at a time (4 vectors of 4 floats each)
2666        let n16 = n & !15;
2667        if n16 > 0 {
2668            let mut acc0 = vdupq_n_f32(0.0);
2669            let mut acc1 = vdupq_n_f32(0.0);
2670            let mut acc2 = vdupq_n_f32(0.0);
2671            let mut acc3 = vdupq_n_f32(0.0);
2672
2673            for i in (0..n16).step_by(16) {
2674                let v0 = vld1q_f32(band.as_ptr().add(i));
2675                let v1 = vld1q_f32(band.as_ptr().add(i + 4));
2676                let v2 = vld1q_f32(band.as_ptr().add(i + 8));
2677                let v3 = vld1q_f32(band.as_ptr().add(i + 12));
2678
2679                acc0 = vfmaq_f32(acc0, v0, v0);
2680                acc1 = vfmaq_f32(acc1, v1, v1);
2681                acc2 = vfmaq_f32(acc2, v2, v2);
2682                acc3 = vfmaq_f32(acc3, v3, v3);
2683            }
2684
2685            // Combine accumulators
2686            acc0 = vaddq_f32(acc0, acc1);
2687            acc2 = vaddq_f32(acc2, acc3);
2688            acc0 = vaddq_f32(acc0, acc2);
2689            sum += vaddvq_f32(acc0);
2690        }
2691
2692        // Process remaining 4-element chunks
2693        let n4 = (n & !3) - n16;
2694        if n4 > 0 {
2695            let mut acc = vdupq_n_f32(0.0);
2696            for i in (n16..n16 + n4).step_by(4) {
2697                let v = vld1q_f32(band.as_ptr().add(i));
2698                acc = vfmaq_f32(acc, v, v);
2699            }
2700            sum += vaddvq_f32(acc);
2701        }
2702
2703        // Process remaining elements
2704        for i in (n16 + n4)..n {
2705            let v = band[i];
2706            sum += v * v;
2707        }
2708    }
2709
2710    sum.sqrt()
2711}
2712
2713pub fn compute_band_energies(
2714    m: &CeltMode,
2715    x: &[f32],
2716    band_e: &mut [f32],
2717    end: usize,
2718    channels: usize,
2719    lm: usize,
2720) {
2721    let frame_size = m.short_mdct_size << lm;
2722
2723    #[cfg(target_arch = "aarch64")]
2724    let _use_neon = true;
2725    #[cfg(not(target_arch = "aarch64"))]
2726    let _use_neon = false;
2727
2728    for c in 0..channels {
2729        let ch = &x[c * frame_size..(c + 1) * frame_size];
2730        for i in 0..end {
2731            let offset = (m.e_bands[i] as usize) << lm;
2732            let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2733            let band = &ch[offset..offset + n];
2734
2735            #[cfg(target_arch = "aarch64")]
2736            {
2737                band_e[c * m.nb_ebands + i] = compute_band_energy_neon(band);
2738            }
2739            #[cfg(not(target_arch = "aarch64"))]
2740            {
2741                let sum = band.iter().fold(1e-15f32, |acc, &v| acc + v * v);
2742                band_e[c * m.nb_ebands + i] = sum.sqrt();
2743            }
2744        }
2745    }
2746}
2747
2748pub fn amp2log2(
2749    m: &CeltMode,
2750    eff_ebands: usize,
2751    end: usize,
2752    band_e: &[f32],
2753    band_log_e: &mut [f32],
2754    channels: usize,
2755) {
2756    for c in 0..channels {
2757        for i in 0..eff_ebands {
2758            let val = band_e[c * m.nb_ebands + i].max(1e-10);
2759            band_log_e[c * m.nb_ebands + i] = val.log2() - m.e_means[i];
2760        }
2761        for i in eff_ebands..end {
2762            band_log_e[c * m.nb_ebands + i] = -14.0;
2763        }
2764    }
2765}
2766
2767/// Convert log2 energy to amplitude (actually just adds eMeans).
2768/// This matches Opus reference implementation where log2Amp just adds the mean.
2769pub fn log2amp(m: &CeltMode, end: usize, band_e: &mut [f32], band_log_e: &[f32], channels: usize) {
2770    for c in 0..channels {
2771        for i in 0..end {
2772            // Just add eMeans, don't apply exp2 - denormalise_bands will do that
2773            band_e[c * m.nb_ebands + i] = band_log_e[c * m.nb_ebands + i] + m.e_means[i];
2774        }
2775    }
2776}
2777
2778pub fn normalise_bands(
2779    m: &CeltMode,
2780    freq: &[f32],
2781    x: &mut [f32],
2782    band_e: &[f32],
2783    end: usize,
2784    channels: usize,
2785    m_val: usize,
2786) {
2787    let lm = m_val.trailing_zeros() as usize;
2788    let frame_size = m.short_mdct_size << lm;
2789    for c in 0..channels {
2790        for i in 0..end {
2791            let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2792            let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2793            let norm = 1.0 / (1e-15 + band_e[c * m.nb_ebands + i]);
2794            let src = &freq[base..base + n];
2795            let dst = &mut x[base..base + n];
2796            for (d, &s) in dst.iter_mut().zip(src) {
2797                *d = s * norm;
2798            }
2799        }
2800    }
2801}
2802
2803#[allow(clippy::too_many_arguments)]
2804pub fn denormalise_bands(
2805    m: &CeltMode,
2806    x: &[f32],
2807    freq: &mut [f32],
2808    band_e: &[f32],
2809    start: usize,
2810    end: usize,
2811    channels: usize,
2812    m_val: usize,
2813) {
2814    let lm = m_val.trailing_zeros() as usize;
2815    let frame_size = m.short_mdct_size << lm;
2816
2817    for c in 0..channels {
2818        for i in start..end {
2819            let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2820            let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2821            let band_log = band_e[c * m.nb_ebands + i];
2822            // band_e is log2(amplitude) from log2amp(); convert to linear: amplitude = 2^log2(amplitude)
2823            let g = (2.0f32).powf(band_log);
2824            let src = &x[base..base + n];
2825            let dst = &mut freq[base..base + n];
2826            for (d, &s) in dst.iter_mut().zip(src) {
2827                *d = s * g;
2828            }
2829        }
2830    }
2831}
2832
2833pub fn celt_lcg_rand(seed: u32) -> u32 {
2834    seed.wrapping_mul(1664525).wrapping_add(1013904223)
2835}
2836
2837/// NEON-optimized vector renormalization
2838#[cfg(target_arch = "aarch64")]
2839#[inline(always)]
2840#[allow(unsafe_op_in_unsafe_fn)]
2841unsafe fn renormalise_vector_neon(x: &mut [f32], n: usize, gain: f32) {
2842    use std::arch::aarch64::*;
2843
2844    // Compute sum of squares
2845    let mut sum_vec = vdupq_n_f32(0.0);
2846    let mut i = 0;
2847
2848    // Process 16 elements at a time
2849    while i + 16 <= n {
2850        let x0 = vld1q_f32(x.as_ptr().add(i));
2851        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2852        let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2853        let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2854        sum_vec = vfmaq_f32(sum_vec, x0, x0);
2855        sum_vec = vfmaq_f32(sum_vec, x1, x1);
2856        sum_vec = vfmaq_f32(sum_vec, x2, x2);
2857        sum_vec = vfmaq_f32(sum_vec, x3, x3);
2858        i += 16;
2859    }
2860
2861    // Process 8 elements
2862    while i + 8 <= n {
2863        let x0 = vld1q_f32(x.as_ptr().add(i));
2864        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2865        sum_vec = vfmaq_f32(sum_vec, x0, x0);
2866        sum_vec = vfmaq_f32(sum_vec, x1, x1);
2867        i += 8;
2868    }
2869
2870    // Process 4 elements
2871    while i + 4 <= n {
2872        let x0 = vld1q_f32(x.as_ptr().add(i));
2873        sum_vec = vfmaq_f32(sum_vec, x0, x0);
2874        i += 4;
2875    }
2876
2877    let mut e = 1e-15f32 + vaddvq_f32(sum_vec);
2878
2879    // Scalar tail for energy
2880    for j in i..n {
2881        e += x[j] * x[j];
2882    }
2883
2884    let norm = gain / e.sqrt();
2885    let vnorm = vdupq_n_f32(norm);
2886
2887    // Normalize with NEON
2888    i = 0;
2889    while i + 16 <= n {
2890        let x0 = vld1q_f32(x.as_ptr().add(i));
2891        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2892        let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2893        let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2894        vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2895        vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2896        vst1q_f32(x.as_mut_ptr().add(i + 8), vmulq_f32(x2, vnorm));
2897        vst1q_f32(x.as_mut_ptr().add(i + 12), vmulq_f32(x3, vnorm));
2898        i += 16;
2899    }
2900
2901    while i + 8 <= n {
2902        let x0 = vld1q_f32(x.as_ptr().add(i));
2903        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2904        vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2905        vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2906        i += 8;
2907    }
2908
2909    while i + 4 <= n {
2910        let x0 = vld1q_f32(x.as_ptr().add(i));
2911        vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2912        i += 4;
2913    }
2914
2915    // Scalar tail
2916    for j in i..n {
2917        x[j] *= norm;
2918    }
2919}
2920
2921pub fn renormalise_vector(x: &mut [f32], n: usize, gain: f32) {
2922    #[cfg(target_arch = "aarch64")]
2923    unsafe {
2924        renormalise_vector_neon(x, n, gain);
2925        return;
2926    }
2927    #[cfg(not(target_arch = "aarch64"))]
2928    {
2929        let mut e = 1e-15f32;
2930        for &xv in x[..n].iter() {
2931            e += xv * xv;
2932        }
2933        let norm = gain / e.sqrt();
2934        for xv in x[..n].iter_mut() {
2935            *xv *= norm;
2936        }
2937    }
2938}
2939
2940#[allow(clippy::too_many_arguments)]
2941pub fn anti_collapse(
2942    m: &CeltMode,
2943    x_buf: &mut [f32],
2944    collapse_masks: &[u32],
2945    lm: i32,
2946    channels: usize,
2947    size: usize,
2948    start: usize,
2949    end: usize,
2950    log_e: &[f32],
2951    prev1_log_e: &[f32],
2952    prev2_log_e: &[f32],
2953    pulses: &[i32],
2954    mut seed: u32,
2955) -> u32 {
2956    for i in start..end {
2957        let n0 = (m.e_bands[i + 1] - m.e_bands[i]) as usize;
2958        let depth = if n0 > 0 {
2959            ((1 + pulses[i]) / n0 as i32) >> lm
2960        } else {
2961            0
2962        };
2963
2964        let thresh = 0.5 * (-(0.125 * depth as f32)).exp2();
2965        let sqrt_1 = 1.0 / ((n0 << lm) as f32).sqrt();
2966
2967        for c in 0..channels {
2968            let p1 = prev1_log_e[c * m.nb_ebands + i];
2969            let p2 = prev2_log_e[c * m.nb_ebands + i];
2970
2971            let (p1_adj, p2_adj) = if channels == 1 && prev1_log_e.len() >= 2 * m.nb_ebands {
2972                (
2973                    p1.max(prev1_log_e[m.nb_ebands + i]),
2974                    p2.max(prev2_log_e[m.nb_ebands + i]),
2975                )
2976            } else {
2977                (p1, p2)
2978            };
2979
2980            let e_diff = log_e[c * m.nb_ebands + i] - p1_adj.min(p2_adj);
2981            let e_diff = e_diff.max(0.0);
2982
2983            let mut r = 2.0 * (-e_diff).exp2();
2984            if lm == 3 {
2985                r *= std::f32::consts::SQRT_2;
2986            }
2987            r = r.min(thresh);
2988            r *= sqrt_1;
2989
2990            let x_offset = c * size + ((m.e_bands[i] as usize) << lm);
2991            let mut renormalize = false;
2992            for k in 0..(1 << lm) {
2993                if (collapse_masks[i * channels + c] & (1 << k)) == 0 {
2994                    for j in 0..n0 {
2995                        seed = celt_lcg_rand(seed);
2996                        x_buf[x_offset + (j << lm) + k] = if (seed & 0x8000) != 0 { r } else { -r };
2997                    }
2998                    renormalize = true;
2999                }
3000            }
3001            if renormalize {
3002                renormalise_vector(&mut x_buf[x_offset..x_offset + (n0 << lm)], n0 << lm, 1.0);
3003            }
3004        }
3005    }
3006    seed
3007}