Skip to main content

opus_rs/
bands.rs

1use crate::modes::CeltMode;
2use crate::pvq::*;
3use crate::range_coder::RangeCoder;
4use crate::rate::{BITRES, bits2pulses, get_pulses, pulses2bits};
5use crate::tell_frac_inline;
6
7const MIN_STEREO_ENERGY: f32 = 1e-10;
8
9pub struct BandCtx<'a> {
10    pub encode: bool,
11    pub m: &'a CeltMode,
12    pub i: usize,
13    pub band_e: &'a [f32],
14    pub rc: &'a mut RangeCoder,
15    pub spread: i32,
16    pub remaining_bits: i32,
17    pub resynth: bool,
18    pub tf_change: i32,
19    pub intensity: usize,
20    pub theta_round: i32,
21    pub avoid_split_noise: bool,
22    pub arch: i32,
23    pub disable_inv: bool,
24    pub seed: u32,
25}
26
27#[inline]
28fn bitexact_cos(x: i16) -> i16 {
29    #[inline(always)]
30    fn frac_mul16(a: i16, b: i16) -> i16 {
31        ((16384i32 + (a as i32) * (b as i32)) >> 15) as i16
32    }
33
34    let tmp = (4096i32 + (x as i32) * (x as i32)) >> 13;
35    let x2 = tmp as i16;
36    let x2 = (32767 - x2 as i32
37        + frac_mul16(x2, -7651 + frac_mul16(x2, 8277 + frac_mul16(-626, x2))) as i32)
38        as i16;
39    1 + x2
40}
41
42#[inline]
43pub fn bitexact_log2tan(isin: i32, icos: i32) -> i32 {
44    let ec_ilog = |x: u32| -> i32 {
45        if x == 0 {
46            0
47        } else {
48            32 - x.leading_zeros() as i32
49        }
50    };
51    let lc = ec_ilog(icos.max(0) as u32);
52    let ls = ec_ilog(isin.max(0) as u32);
53    let icos_shifted = if lc > 0 {
54        icos.max(0) << (15 - lc).max(0)
55    } else {
56        0
57    };
58    let isin_shifted = if ls > 0 {
59        isin.max(0) << (15 - ls).max(0)
60    } else {
61        0
62    };
63    let fract_mul = |a: i32, b: i32| -> i32 { (a * b + 16384) >> 15 };
64    (ls - lc) * (1 << 11) + fract_mul(isin_shifted, fract_mul(isin_shifted, -2597) + 7932)
65        - fract_mul(icos_shifted, fract_mul(icos_shifted, -2597) + 7932)
66}
67
68#[inline(always)]
69fn celt_sudiv(n: i32, d: i32) -> i32 {
70    n / d
71}
72
73#[inline]
74fn isqrt32(mut val: u32) -> u32 {
75    let mut g = 0u32;
76    let mut bshift = ((32 - val.leading_zeros()) as i32 - 1) >> 1;
77    let mut b = 1u32 << bshift;
78    while bshift >= 0 {
79        let t = (((g << 1) + b) as u64) << bshift;
80        if t <= val as u64 {
81            g += b;
82            val -= t as u32;
83        }
84        b >>= 1;
85        bshift -= 1;
86    }
87    g
88}
89
90pub const SPREAD_NONE: i32 = 0;
91pub const SPREAD_LIGHT: i32 = 1;
92pub const SPREAD_NORMAL: i32 = 2;
93pub const SPREAD_AGGRESSIVE: i32 = 3;
94
95#[allow(clippy::too_many_arguments)]
96pub fn spreading_decision(
97    m: &CeltMode,
98    x_buf: &[f32],
99    average: &mut i32,
100    last_decision: i32,
101    hf_average: &mut i32,
102    tapset_decision: &mut i32,
103    update_hf: bool,
104    end: usize,
105    channels: usize,
106    m_val: usize,
107    spread_weight: &[i32],
108) -> i32 {
109    let mut sum = 0;
110    let mut nb_bands = 0;
111    let n0 = m_val * m.short_mdct_size;
112    let mut hf_sum = 0;
113
114    if m_val * (m.e_bands[end] as usize - m.e_bands[end - 1] as usize) <= 8 {
115        return SPREAD_NONE;
116    }
117
118    for c in 0..channels {
119        for (i, &sw) in spread_weight[..end].iter().enumerate() {
120            let n = m_val * (m.e_bands[i + 1] as usize - m.e_bands[i] as usize);
121            if n <= 8 {
122                continue;
123            }
124
125            let mut tcount = [0; 3];
126            let offset = m_val * m.e_bands[i] as usize + c * n0;
127            let x = &x_buf[offset..offset + n];
128
129            for xv in x.iter().copied() {
130                let x2n = xv * xv * (n as f32);
131                if x2n < 0.25 {
132                    tcount[0] += 1;
133                }
134                if x2n < 0.0625 {
135                    tcount[1] += 1;
136                }
137                if x2n < 0.015625 {
138                    tcount[2] += 1;
139                }
140            }
141
142            if i > m.nb_ebands - 4 {
143                hf_sum += 32 * (tcount[1] + tcount[0]) / (n as i32);
144            }
145
146            let tmp = (if 2 * tcount[2] >= (n as i32) { 1 } else { 0 })
147                + (if 2 * tcount[1] >= (n as i32) { 1 } else { 0 })
148                + (if 2 * tcount[0] >= (n as i32) { 1 } else { 0 });
149            sum += tmp * sw;
150            nb_bands += sw;
151        }
152    }
153
154    if update_hf {
155        if hf_sum > 0 {
156            hf_sum /= (channels as i32) * (4 - m.nb_ebands as i32 + end as i32);
157        }
158        *hf_average = (*hf_average + hf_sum) >> 1;
159        hf_sum = *hf_average;
160
161        if *tapset_decision == 2 {
162            hf_sum += 4;
163        } else if *tapset_decision == 0 {
164            hf_sum -= 4;
165        }
166
167        if hf_sum > 22 {
168            *tapset_decision = 2;
169        } else if hf_sum > 18 {
170            *tapset_decision = 1;
171        } else {
172            *tapset_decision = 0;
173        }
174    }
175
176    if nb_bands == 0 {
177        return SPREAD_NORMAL;
178    }
179
180    let mut sum_scaled = (sum << 8) / nb_bands;
181    sum_scaled = (sum_scaled + *average) >> 1;
182    *average = sum_scaled;
183
184    let sum_final = (3 * sum_scaled + (((3 - last_decision) << 7) + 64) + 2) >> 2;
185
186    if sum_final < 80 {
187        SPREAD_AGGRESSIVE
188    } else if sum_final < 256 {
189        SPREAD_NORMAL
190    } else if sum_final < 384 {
191        SPREAD_LIGHT
192    } else {
193        SPREAD_NONE
194    }
195}
196
197pub fn haar1(x: &mut [f32], n0: usize, stride: usize) {
198    #[cfg(target_arch = "aarch64")]
199    {
200        if stride == 1 && n0 >= 64 {
201            haar1_neon(x, n0);
202        } else {
203            haar1_scalar(x, n0, stride);
204        }
205        return;
206    }
207    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
208    unsafe {
209        if stride == 1 && n0 >= 16 && is_x86_feature_detected!("avx") {
210            haar1_avx(x, n0);
211            return;
212        }
213    }
214    #[cfg(not(target_arch = "aarch64"))]
215    haar1_scalar(x, n0, stride);
216}
217
218#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
219#[target_feature(enable = "avx")]
220unsafe fn haar1_avx(x: &mut [f32], n0: usize) {
221    use std::arch::x86_64::*;
222    let n = n0 >> 1;
223    let scale = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
224    let mut j = 0;
225    while j + 8 <= n {
226        let ptr = x.as_mut_ptr().add(2 * j);
227        let a = _mm256_loadu_ps(ptr);
228        let b = _mm256_loadu_ps(ptr.add(4));
229
230        let t0 = _mm256_unpacklo_ps(a, b);
231        let t1 = _mm256_unpackhi_ps(a, b);
232
233        let even = _mm256_unpacklo_ps(t0, t1);
234        let odd = _mm256_unpackhi_ps(t0, t1);
235
236        let sum = _mm256_mul_ps(_mm256_add_ps(even, odd), scale);
237        let diff = _mm256_mul_ps(_mm256_sub_ps(even, odd), scale);
238
239        let r0 = _mm256_unpacklo_ps(sum, diff);
240        let r1 = _mm256_unpackhi_ps(sum, diff);
241
242        let out0 = _mm256_permute2f128_ps(r0, r1, 0x20);
243        let out1 = _mm256_permute2f128_ps(r0, r1, 0x31);
244
245        _mm256_storeu_ps(ptr, out0);
246        _mm256_storeu_ps(ptr.add(8), out1);
247        j += 8;
248    }
249
250    let scale = std::f32::consts::FRAC_1_SQRT_2;
251    while j < n {
252        let idx1 = 2 * j;
253        let idx2 = 2 * j + 1;
254        let tmp1 = scale * x[idx1];
255        let tmp2 = scale * x[idx2];
256        x[idx1] = tmp1 + tmp2;
257        x[idx2] = tmp1 - tmp2;
258        j += 1;
259    }
260}
261
262#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
263#[inline]
264fn haar1_scalar(x: &mut [f32], n0: usize, stride: usize) {
265    let n = n0 >> 1;
266    let scale = std::f32::consts::FRAC_1_SQRT_2;
267    for i in 0..stride {
268        for j in 0..n {
269            let idx1 = stride * 2 * j + i;
270            let idx2 = stride * (2 * j + 1) + i;
271            let tmp1 = scale * x[idx1];
272            let tmp2 = scale * x[idx2];
273            x[idx1] = tmp1 + tmp2;
274            x[idx2] = tmp1 - tmp2;
275        }
276    }
277}
278
279#[cfg(target_arch = "aarch64")]
280fn haar1_neon(x: &mut [f32], n0: usize) {
281    use std::arch::aarch64::*;
282
283    let n = n0 >> 1;
284    let scale = std::f32::consts::FRAC_1_SQRT_2;
285
286    unsafe {
287        let vscale = vdupq_n_f32(scale);
288
289        let mut j = 0usize;
290        while j + 4 <= n {
291            let idx = 2 * j;
292            let pairs = vld2q_f32(x.as_ptr().add(idx));
293            let even = vmulq_f32(pairs.0, vscale);
294            let odd = vmulq_f32(pairs.1, vscale);
295
296            let out = float32x4x2_t {
297                0: vaddq_f32(even, odd),
298                1: vsubq_f32(even, odd),
299            };
300            vst2q_f32(x.as_mut_ptr().add(idx), out);
301            j += 4;
302        }
303
304        while j < n {
305            let idx1 = 2 * j;
306            let idx2 = idx1 + 1;
307            let tmp1 = scale * x[idx1];
308            let tmp2 = scale * x[idx2];
309            x[idx1] = tmp1 + tmp2;
310            x[idx2] = tmp1 - tmp2;
311            j += 1;
312        }
313    }
314}
315
316#[inline(always)]
317pub fn compute_qn(n: usize, b: i32, offset: i32, pulse_cap: i32, stereo: bool) -> i32 {
318    static EXP2_TABLE8: [i16; 8] = [16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048];
319    let mut n2 = (2 * n as i32) - 1;
320    if stereo && n == 2 {
321        n2 -= 1;
322    }
323    let mut qb = celt_sudiv(b + n2 * offset, n2);
324    qb = qb.min(b - pulse_cap - (4 << BITRES));
325    qb = qb.min(8 << BITRES);
326    if qb < (1i32 << BITRES >> 1) {
327        1
328    } else {
329        let val = EXP2_TABLE8[(qb & 0x7) as usize] as i32;
330        let shift = 14 - (qb >> BITRES);
331        let raw = if (0..32).contains(&shift) {
332            val >> shift
333        } else {
334            0
335        };
336        let qn = (raw + 1) >> 1 << 1;
337        qn.min(256)
338    }
339}
340
341#[cfg(target_arch = "aarch64")]
342#[inline(always)]
343#[allow(unsafe_op_in_unsafe_fn)]
344unsafe fn stereo_itheta_neon(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
345    use std::arch::aarch64::*;
346
347    let mut emid = 1e-15f32;
348    let mut eside = 1e-15f32;
349
350    if stereo {
351        let mut sum_mid = vdupq_n_f32(0.0);
352        let mut sum_side = vdupq_n_f32(0.0);
353        let mut i = 0;
354
355        while i + 16 <= n {
356            let x0 = vld1q_f32(x.as_ptr().add(i));
357            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
358            let x2 = vld1q_f32(x.as_ptr().add(i + 8));
359            let x3 = vld1q_f32(x.as_ptr().add(i + 12));
360            let y0 = vld1q_f32(y.as_ptr().add(i));
361            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
362            let y2 = vld1q_f32(y.as_ptr().add(i + 8));
363            let y3 = vld1q_f32(y.as_ptr().add(i + 12));
364
365            let m0 = vaddq_f32(x0, y0);
366            let m1 = vaddq_f32(x1, y1);
367            let m2 = vaddq_f32(x2, y2);
368            let m3 = vaddq_f32(x3, y3);
369            let s0 = vsubq_f32(x0, y0);
370            let s1 = vsubq_f32(x1, y1);
371            let s2 = vsubq_f32(x2, y2);
372            let s3 = vsubq_f32(x3, y3);
373
374            sum_mid = vfmaq_f32(sum_mid, m0, m0);
375            sum_mid = vfmaq_f32(sum_mid, m1, m1);
376            sum_mid = vfmaq_f32(sum_mid, m2, m2);
377            sum_mid = vfmaq_f32(sum_mid, m3, m3);
378            sum_side = vfmaq_f32(sum_side, s0, s0);
379            sum_side = vfmaq_f32(sum_side, s1, s1);
380            sum_side = vfmaq_f32(sum_side, s2, s2);
381            sum_side = vfmaq_f32(sum_side, s3, s3);
382
383            i += 16;
384        }
385
386        while i + 8 <= n {
387            let x0 = vld1q_f32(x.as_ptr().add(i));
388            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
389            let y0 = vld1q_f32(y.as_ptr().add(i));
390            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
391
392            let m0 = vaddq_f32(x0, y0);
393            let m1 = vaddq_f32(x1, y1);
394            let s0 = vsubq_f32(x0, y0);
395            let s1 = vsubq_f32(x1, y1);
396
397            sum_mid = vfmaq_f32(sum_mid, m0, m0);
398            sum_mid = vfmaq_f32(sum_mid, m1, m1);
399            sum_side = vfmaq_f32(sum_side, s0, s0);
400            sum_side = vfmaq_f32(sum_side, s1, s1);
401
402            i += 8;
403        }
404
405        while i + 4 <= n {
406            let x0 = vld1q_f32(x.as_ptr().add(i));
407            let y0 = vld1q_f32(y.as_ptr().add(i));
408            let m0 = vaddq_f32(x0, y0);
409            let s0 = vsubq_f32(x0, y0);
410            sum_mid = vfmaq_f32(sum_mid, m0, m0);
411            sum_side = vfmaq_f32(sum_side, s0, s0);
412            i += 4;
413        }
414
415        emid += vaddvq_f32(sum_mid);
416        eside += vaddvq_f32(sum_side);
417
418        for j in i..n {
419            let m = x[j] + y[j];
420            let s = x[j] - y[j];
421            emid += m * m;
422            eside += s * s;
423        }
424    } else {
425        let mut sum_mid = vdupq_n_f32(0.0);
426        let mut sum_side = vdupq_n_f32(0.0);
427        let mut i = 0;
428
429        while i + 16 <= n {
430            let x0 = vld1q_f32(x.as_ptr().add(i));
431            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
432            let x2 = vld1q_f32(x.as_ptr().add(i + 8));
433            let x3 = vld1q_f32(x.as_ptr().add(i + 12));
434            let y0 = vld1q_f32(y.as_ptr().add(i));
435            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
436            let y2 = vld1q_f32(y.as_ptr().add(i + 8));
437            let y3 = vld1q_f32(y.as_ptr().add(i + 12));
438
439            sum_mid = vfmaq_f32(sum_mid, x0, x0);
440            sum_mid = vfmaq_f32(sum_mid, x1, x1);
441            sum_mid = vfmaq_f32(sum_mid, x2, x2);
442            sum_mid = vfmaq_f32(sum_mid, x3, x3);
443            sum_side = vfmaq_f32(sum_side, y0, y0);
444            sum_side = vfmaq_f32(sum_side, y1, y1);
445            sum_side = vfmaq_f32(sum_side, y2, y2);
446            sum_side = vfmaq_f32(sum_side, y3, y3);
447
448            i += 16;
449        }
450
451        while i + 8 <= n {
452            let x0 = vld1q_f32(x.as_ptr().add(i));
453            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
454            let y0 = vld1q_f32(y.as_ptr().add(i));
455            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
456
457            sum_mid = vfmaq_f32(sum_mid, x0, x0);
458            sum_mid = vfmaq_f32(sum_mid, x1, x1);
459            sum_side = vfmaq_f32(sum_side, y0, y0);
460            sum_side = vfmaq_f32(sum_side, y1, y1);
461
462            i += 8;
463        }
464
465        while i + 4 <= n {
466            let x0 = vld1q_f32(x.as_ptr().add(i));
467            let y0 = vld1q_f32(y.as_ptr().add(i));
468            sum_mid = vfmaq_f32(sum_mid, x0, x0);
469            sum_side = vfmaq_f32(sum_side, y0, y0);
470            i += 4;
471        }
472
473        emid += vaddvq_f32(sum_mid);
474        eside += vaddvq_f32(sum_side);
475
476        for j in i..n {
477            emid += x[j] * x[j];
478            eside += y[j] * y[j];
479        }
480    }
481
482    let mid = emid.sqrt();
483    let side = eside.sqrt();
484    let theta_norm = celt_atan2p_norm(side, mid);
485    (0.5 + 16384.0 * theta_norm) as i32
486}
487
488#[inline(always)]
489#[cfg(target_arch = "aarch64")]
490pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
491    unsafe { stereo_itheta_neon(x, y, stereo, n) }
492}
493
494#[inline(always)]
495#[cfg(not(target_arch = "aarch64"))]
496pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
497    #[cfg(target_arch = "aarch64")]
498    unsafe {
499        return stereo_itheta_neon(x, y, stereo, n);
500    }
501    #[cfg(not(target_arch = "aarch64"))]
502    {
503        let mut emid = 1e-15f32;
504        let mut eside = 1e-15f32;
505        if stereo {
506            for i in 0..n {
507                let m = x[i] + y[i];
508                let s = x[i] - y[i];
509                emid += m * m;
510                eside += s * s;
511            }
512        } else {
513            for i in 0..n {
514                emid += x[i] * x[i];
515                eside += y[i] * y[i];
516            }
517        }
518        let mid = emid.sqrt();
519        let side = eside.sqrt();
520        let theta_norm = celt_atan2p_norm(side, mid);
521        (0.5 + 16384.0 * theta_norm) as i32
522    }
523}
524
525#[inline(always)]
526fn celt_atan2p_norm(y: f32, x: f32) -> f32 {
527    #[inline(always)]
528    fn atan_norm(x: f32) -> f32 {
529        const ATAN2_2_OVER_PI: f32 = std::f32::consts::FRAC_2_PI;
530        const A03: f32 = -3.333_166e-1_f32;
531        const A05: f32 = 1.996_270_4e-1_f32;
532        const A07: f32 = -1.397_658_3e-1_f32;
533        const A09: f32 = 9.794_234_e-2_f32;
534        const A11: f32 = -5.777_359_e-2_f32;
535        const A13: f32 = 2.304_014e-2_f32;
536        const A15: f32 = -4.355_406e-3_f32;
537        let x2 = x * x;
538        ATAN2_2_OVER_PI
539            * x
540            * (1.0
541                + x2 * (A03
542                    + x2 * (A05 + x2 * (A07 + x2 * (A09 + x2 * (A11 + x2 * (A13 + x2 * A15)))))))
543    }
544    if x * x + y * y < 1e-18 {
545        return 0.0;
546    }
547    if y < x {
548        atan_norm(y / x)
549    } else {
550        1.0 - atan_norm(x / y)
551    }
552}
553
554pub struct SplitCtx {
555    pub inv: bool,
556    pub imid: i32,
557    pub iside: i32,
558    pub delta: i32,
559    pub itheta: i32,
560    pub qalloc: i32,
561}
562
563#[allow(clippy::too_many_arguments)]
564#[inline(always)]
565pub fn compute_theta(
566    ctx: &mut BandCtx,
567    sctx: &mut SplitCtx,
568    x: &[f32],
569    y: &[f32],
570    n: usize,
571    b: &mut i32,
572    b_blocks: i32,
573    b0: i32,
574    lm: i32,
575    stereo: bool,
576    fill: &mut u32,
577) {
578    let pulse_cap = ctx.m.log_n[ctx.i] as i32 + (lm << BITRES);
579    let offset = (pulse_cap >> 1) - if stereo && n == 2 { 16 } else { 4 };
580    let mut qn = compute_qn(n, *b, offset, pulse_cap, stereo);
581
582    if stereo && ctx.i >= ctx.intensity {
583        qn = 1;
584    }
585
586    let mut itheta = 0;
587    if ctx.encode {
588        itheta = stereo_itheta(x, y, stereo, n);
589    }
590
591    let tell_start = tell_frac_inline!(ctx.rc);
592
593    if qn != 1 {
594        if ctx.encode {
595            if !stereo || ctx.theta_round == 0 {
596                itheta = (itheta * qn + 8192) >> 14;
597                if !stereo && ctx.avoid_split_noise && itheta > 0 && itheta < qn {
598                    let unquantized = (itheta * 16384) / qn;
599                    let imid = bitexact_cos(unquantized as i16) as i32;
600                    let iside = bitexact_cos((16384 - unquantized) as i16) as i32;
601                    let delta =
602                        (((n as i32 - 1) << 7) * bitexact_log2tan(iside, imid) + 16384) >> 15;
603                    if delta > *b {
604                        itheta = qn;
605                    } else if delta < -*b {
606                        itheta = 0;
607                    }
608                }
609            } else {
610                let bias = if itheta > 8192 {
611                    32767 / qn
612                } else {
613                    -32767 / qn
614                };
615                let down = (itheta * qn + bias) >> 14;
616                let down = down.clamp(0, qn - 1);
617                if ctx.theta_round < 0 {
618                    itheta = down;
619                } else {
620                    itheta = down + 1;
621                }
622            }
623        }
624
625        if stereo && n > 2 {
626            let p0 = 3;
627            let x0 = qn / 2;
628            let ft = p0 * (x0 + 1) + x0;
629            if ctx.encode {
630                let fl = if itheta <= x0 {
631                    p0 * itheta
632                } else {
633                    (itheta - 1 - x0) + (x0 + 1) * p0
634                };
635                let fh = if itheta <= x0 {
636                    p0 * (itheta + 1)
637                } else {
638                    (itheta - x0) + (x0 + 1) * p0
639                };
640                ctx.rc.encode(fl as u32, fh as u32, ft as u32);
641            } else {
642                let fs = ctx.rc.decode(ft as u32);
643                if fs < (x0 + 1) as u32 * p0 as u32 {
644                    itheta = fs as i32 / p0;
645                } else {
646                    itheta = (x0 + 1) + (fs as i32 - (x0 + 1) * p0);
647                }
648                let fl = if itheta <= x0 {
649                    p0 * itheta
650                } else {
651                    (itheta - 1 - x0) + (x0 + 1) * p0
652                };
653                let fh = if itheta <= x0 {
654                    p0 * (itheta + 1)
655                } else {
656                    (itheta - x0) + (x0 + 1) * p0
657                };
658                ctx.rc.update(fl as u32, fh as u32, ft as u32);
659            }
660        } else if b0 > 1 || stereo {
661            if ctx.encode {
662                ctx.rc.enc_uint(itheta as u32, (qn + 1) as u32);
663            } else {
664                itheta = ctx.rc.dec_uint((qn + 1) as u32) as i32;
665            }
666        } else {
667            let ft = ((qn >> 1) + 1) * ((qn >> 1) + 1);
668            if ctx.encode {
669                let fs = if itheta <= (qn >> 1) {
670                    itheta + 1
671                } else {
672                    qn + 1 - itheta
673                };
674                let fl = if itheta <= (qn >> 1) {
675                    (itheta * (itheta + 1)) >> 1
676                } else {
677                    ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1)
678                };
679                ctx.rc.encode(fl as u32, (fl + fs) as u32, ft as u32);
680            } else {
681                let fm = ctx.rc.decode(ft as u32) as i32;
682                if fm < (((qn >> 1) * ((qn >> 1) + 1)) >> 1) {
683                    itheta = (isqrt32((8 * fm + 1) as u32) as i32 - 1) >> 1;
684                    let fl = (itheta * (itheta + 1)) >> 1;
685                    let fs = itheta + 1;
686                    ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
687                } else {
688                    itheta = (2 * (qn + 1) - isqrt32((8 * (ft - fm - 1) + 1) as u32) as i32) >> 1;
689                    let fs = qn + 1 - itheta;
690                    let fl = ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1);
691                    ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
692                }
693            }
694        }
695        itheta = (itheta as u32 * 16384 / qn as u32) as i32;
696        if ctx.encode && stereo {
697            let (bx, by) = (x.as_ptr() as *mut f32, y.as_ptr() as *mut f32);
698            let (sx, sy) = unsafe {
699                (
700                    std::slice::from_raw_parts_mut(bx, n),
701                    std::slice::from_raw_parts_mut(by, n),
702                )
703            };
704            if itheta == 0 {
705                intensity_stereo(ctx.m, sx, sy, ctx.band_e, ctx.i, n);
706            } else {
707                stereo_split(sx, sy, n);
708            }
709        }
710    } else if stereo {
711        if ctx.encode {
712            let inv = itheta > 8192 && !ctx.disable_inv;
713            let (bx, by) = (x.as_ptr() as *mut f32, y.as_ptr() as *mut f32);
714            let (sx, sy) = unsafe {
715                (
716                    std::slice::from_raw_parts_mut(bx, n),
717                    std::slice::from_raw_parts_mut(by, n),
718                )
719            };
720            if inv {
721                for yv in sy.iter_mut() {
722                    *yv = -*yv;
723                }
724            }
725            intensity_stereo(ctx.m, sx, sy, ctx.band_e, ctx.i, n);
726            if *b > (2 << BITRES) && ctx.remaining_bits > (2 << BITRES) {
727                ctx.rc.encode_bit_logp(inv, 2);
728            }
729            itheta = 0;
730            sctx.inv = inv;
731        } else {
732            if *b > (2 << BITRES) && ctx.remaining_bits > (2 << BITRES) {
733                sctx.inv = ctx.rc.decode_bit_logp(2);
734            } else {
735                sctx.inv = false;
736            }
737            if ctx.disable_inv {
738                sctx.inv = false;
739            }
740            itheta = 0;
741        }
742    }
743
744    sctx.itheta = itheta;
745
746    sctx.qalloc = tell_frac_inline!(ctx.rc) - tell_start;
747    *b -= sctx.qalloc; // matches C: *b -= qalloc
748
749    if itheta == 0 {
750        sctx.imid = 32767;
751        sctx.iside = 0;
752        sctx.delta = -16384;
753        *fill &= (1 << b_blocks) - 1;
754    } else if itheta == 16384 {
755        sctx.imid = 0;
756        sctx.iside = 32767;
757        sctx.delta = 16384;
758        *fill &= ((1 << b_blocks) - 1) << b_blocks;
759    } else {
760        let imid = bitexact_cos(itheta as i16);
761        sctx.imid = imid as i32;
762        let iside = bitexact_cos((16384 - itheta) as i16);
763        sctx.iside = iside as i32;
764        sctx.delta =
765            (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
766    }
767}
768
769#[inline(always)]
770fn quant_partition_n2_encode(
771    ctx: &mut BandCtx,
772    x: &mut [f32],
773    b: i32,
774    b_blocks: i32,
775    lowband: Option<&mut [f32]>,
776    lm: i32,
777    gain: f32,
778    fill: u32,
779) -> u32 {
780    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
781    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
782    ctx.remaining_bits -= curr_bits;
783
784    while ctx.remaining_bits < 0 && q > 0 {
785        ctx.remaining_bits += curr_bits;
786        q -= 1;
787        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
788        ctx.remaining_bits -= curr_bits;
789    }
790
791    if q != 0 {
792        let k = get_pulses(q);
793        alg_quant(x, 2, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
794    } else {
795        let has_lowband = lowband.is_some();
796        if has_lowband {
797            fill
798        } else {
799            (1u32 << b_blocks) - 1
800        }
801    }
802}
803
804#[inline(always)]
805fn quant_partition_n4_encode(
806    ctx: &mut BandCtx,
807    x: &mut [f32],
808    b: i32,
809    b_blocks: i32,
810    lowband: Option<&mut [f32]>,
811    lm: i32,
812    gain: f32,
813    fill: u32,
814) -> u32 {
815    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
816    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
817    ctx.remaining_bits -= curr_bits;
818
819    while ctx.remaining_bits < 0 && q > 0 {
820        ctx.remaining_bits += curr_bits;
821        q -= 1;
822        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
823        ctx.remaining_bits -= curr_bits;
824    }
825
826    if q != 0 {
827        let k = get_pulses(q);
828        alg_quant(x, 4, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
829    } else {
830        let has_lowband = lowband.is_some();
831        if has_lowband {
832            fill
833        } else {
834            (1u32 << b_blocks) - 1
835        }
836    }
837}
838
839#[inline(always)]
840fn quant_partition_n8_encode(
841    ctx: &mut BandCtx,
842    x: &mut [f32],
843    b: i32,
844    b_blocks: i32,
845    lowband: Option<&mut [f32]>,
846    lm: i32,
847    gain: f32,
848    fill: u32,
849) -> u32 {
850    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
851    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
852    ctx.remaining_bits -= curr_bits;
853
854    while ctx.remaining_bits < 0 && q > 0 {
855        ctx.remaining_bits += curr_bits;
856        q -= 1;
857        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
858        ctx.remaining_bits -= curr_bits;
859    }
860
861    if q != 0 {
862        let k = get_pulses(q);
863        alg_quant(x, 8, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
864    } else {
865        let has_lowband = lowband.is_some();
866        if has_lowband {
867            fill
868        } else {
869            (1u32 << b_blocks) - 1
870        }
871    }
872}
873
874#[inline(always)]
875#[allow(clippy::too_many_arguments)]
876fn quant_partition_direct_encode(
877    ctx: &mut BandCtx,
878    x: &mut [f32],
879    n: usize,
880    b: i32,
881    b_blocks: i32,
882    lowband: Option<&mut [f32]>,
883    lm: i32,
884    gain: f32,
885    fill: u32,
886) -> u32 {
887    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
888    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
889    ctx.remaining_bits -= curr_bits;
890
891    while ctx.remaining_bits < 0 && q > 0 {
892        ctx.remaining_bits += curr_bits;
893        q -= 1;
894        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
895        ctx.remaining_bits -= curr_bits;
896    }
897
898    if q != 0 {
899        let k = get_pulses(q);
900        alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
901    } else {
902        let has_lowband = lowband.is_some();
903        if has_lowband {
904            fill
905        } else {
906            (1u32 << b_blocks) - 1
907        }
908    }
909}
910
911#[inline(always)]
912#[allow(clippy::too_many_arguments)]
913fn quant_partition_encode(
914    ctx: &mut BandCtx,
915    x: &mut [f32],
916    n: usize,
917    b: i32,
918    b_blocks: i32,
919    lowband: Option<&mut [f32]>,
920    lm: i32,
921    gain: f32,
922    fill: u32,
923) -> u32 {
924    // N==2 can never split (should_split requires n>2), dispatch immediately
925    if n == 2 {
926        return quant_partition_n2_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
927    }
928
929    // Check split condition FIRST (matching C's quant_partition which checks this before dispatch)
930    let should_split = if lm >= 0 && n > 2 {
931        let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
932        let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) };
933        if cache_base >= 0 {
934            let cache_base = cache_base as usize;
935            let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
936            let max_q = unsafe { *cache_ptr } as usize;
937            b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
938        } else {
939            false
940        }
941    } else {
942        false
943    };
944
945    if should_split {
946        let mut sctx = SplitCtx {
947            inv: false,
948            imid: 0,
949            iside: 0,
950            delta: 0,
951            itheta: 0,
952            qalloc: 0,
953        };
954        let mut b_mut = b;
955        let mut fill_mut = fill;
956        let mid = n / 2;
957        let lm = lm - 1;
958        let b0 = b_blocks;
959        if b_blocks == 1 {
960            fill_mut = (fill_mut & 1) | (fill_mut << 1);
961        }
962        let b_blocks = (b_blocks + 1) >> 1;
963        let (x_mid, x_side) = x.split_at_mut(mid);
964
965        compute_theta(
966            ctx,
967            &mut sctx,
968            x_mid,
969            x_side,
970            mid,
971            &mut b_mut,
972            b_blocks,
973            b0,
974            lm,
975            false,
976            &mut fill_mut,
977        );
978
979        ctx.remaining_bits -= sctx.qalloc;
980        let mut delta = sctx.delta;
981        /* Give more bits to low-energy MDCTs than they would otherwise deserve */
982        if b0 > 1 && (sctx.itheta & 0x3fff) != 0 {
983            if sctx.itheta > 8192 {
984                delta -= delta >> (4 - lm);
985            } else {
986                delta = 0.min(delta + ((mid as i32) << BITRES >> (5 - lm)));
987            }
988        }
989        let mbits = (0).max((b_mut - delta) / 2).min(b_mut);
990        let mut sbits = b_mut - mbits;
991        let mut mbits = mbits;
992
993        let mut rebalance = ctx.remaining_bits;
994        let mut cm;
995        let mid_gain = gain * (sctx.imid as f32 / 32768.0);
996        let side_gain = gain * (sctx.iside as f32 / 32768.0);
997
998        if mbits >= sbits {
999            if let Some(lb) = lowband {
1000                let (lb_mid, lb_side) = lb.split_at_mut(mid);
1001                cm = quant_partition_encode(
1002                    ctx,
1003                    x_mid,
1004                    mid,
1005                    mbits,
1006                    b_blocks,
1007                    Some(lb_mid),
1008                    lm,
1009                    mid_gain,
1010                    fill_mut,
1011                );
1012                rebalance = mbits - (rebalance - ctx.remaining_bits);
1013                if rebalance > (3 << 3) && sctx.itheta != 0 {
1014                    sbits += rebalance - (3 << 3);
1015                }
1016                cm |= quant_partition_encode(
1017                    ctx,
1018                    x_side,
1019                    mid,
1020                    sbits,
1021                    b_blocks,
1022                    Some(lb_side),
1023                    lm,
1024                    side_gain,
1025                    fill_mut >> b_blocks,
1026                ) << (b0 >> 1);
1027            } else {
1028                cm = quant_partition_encode(
1029                    ctx, x_mid, mid, mbits, b_blocks, None, lm, mid_gain, fill_mut,
1030                );
1031                rebalance = mbits - (rebalance - ctx.remaining_bits);
1032                if rebalance > (3 << 3) && sctx.itheta != 0 {
1033                    sbits += rebalance - (3 << 3);
1034                }
1035                cm |= quant_partition_encode(
1036                    ctx,
1037                    x_side,
1038                    mid,
1039                    sbits,
1040                    b_blocks,
1041                    None,
1042                    lm,
1043                    side_gain,
1044                    fill_mut >> b_blocks,
1045                ) << (b0 >> 1);
1046            }
1047        } else if let Some(lb) = lowband {
1048            let (lb_mid, lb_side) = lb.split_at_mut(mid);
1049            cm = quant_partition_encode(
1050                ctx,
1051                x_side,
1052                mid,
1053                sbits,
1054                b_blocks,
1055                Some(lb_side),
1056                lm,
1057                side_gain,
1058                fill_mut >> b_blocks,
1059            ) << (b0 >> 1);
1060            rebalance = sbits - (rebalance - ctx.remaining_bits);
1061            if rebalance > (3 << 3) && sctx.itheta != 16384 {
1062                mbits += rebalance - (3 << 3);
1063            }
1064            cm |= quant_partition_encode(
1065                ctx,
1066                x_mid,
1067                mid,
1068                mbits,
1069                b_blocks,
1070                Some(lb_mid),
1071                lm,
1072                mid_gain,
1073                fill_mut,
1074            );
1075        } else {
1076            cm = quant_partition_encode(
1077                ctx,
1078                x_side,
1079                mid,
1080                sbits,
1081                b_blocks,
1082                None,
1083                lm,
1084                side_gain,
1085                fill_mut >> b_blocks,
1086            ) << (b0 >> 1);
1087            rebalance = sbits - (rebalance - ctx.remaining_bits);
1088            if rebalance > (3 << 3) && sctx.itheta != 16384 {
1089                mbits += rebalance - (3 << 3);
1090            }
1091            cm |= quant_partition_encode(
1092                ctx, x_mid, mid, mbits, b_blocks, None, lm, mid_gain, fill_mut,
1093            );
1094        }
1095        cm
1096    } else {
1097        // No split — dispatch to small-N specialized encoders or direct path
1098        if n == 4 {
1099            return quant_partition_n4_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1100        }
1101        if n == 8 {
1102            return quant_partition_n8_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1103        }
1104        if n == 16 {
1105            return quant_partition_direct_encode(ctx, x, n, b, b_blocks, lowband, lm, gain, fill);
1106        }
1107        let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1108        let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1109        ctx.remaining_bits -= curr_bits;
1110
1111        while ctx.remaining_bits < 0 && q > 0 {
1112            ctx.remaining_bits += curr_bits;
1113            q -= 1;
1114            curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1115            ctx.remaining_bits -= curr_bits;
1116        }
1117
1118        if q != 0 {
1119            let k = get_pulses(q);
1120            alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1121        } else if lowband.is_some() {
1122            fill
1123        } else {
1124            (1 << b_blocks) - 1
1125        }
1126    }
1127}
1128
1129#[inline(always)]
1130#[allow(clippy::too_many_arguments)]
1131pub fn quant_partition(
1132    ctx: &mut BandCtx,
1133    x: &mut [f32],
1134    n: usize,
1135    b: i32,
1136    b_blocks: i32,
1137    lowband: Option<&mut [f32]>,
1138    lm: i32,
1139    gain: f32,
1140    fill: u32,
1141) -> u32 {
1142    /* Check split condition FIRST, before dispatching to specialized handlers.
1143    This matches the C code which checks this at the top of quant_partition. */
1144    let should_split = if lm >= 0 && n > 2 {
1145        let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
1146        let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) };
1147        if cache_base >= 0 {
1148            let cache_base = cache_base as usize;
1149            let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
1150            let max_q = unsafe { *cache_ptr } as usize;
1151            b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
1152        } else {
1153            false
1154        }
1155    } else {
1156        false
1157    };
1158    if should_split {
1159        let mut sctx = SplitCtx {
1160            inv: false,
1161            imid: 0,
1162            iside: 0,
1163            delta: 0,
1164            itheta: 0,
1165            qalloc: 0,
1166        };
1167        let mut b_mut = b;
1168        let mut fill_mut = fill;
1169        let mid = n / 2;
1170        let lm = lm - 1;
1171        let b0 = b_blocks; // Save original B0
1172        if b_blocks == 1 {
1173            fill_mut = (fill_mut & 1) | (fill_mut << 1);
1174        }
1175        let b_blocks = (b_blocks + 1) >> 1;
1176        let (x_mid, x_side) = x.split_at_mut(mid);
1177
1178        compute_theta(
1179            ctx,
1180            &mut sctx,
1181            x_mid,
1182            x_side,
1183            mid,
1184            &mut b_mut,
1185            b_blocks,
1186            b0,
1187            lm,
1188            false,
1189            &mut fill_mut,
1190        );
1191
1192        ctx.remaining_bits -= sctx.qalloc;
1193        let mut delta = sctx.delta;
1194        /* Give more bits to low-energy MDCTs than they would otherwise deserve
1195        (matches C quant_partition's B0>1 adjustment) */
1196        if b0 > 1 && (sctx.itheta & 0x3fff) != 0 {
1197            if sctx.itheta > 8192 {
1198                delta -= delta >> (4 - lm);
1199            } else {
1200                delta = 0.min(delta + ((mid as i32) << BITRES >> (5 - lm)));
1201            }
1202        }
1203        let mbits = (0).max((b_mut - delta) / 2).min(b_mut);
1204        let mut sbits = b_mut - mbits;
1205        let mut mbits = mbits;
1206
1207        let mut rebalance = ctx.remaining_bits;
1208        let mut cm;
1209
1210        if mbits >= sbits {
1211            if let Some(lb) = lowband {
1212                let (lb_mid, lb_side) = lb.split_at_mut(mid);
1213                cm = quant_partition(
1214                    ctx,
1215                    x_mid,
1216                    mid,
1217                    mbits,
1218                    b_blocks,
1219                    Some(lb_mid),
1220                    lm,
1221                    gain * (sctx.imid as f32 / 32768.0),
1222                    fill_mut,
1223                );
1224                rebalance = mbits - (rebalance - ctx.remaining_bits);
1225                if rebalance > (3 << 3) && sctx.itheta != 0 {
1226                    sbits += rebalance - (3 << 3);
1227                }
1228                cm |= quant_partition(
1229                    ctx,
1230                    x_side,
1231                    mid,
1232                    sbits,
1233                    b_blocks,
1234                    Some(lb_side),
1235                    lm,
1236                    gain * (sctx.iside as f32 / 32768.0),
1237                    fill_mut >> b_blocks,
1238                ) << (b0 >> 1);
1239            } else {
1240                cm = quant_partition(
1241                    ctx,
1242                    x_mid,
1243                    mid,
1244                    mbits,
1245                    b_blocks,
1246                    None,
1247                    lm,
1248                    gain * (sctx.imid as f32 / 32768.0),
1249                    fill_mut,
1250                );
1251                rebalance = mbits - (rebalance - ctx.remaining_bits);
1252                if rebalance > (3 << 3) && sctx.itheta != 0 {
1253                    sbits += rebalance - (3 << 3);
1254                }
1255                cm |= quant_partition(
1256                    ctx,
1257                    x_side,
1258                    mid,
1259                    sbits,
1260                    b_blocks,
1261                    None,
1262                    lm,
1263                    gain * (sctx.iside as f32 / 32768.0),
1264                    fill_mut >> b_blocks,
1265                ) << (b0 >> 1);
1266            }
1267        } else if let Some(lb) = lowband {
1268            let (lb_mid, lb_side) = lb.split_at_mut(mid);
1269            cm = quant_partition(
1270                ctx,
1271                x_side,
1272                mid,
1273                sbits,
1274                b_blocks,
1275                Some(lb_side),
1276                lm,
1277                gain * (sctx.iside as f32 / 32768.0),
1278                fill_mut >> b_blocks,
1279            ) << (b0 >> 1);
1280            rebalance = sbits - (rebalance - ctx.remaining_bits);
1281            if rebalance > (3 << 3) && sctx.itheta != 16384 {
1282                mbits += rebalance - (3 << 3);
1283            }
1284            cm |= quant_partition(
1285                ctx,
1286                x_mid,
1287                mid,
1288                mbits,
1289                b_blocks,
1290                Some(lb_mid),
1291                lm,
1292                gain * (sctx.imid as f32 / 32768.0),
1293                fill_mut,
1294            );
1295        } else {
1296            cm = quant_partition(
1297                ctx,
1298                x_side,
1299                mid,
1300                sbits,
1301                b_blocks,
1302                None,
1303                lm,
1304                gain * (sctx.iside as f32 / 32768.0),
1305                fill_mut >> b_blocks,
1306            ) << (b0 >> 1);
1307            rebalance = sbits - (rebalance - ctx.remaining_bits);
1308            if rebalance > (3 << 3) && sctx.itheta != 16384 {
1309                mbits += rebalance - (3 << 3);
1310            }
1311            cm |= quant_partition(
1312                ctx,
1313                x_mid,
1314                mid,
1315                mbits,
1316                b_blocks,
1317                None,
1318                lm,
1319                gain * (sctx.imid as f32 / 32768.0),
1320                fill_mut,
1321            );
1322        }
1323        cm
1324    } else {
1325        let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1326        let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1327        ctx.remaining_bits -= curr_bits;
1328
1329        while ctx.remaining_bits < 0 && q > 0 {
1330            ctx.remaining_bits += curr_bits;
1331            q -= 1;
1332            curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1333            ctx.remaining_bits -= curr_bits;
1334        }
1335
1336        if q != 0 {
1337            let k = get_pulses(q);
1338            if ctx.encode {
1339                alg_quant(
1340                    x,
1341                    n,
1342                    k,
1343                    ctx.spread,
1344                    b_blocks as usize,
1345                    ctx.rc,
1346                    gain,
1347                    ctx.resynth,
1348                )
1349            } else {
1350                alg_unquant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1351            }
1352        } else {
1353            let mut cm = 0u32;
1354            if ctx.resynth {
1355                let cm_mask = (1u32 << b_blocks) - 1;
1356                let fill_masked = fill & cm_mask;
1357                if fill_masked == 0 {
1358                    x[..n].fill(0.0);
1359                } else if let Some(lb) = lowband {
1360                    #[cfg(target_arch = "aarch64")]
1361                    unsafe {
1362                        use std::arch::aarch64::*;
1363                        let n8 = n & !7;
1364                        let mut i = 0;
1365                        while i < n8 {
1366                            let mut vals = [0.0f32; 8];
1367                            for j in 0..8 {
1368                                ctx.seed = celt_lcg_rand(ctx.seed);
1369                                vals[j] = if ctx.seed & 0x8000 != 0 {
1370                                    1.0 / 256.0
1371                                } else {
1372                                    -1.0 / 256.0
1373                                };
1374                            }
1375                            let vnoise = vld1q_f32(vals.as_ptr());
1376                            let vnoise1 = vld1q_f32(vals.as_ptr().add(4));
1377                            let vlb = vld1q_f32(lb.as_ptr().add(i));
1378                            let vlb1 = vld1q_f32(lb.as_ptr().add(i + 4));
1379                            let vres = vaddq_f32(vlb, vnoise);
1380                            let vres1 = vaddq_f32(vlb1, vnoise1);
1381                            vst1q_f32(x.as_mut_ptr().add(i), vres);
1382                            vst1q_f32(x.as_mut_ptr().add(i + 4), vres1);
1383                            i += 8;
1384                        }
1385                        for j in i..n {
1386                            ctx.seed = celt_lcg_rand(ctx.seed);
1387                            x[j] = lb[j]
1388                                + if ctx.seed & 0x8000 != 0 {
1389                                    1.0 / 256.0
1390                                } else {
1391                                    -1.0 / 256.0
1392                                };
1393                        }
1394                    }
1395                    #[cfg(not(target_arch = "aarch64"))]
1396                    {
1397                        for j in 0..n {
1398                            ctx.seed = celt_lcg_rand(ctx.seed);
1399                            x[j] = lb[j]
1400                                + if ctx.seed & 0x8000 != 0 {
1401                                    1.0 / 256.0
1402                                } else {
1403                                    -1.0 / 256.0
1404                                };
1405                        }
1406                    }
1407                    renormalise_vector(x, n, gain);
1408                    cm = fill_masked;
1409                } else {
1410                    for xv in x[..n].iter_mut() {
1411                        ctx.seed = celt_lcg_rand(ctx.seed);
1412                        *xv = ((ctx.seed as i32 >> 20) as f32) / 16384.0;
1413                    }
1414                    renormalise_vector(x, n, gain);
1415                    cm = cm_mask;
1416                }
1417            }
1418            cm
1419        }
1420    }
1421}
1422
1423#[cfg(target_arch = "aarch64")]
1424#[inline(always)]
1425unsafe fn deinterleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1426    let n = n0 * stride;
1427    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1428    let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1429
1430    for i in 0..stride {
1431        let src_offset = i;
1432        let dst_offset = i * n0;
1433        for j in 0..n0 {
1434            tmp[dst_offset + j] = x[j * stride + src_offset];
1435        }
1436    }
1437
1438    x[..n].copy_from_slice(tmp);
1439}
1440
1441pub fn deinterleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1442    let n = n0 * stride;
1443
1444    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1445
1446    let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1447    if hadamard {
1448        let offset = match stride {
1449            2 => 0,
1450            4 => 2,
1451            8 => 6,
1452            16 => 14,
1453            _ => 0,
1454        };
1455        let ordery = &ORDERY_TABLE[offset..offset + stride];
1456        for i in 0..stride {
1457            for j in 0..n0 {
1458                tmp[ordery[i] as usize * n0 + j] = x[j * stride + i];
1459            }
1460        }
1461    } else {
1462        #[cfg(target_arch = "aarch64")]
1463        unsafe {
1464            if n0 >= 4 {
1465                deinterleave_hadamard_neon(x, n0, stride);
1466                return;
1467            }
1468        }
1469        for i in 0..stride {
1470            for j in 0..n0 {
1471                tmp[i * n0 + j] = x[j * stride + i];
1472            }
1473        }
1474    }
1475    x[..n].copy_from_slice(tmp);
1476}
1477
1478#[cfg(target_arch = "aarch64")]
1479#[inline(always)]
1480unsafe fn interleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1481    let n = n0 * stride;
1482    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1483    let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1484
1485    for i in 0..stride {
1486        let src_offset = i * n0;
1487        let dst_offset = i;
1488        for j in 0..n0 {
1489            tmp[j * stride + dst_offset] = x[src_offset + j];
1490        }
1491    }
1492
1493    x[..n].copy_from_slice(tmp);
1494}
1495
1496pub fn interleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1497    let n = n0 * stride;
1498    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1499    let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1500    if hadamard {
1501        let offset = match stride {
1502            2 => 0,
1503            4 => 2,
1504            8 => 6,
1505            16 => 14,
1506            _ => 0,
1507        };
1508        let ordery = &ORDERY_TABLE[offset..offset + stride];
1509        for i in 0..stride {
1510            for j in 0..n0 {
1511                tmp[j * stride + i] = x[ordery[i] as usize * n0 + j];
1512            }
1513        }
1514    } else {
1515        #[cfg(target_arch = "aarch64")]
1516        unsafe {
1517            if n0 >= 4 {
1518                interleave_hadamard_neon(x, n0, stride);
1519                return;
1520            }
1521        }
1522        for i in 0..stride {
1523            for j in 0..n0 {
1524                tmp[j * stride + i] = x[i * n0 + j];
1525            }
1526        }
1527    }
1528    x[..n].copy_from_slice(tmp);
1529}
1530
1531const ORDERY_TABLE: [i32; 30] = [
1532    1, 0, 3, 0, 2, 1, 7, 0, 4, 3, 6, 1, 5, 2, 15, 0, 8, 7, 12, 3, 11, 4, 14, 1, 9, 6, 13, 2, 10, 5,
1533];
1534
1535fn quant_band_n1(
1536    ctx: &mut BandCtx,
1537    x: &mut [f32],
1538    y: Option<&mut [f32]>,
1539    lowband_out: Option<&mut [f32]>,
1540) -> u32 {
1541    let mut sign = 0;
1542    if ctx.remaining_bits >= 1 << BITRES {
1543        if ctx.encode {
1544            sign = if x[0] < 0.0 { 1 } else { 0 };
1545            ctx.rc.enc_bits(sign as u32, 1);
1546        } else {
1547            sign = ctx.rc.dec_bits(1) as i32;
1548        }
1549        ctx.remaining_bits -= 1 << BITRES;
1550    }
1551    if ctx.resynth {
1552        x[0] = if sign != 0 { -1.0 } else { 1.0 };
1553    }
1554    if let Some(y_val) = y {
1555        let mut y_sign = 0;
1556        if ctx.remaining_bits >= 1 << BITRES {
1557            if ctx.encode {
1558                y_sign = if y_val[0] < 0.0 { 1 } else { 0 };
1559                ctx.rc.enc_bits(y_sign as u32, 1);
1560            } else {
1561                y_sign = ctx.rc.dec_bits(1) as i32;
1562            }
1563            ctx.remaining_bits -= 1 << BITRES;
1564        }
1565        if ctx.resynth {
1566            y_val[0] = if y_sign != 0 { -1.0 } else { 1.0 };
1567        }
1568    }
1569    if let Some(l_out) = lowband_out {
1570        l_out[0] = x[0] / 16.0;
1571    }
1572    1
1573}
1574
1575#[allow(clippy::too_many_arguments)]
1576#[inline(always)]
1577pub fn quant_band(
1578    ctx: &mut BandCtx,
1579    x: &mut [f32],
1580    n: usize,
1581    b: i32,
1582    b_blocks: i32,
1583    lowband: Option<&mut [f32]>,
1584    lm: i32,
1585    lowband_out: Option<&mut [f32]>,
1586    gain: f32,
1587    fill: u32,
1588) -> u32 {
1589    let n0 = n;
1590    let b0 = b_blocks;
1591    let long_blocks = b0 == 1;
1592
1593    if n == 1 {
1594        return quant_band_n1(ctx, x, None, lowband_out);
1595    }
1596
1597    let mut b_blocks = b_blocks;
1598    let mut n_b = n / b_blocks as usize;
1599    let mut time_divide = 0;
1600    let mut recombine = 0;
1601    let mut tf_change_local = ctx.tf_change;
1602    let mut fill = fill;
1603
1604    if tf_change_local > 0 {
1605        recombine = tf_change_local;
1606    }
1607
1608    let mut lowband_buf = lowband;
1609
1610    static BIT_INTERLEAVE_TABLE: [u8; 16] = [0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3];
1611
1612    for k in 0..recombine {
1613        if ctx.encode {
1614            haar1(x, n >> k, 1 << k);
1615        }
1616        if let Some(ref mut lb) = lowband_buf {
1617            haar1(lb, n >> k, 1 << k);
1618        }
1619        fill = (BIT_INTERLEAVE_TABLE[(fill & 0xF) as usize] as u32)
1620            | ((BIT_INTERLEAVE_TABLE[(fill >> 4) as usize] as u32) << 2);
1621    }
1622    b_blocks >>= recombine;
1623    n_b <<= recombine;
1624
1625    while n_b & 1 == 0 && tf_change_local < 0 {
1626        if ctx.encode {
1627            haar1(x, n_b, b_blocks as usize);
1628        }
1629        if let Some(ref mut lb) = lowband_buf {
1630            haar1(lb, n_b, b_blocks as usize);
1631        }
1632        fill |= fill << b_blocks;
1633        b_blocks <<= 1;
1634        n_b >>= 1;
1635        time_divide += 1;
1636        tf_change_local += 1;
1637    }
1638
1639    let b0_after = b_blocks;
1640    let n_b0 = n_b;
1641
1642    if b_blocks > 1 {
1643        if ctx.encode {
1644            deinterleave_hadamard(
1645                x,
1646                n_b >> recombine as usize,
1647                (b_blocks << recombine) as usize,
1648                long_blocks,
1649            );
1650        }
1651        if let Some(ref mut lb) = lowband_buf {
1652            deinterleave_hadamard(
1653                lb,
1654                n_b >> recombine as usize,
1655                (b_blocks << recombine) as usize,
1656                long_blocks,
1657            );
1658        }
1659    }
1660
1661    let cm = if ctx.encode {
1662        quant_partition_encode(ctx, x, n, b, b_blocks, lowband_buf, lm, gain, fill)
1663    } else {
1664        quant_partition(ctx, x, n, b, b_blocks, lowband_buf, lm, gain, fill)
1665    };
1666
1667    if ctx.resynth {
1668        let mut cm = cm;
1669
1670        if b_blocks > 1 {
1671            interleave_hadamard(
1672                x,
1673                n_b >> recombine as usize,
1674                (b0_after << recombine) as usize,
1675                long_blocks,
1676            );
1677        }
1678
1679        let mut n_b_undo = n_b0;
1680        let mut b_undo = b0_after;
1681        for _ in 0..time_divide {
1682            b_undo >>= 1;
1683            n_b_undo <<= 1;
1684            cm |= cm >> b_undo;
1685            haar1(x, n_b_undo, b_undo as usize);
1686        }
1687
1688        static BIT_DEINTERLEAVE_TABLE: [u8; 16] = [
1689            0x00, 0x03, 0x0C, 0x0F, 0x30, 0x33, 0x3C, 0x3F, 0xC0, 0xC3, 0xCC, 0xCF, 0xF0, 0xF3,
1690            0xFC, 0xFF,
1691        ];
1692        for k in 0..recombine {
1693            cm = BIT_DEINTERLEAVE_TABLE[cm as usize & 0xF] as u32;
1694            haar1(x, n0 >> k, 1 << k);
1695        }
1696        let mut b_final = b_undo;
1697        b_final <<= recombine;
1698
1699        if let Some(lb_out) = lowband_out {
1700            let scale = (n0 as f32).sqrt();
1701            for j in 0..n0 {
1702                lb_out[j] = scale * x[j];
1703            }
1704        }
1705        cm &= (1u32 << b_final) - 1;
1706        return cm;
1707    }
1708
1709    cm
1710}
1711
1712pub fn stereo_merge(x: &mut [f32], y: &mut [f32], mid: f32, _side: f32, n: usize) {
1713    let mut xp = 0.0f32;
1714    let mut side_e = 0.0f32;
1715    for i in 0..n {
1716        xp += y[i] * x[i];
1717        side_e += y[i] * y[i];
1718    }
1719
1720    xp *= mid;
1721    let el = mid * mid + side_e - 2.0 * xp;
1722    let er = mid * mid + side_e + 2.0 * xp;
1723
1724    if er < 6e-4f32 || el < 6e-4f32 {
1725        y[..n].copy_from_slice(&x[..n]);
1726        return;
1727    }
1728
1729    let lgain = 1.0 / el.sqrt();
1730    let rgain = 1.0 / er.sqrt();
1731
1732    for i in 0..n {
1733        let l = mid * x[i];
1734        let r = y[i];
1735        x[i] = lgain * (l - r);
1736        y[i] = rgain * (l + r);
1737    }
1738}
1739
1740#[inline(always)]
1741fn stereo_split(x: &mut [f32], y: &mut [f32], n: usize) {
1742    let scale = std::f32::consts::FRAC_1_SQRT_2;
1743    for i in 0..n {
1744        let l = scale * x[i];
1745        let r = scale * y[i];
1746        x[i] = l + r;
1747        y[i] = r - l;
1748    }
1749}
1750
1751#[inline(always)]
1752fn intensity_stereo(
1753    m: &CeltMode,
1754    x: &mut [f32],
1755    y: &mut [f32],
1756    band_e: &[f32],
1757    band: usize,
1758    n: usize,
1759) {
1760    let left = band_e[band].max(MIN_STEREO_ENERGY);
1761    let right = band_e[m.nb_ebands + band].max(MIN_STEREO_ENERGY);
1762    let norm = (left * left + right * right).sqrt().max(MIN_STEREO_ENERGY);
1763    let a1 = left / norm;
1764    let a2 = right / norm;
1765    for i in 0..n {
1766        x[i] = a1 * x[i] + a2 * y[i];
1767    }
1768}
1769
1770#[inline(always)]
1771fn special_hybrid_folding(m: &CeltMode, norm: &mut [f32], start: usize, m_val: usize) {
1772    if start + 2 >= m.e_bands.len() {
1773        return;
1774    }
1775    let n1 = m_val * (m.e_bands[start + 1] - m.e_bands[start]) as usize;
1776    let n2 = m_val * (m.e_bands[start + 2] - m.e_bands[start + 1]) as usize;
1777    if n2 <= n1 {
1778        return;
1779    }
1780    let len = n2 - n1;
1781    let src_start = 2 * n1 - n2;
1782    if src_start + len <= norm.len() && n1 + len <= norm.len() {
1783        norm.copy_within(src_start..src_start + len, n1);
1784    }
1785}
1786
1787fn prepare_lowband_views(
1788    norm: &mut [f32],
1789    lowband_scratch_ptr: *mut f32,
1790    allow_lowband_scratch: bool,
1791    effective_lowband: i32,
1792    norm_pos: usize,
1793    n: usize,
1794    want_out: bool,
1795) -> (Option<&mut [f32]>, Option<&mut [f32]>) {
1796    let len = norm.len();
1797    let out_range = if want_out && norm_pos + n <= len {
1798        Some((norm_pos, norm_pos + n))
1799    } else {
1800        None
1801    };
1802
1803    let Some(lb_start) = (if effective_lowband >= 0 {
1804        Some(effective_lowband as usize)
1805    } else {
1806        None
1807    }) else {
1808        let lb_out = out_range.map(|(s, e)| &mut norm[s..e]);
1809        return (None, lb_out);
1810    };
1811    let lb_end = lb_start + n;
1812    if lb_end > len {
1813        let lb_out = out_range.map(|(s, e)| &mut norm[s..e]);
1814        return (None, lb_out);
1815    }
1816
1817    if allow_lowband_scratch {
1818        unsafe {
1819            std::ptr::copy_nonoverlapping(norm.as_ptr().add(lb_start), lowband_scratch_ptr, n)
1820        };
1821        let lb = Some(unsafe { std::slice::from_raw_parts_mut(lowband_scratch_ptr, n) });
1822        let lb_out = out_range.map(|(s, e)| &mut norm[s..e]);
1823        return (lb, lb_out);
1824    }
1825
1826    if let Some((out_start, out_end)) = out_range {
1827        if lb_end <= out_start {
1828            let (left, right) = norm.split_at_mut(out_start);
1829            let lb = Some(&mut left[lb_start..lb_end]);
1830            let lb_out = Some(&mut right[..(out_end - out_start)]);
1831            return (lb, lb_out);
1832        }
1833        if out_end <= lb_start {
1834            let (left, right) = norm.split_at_mut(lb_start);
1835            let lb_out = Some(&mut left[out_start..out_end]);
1836            let lb = Some(&mut right[..n]);
1837            return (lb, lb_out);
1838        }
1839        return (Some(&mut norm[lb_start..lb_end]), None);
1840    }
1841
1842    (Some(&mut norm[lb_start..lb_end]), None)
1843}
1844
1845#[cfg(target_arch = "x86_64")]
1846#[target_feature(enable = "avx2")]
1847#[allow(dead_code)]
1848unsafe fn stereo_merge_avx2(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1849    use std::arch::x86_64::*;
1850
1851    let mut i = 0;
1852
1853    let v_mid = _mm256_set1_ps(mid);
1854    let v_side = _mm256_set1_ps(side);
1855
1856    while i + 15 < n {
1857        let x0 = _mm256_loadu_ps(x.as_ptr().add(i));
1858        let x1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
1859        let y0 = _mm256_loadu_ps(y.as_ptr().add(i));
1860        let y1 = _mm256_loadu_ps(y.as_ptr().add(i + 8));
1861
1862        let x_val0 = _mm256_mul_ps(x0, v_mid);
1863        let x_val1 = _mm256_mul_ps(x1, v_mid);
1864        let y_val0 = _mm256_mul_ps(y0, v_side);
1865        let y_val1 = _mm256_mul_ps(y1, v_side);
1866
1867        let new_x0 = _mm256_sub_ps(x_val0, y_val0);
1868        let new_x1 = _mm256_sub_ps(x_val1, y_val1);
1869        let new_y0 = _mm256_add_ps(x_val0, y_val0);
1870        let new_y1 = _mm256_add_ps(x_val1, y_val1);
1871
1872        _mm256_storeu_ps(x.as_mut_ptr().add(i), new_x0);
1873        _mm256_storeu_ps(x.as_mut_ptr().add(i + 8), new_x1);
1874        _mm256_storeu_ps(y.as_mut_ptr().add(i), new_y0);
1875        _mm256_storeu_ps(y.as_mut_ptr().add(i + 8), new_y1);
1876
1877        i += 16;
1878    }
1879
1880    while i + 7 < n {
1881        let x0 = _mm256_loadu_ps(x.as_ptr().add(i));
1882        let y0 = _mm256_loadu_ps(y.as_ptr().add(i));
1883
1884        let x_val = _mm256_mul_ps(x0, v_mid);
1885        let y_val = _mm256_mul_ps(y0, v_side);
1886
1887        let new_x = _mm256_sub_ps(x_val, y_val);
1888        let new_y = _mm256_add_ps(x_val, y_val);
1889
1890        _mm256_storeu_ps(x.as_mut_ptr().add(i), new_x);
1891        _mm256_storeu_ps(y.as_mut_ptr().add(i), new_y);
1892
1893        i += 8;
1894    }
1895
1896    for j in i..n {
1897        let x_val = x[j] * mid;
1898        let y_val = y[j] * side;
1899        x[j] = x_val - y_val;
1900        y[j] = x_val + y_val;
1901    }
1902}
1903
1904#[allow(dead_code)]
1905#[inline]
1906fn stereo_merge_scalar(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1907    for i in 0..n {
1908        let x_val = x[i] * mid;
1909        let y_val = y[i] * side;
1910        x[i] = x_val - y_val;
1911        y[i] = x_val + y_val;
1912    }
1913}
1914
1915#[cfg(target_arch = "aarch64")]
1916#[allow(dead_code)]
1917fn stereo_merge_neon(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1918    use std::arch::aarch64::*;
1919
1920    unsafe {
1921        let vmid = vdupq_n_f32(mid);
1922        let vside = vdupq_n_f32(side);
1923
1924        let n16 = n & !15;
1925        for i in (0..n16).step_by(16) {
1926            let x0 = vld1q_f32(x.as_ptr().add(i));
1927            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
1928            let x2 = vld1q_f32(x.as_ptr().add(i + 8));
1929            let x3 = vld1q_f32(x.as_ptr().add(i + 12));
1930
1931            let y0 = vld1q_f32(y.as_ptr().add(i));
1932            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
1933            let y2 = vld1q_f32(y.as_ptr().add(i + 8));
1934            let y3 = vld1q_f32(y.as_ptr().add(i + 12));
1935
1936            let xv0 = vmulq_f32(x0, vmid);
1937            let xv1 = vmulq_f32(x1, vmid);
1938            let xv2 = vmulq_f32(x2, vmid);
1939            let xv3 = vmulq_f32(x3, vmid);
1940
1941            let yv0 = vmulq_f32(y0, vside);
1942            let yv1 = vmulq_f32(y1, vside);
1943            let yv2 = vmulq_f32(y2, vside);
1944            let yv3 = vmulq_f32(y3, vside);
1945
1946            vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(xv0, yv0));
1947            vst1q_f32(x.as_mut_ptr().add(i + 4), vsubq_f32(xv1, yv1));
1948            vst1q_f32(x.as_mut_ptr().add(i + 8), vsubq_f32(xv2, yv2));
1949            vst1q_f32(x.as_mut_ptr().add(i + 12), vsubq_f32(xv3, yv3));
1950
1951            vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(xv0, yv0));
1952            vst1q_f32(y.as_mut_ptr().add(i + 4), vaddq_f32(xv1, yv1));
1953            vst1q_f32(y.as_mut_ptr().add(i + 8), vaddq_f32(xv2, yv2));
1954            vst1q_f32(y.as_mut_ptr().add(i + 12), vaddq_f32(xv3, yv3));
1955        }
1956
1957        let n4 = (n & !3) - n16;
1958        for i in (n16..n16 + n4).step_by(4) {
1959            let xv = vld1q_f32(x.as_ptr().add(i));
1960            let yv = vld1q_f32(y.as_ptr().add(i));
1961
1962            let x_val = vmulq_f32(xv, vmid);
1963            let y_val = vmulq_f32(yv, vside);
1964
1965            vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(x_val, y_val));
1966            vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(x_val, y_val));
1967        }
1968
1969        for i in (n16 + n4)..n {
1970            let x_val = x[i] * mid;
1971            let y_val = y[i] * side;
1972            x[i] = x_val - y_val;
1973            y[i] = x_val + y_val;
1974        }
1975    }
1976}
1977
1978#[allow(clippy::too_many_arguments)]
1979#[inline(always)]
1980pub fn quant_band_stereo(
1981    ctx: &mut BandCtx,
1982    x: &mut [f32],
1983    y: &mut [f32],
1984    n: usize,
1985    b: i32,
1986    b_blocks: i32,
1987    lowband: Option<&mut [f32]>,
1988    lm: i32,
1989    lowband_out: Option<&mut [f32]>,
1990    _gain: f32,
1991    fill: u32,
1992) -> u32 {
1993    if n == 1 {
1994        return quant_band_n1(ctx, x, Some(y), lowband_out);
1995    }
1996
1997    if ctx.encode
1998        && (ctx.band_e[ctx.i] < MIN_STEREO_ENERGY
1999            || ctx.band_e[ctx.m.nb_ebands + ctx.i] < MIN_STEREO_ENERGY)
2000    {
2001        if ctx.band_e[ctx.i] > ctx.band_e[ctx.m.nb_ebands + ctx.i] {
2002            y.copy_from_slice(x);
2003        } else {
2004            x.copy_from_slice(y);
2005        }
2006    }
2007
2008    let mut sctx = SplitCtx {
2009        inv: false,
2010        imid: 0,
2011        iside: 0,
2012        delta: 0,
2013        itheta: 0,
2014        qalloc: 0,
2015    };
2016    let mut b_mut = b;
2017    let mut fill_mut = fill;
2018    compute_theta(
2019        ctx,
2020        &mut sctx,
2021        x,
2022        y,
2023        n,
2024        &mut b_mut,
2025        b_blocks,
2026        b_blocks,
2027        lm,
2028        true,
2029        &mut fill_mut,
2030    );
2031
2032    let mid_gain = sctx.imid as f32 / 32768.0;
2033    let side_gain = sctx.iside as f32 / 32768.0;
2034
2035    if n == 2 {
2036        let orig_fill = fill;
2037        let mut mbits = b_mut;
2038        let mut sbits = 0;
2039        if sctx.itheta != 0 && sctx.itheta != 16384 {
2040            sbits = 1 << BITRES;
2041        }
2042        mbits -= sbits;
2043        let c = sctx.itheta > 8192;
2044        ctx.remaining_bits -= sctx.qalloc + sbits;
2045
2046        let mut sign = 0;
2047        if sbits != 0 {
2048            if ctx.encode {
2049                sign = if c {
2050                    if (y[0] * x[1] - y[1] * x[0]) < 0.0 {
2051                        1
2052                    } else {
2053                        0
2054                    }
2055                } else if (x[0] * y[1] - x[1] * y[0]) < 0.0 {
2056                    1
2057                } else {
2058                    0
2059                };
2060                ctx.rc.enc_bits(sign as u32, 1);
2061            } else {
2062                sign = ctx.rc.dec_bits(1) as i32;
2063            }
2064        }
2065        let sign_val = (1 - 2 * sign) as f32;
2066        let cm = if c {
2067            let cm = quant_band(
2068                ctx,
2069                y,
2070                n,
2071                mbits,
2072                b_blocks,
2073                lowband,
2074                lm,
2075                lowband_out,
2076                1.0,
2077                orig_fill,
2078            );
2079            x[0] = -sign_val * y[1];
2080            x[1] = sign_val * y[0];
2081            cm
2082        } else {
2083            let cm = quant_band(
2084                ctx,
2085                x,
2086                n,
2087                mbits,
2088                b_blocks,
2089                lowband,
2090                lm,
2091                lowband_out,
2092                1.0,
2093                orig_fill,
2094            );
2095            y[0] = -sign_val * x[1];
2096            y[1] = sign_val * x[0];
2097            cm
2098        };
2099
2100        if ctx.resynth {
2101            let x0 = x[0];
2102            let x1 = x[1];
2103            let y0 = y[0];
2104            let y1 = y[1];
2105            let mx0 = mid_gain * x0;
2106            let mx1 = mid_gain * x1;
2107            let sy0 = side_gain * y0;
2108            let sy1 = side_gain * y1;
2109            x[0] = mx0 - sy0;
2110            x[1] = mx1 - sy1;
2111            y[0] = mx0 + sy0;
2112            y[1] = mx1 + sy1;
2113        }
2114        return cm;
2115    }
2116
2117    ctx.remaining_bits -= sctx.qalloc;
2118    let mut mbits = (0).max((b_mut - sctx.delta) / 2).min(b_mut);
2119    let mut sbits = b_mut - mbits;
2120
2121    let mut rebalance = ctx.remaining_bits;
2122    let mut cm;
2123
2124    if mbits >= sbits {
2125        cm = quant_band(
2126            ctx,
2127            x,
2128            n,
2129            mbits,
2130            b_blocks,
2131            lowband,
2132            lm,
2133            lowband_out,
2134            1.0,
2135            fill_mut,
2136        );
2137        rebalance = mbits - (rebalance - ctx.remaining_bits);
2138        if rebalance > (3 << 3) && sctx.itheta != 0 {
2139            sbits += rebalance - (3 << 3);
2140        }
2141        cm |= quant_band(
2142            ctx,
2143            y,
2144            n,
2145            sbits,
2146            b_blocks,
2147            None,
2148            lm,
2149            None,
2150            side_gain,
2151            fill_mut >> b_blocks,
2152        );
2153    } else {
2154        cm = quant_band(
2155            ctx,
2156            y,
2157            n,
2158            sbits,
2159            b_blocks,
2160            None,
2161            lm,
2162            None,
2163            side_gain,
2164            fill_mut >> b_blocks,
2165        );
2166        rebalance = sbits - (rebalance - ctx.remaining_bits);
2167        if rebalance > (3 << 3) && sctx.itheta != 16384 {
2168            mbits += rebalance - (3 << 3);
2169        }
2170        cm |= quant_band(
2171            ctx,
2172            x,
2173            n,
2174            mbits,
2175            b_blocks,
2176            lowband,
2177            lm,
2178            lowband_out,
2179            1.0,
2180            fill_mut,
2181        );
2182    }
2183
2184    if ctx.resynth {
2185        stereo_merge(x, y, mid_gain, side_gain, n);
2186        if sctx.inv {
2187            for yv in y[..n].iter_mut() {
2188                *yv = -*yv;
2189            }
2190        }
2191    }
2192    cm
2193}
2194
2195#[allow(clippy::too_many_arguments)]
2196pub fn quant_all_bands(
2197    encode: bool,
2198    m: &CeltMode,
2199    start: usize,
2200    end: usize,
2201    x: &mut [f32],
2202    mut y: Option<&mut [f32]>,
2203    collapse_masks: &mut [u32],
2204    band_e: &[f32],
2205    pulses: &[i32],
2206    short_blocks: bool,
2207    spread: i32,
2208    dual_stereo: &mut bool,
2209    intensity: usize,
2210    tf_res: &[i32],
2211    total_bits: i32,
2212    balance: &mut i32,
2213    rc: &mut RangeCoder,
2214    lm: i32,
2215    coded_bands: i32,
2216    resynth: bool,
2217    disable_inv: bool,
2218    seed: &mut u32,
2219) {
2220    let mut balance_val = *balance;
2221    let b_blocks = if short_blocks { 1 << lm } else { 1 };
2222    let c_channels = if y.is_some() { 2 } else { 1 };
2223    let m_val = 1usize << lm as usize;
2224
2225    let norm_offset = m_val * (m.e_bands[start] as usize);
2226    let norm_size = m_val * (m.e_bands[m.nb_ebands - 1] as usize) - norm_offset;
2227
2228    const MAX_NORM_SIZE: usize = 800;
2229    debug_assert!(norm_size <= MAX_NORM_SIZE);
2230
2231    let mut norm_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_NORM_SIZE];
2232    let norm =
2233        unsafe { std::slice::from_raw_parts_mut(norm_buf.as_mut_ptr() as *mut f32, norm_size) };
2234    let mut norm2_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_NORM_SIZE];
2235    let norm2 =
2236        unsafe { std::slice::from_raw_parts_mut(norm2_buf.as_mut_ptr() as *mut f32, norm_size) };
2237
2238    let mut lowband_scratch_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
2239    let lowband_scratch_ptr = lowband_scratch_buf.as_mut_ptr() as *mut f32;
2240
2241    let mut lowband_offset: usize = 0;
2242    let mut update_lowband = true;
2243    let mut avoid_split_noise = b_blocks > 1;
2244
2245    let e_bands = &m.e_bands;
2246    let mut ctx_seed = *seed;
2247
2248    for i in start..end {
2249        let e_band_i = e_bands[i] as usize;
2250        let e_band_i1 = e_bands[i + 1] as usize;
2251        let offset = m_val * e_band_i;
2252        let n = m_val * (e_band_i1 - e_band_i);
2253        let last = i == end - 1;
2254
2255        let tell = tell_frac_inline!(rc);
2256        if i != start {
2257            balance_val -= tell;
2258        }
2259        let remaining_bits = total_bits - tell - 1;
2260
2261        let mut b = 0i32;
2262        if i < coded_bands as usize {
2263            let curr_balance = celt_sudiv(balance_val, 3i32.min(coded_bands - i as i32));
2264            b = 0i32.max(16383i32.min((remaining_bits + 1).min(pulses[i] + curr_balance)));
2265        }
2266
2267        let norm_pos = m_val * e_band_i - norm_offset;
2268        let tf_change = tf_res[i];
2269
2270        let mut effective_lowband: i32 = -1;
2271        let mut x_cm: u32;
2272        let mut y_cm: u32;
2273
2274        let band_start_abs = m_val * e_band_i;
2275        let start_abs = m_val * (e_bands[start] as usize);
2276        if resynth
2277            && ((band_start_abs as isize - n as isize >= start_abs as isize) || i == start + 1)
2278            && (update_lowband || lowband_offset == 0)
2279        {
2280            lowband_offset = i;
2281        }
2282
2283        if resynth && i == start + 1 {
2284            special_hybrid_folding(m, norm, start, m_val);
2285            if *dual_stereo {
2286                special_hybrid_folding(m, norm2, start, m_val);
2287            }
2288        }
2289
2290        if lowband_offset != 0 && (spread != SPREAD_AGGRESSIVE || b_blocks > 1 || tf_change < 0) {
2291            effective_lowband = 0i32.max(
2292                (m_val * e_bands[lowband_offset] as usize) as i32 - norm_offset as i32 - n as i32,
2293            );
2294            let el_abs = effective_lowband as usize + norm_offset;
2295
2296            let mut fold_start = lowband_offset;
2297            while fold_start > 0 {
2298                fold_start -= 1;
2299                if m_val * (e_bands[fold_start] as usize) <= el_abs {
2300                    break;
2301                }
2302            }
2303
2304            let mut fold_end = lowband_offset.saturating_sub(1);
2305            loop {
2306                fold_end += 1;
2307                if fold_end >= i || m_val * (e_bands[fold_end] as usize) >= el_abs + n {
2308                    break;
2309                }
2310            }
2311
2312            x_cm = 0;
2313            y_cm = 0;
2314            let mut fi = fold_start;
2315            loop {
2316                x_cm |= collapse_masks[fi * c_channels];
2317                y_cm |= collapse_masks[fi * c_channels + c_channels - 1];
2318                fi += 1;
2319                if fi >= fold_end {
2320                    break;
2321                }
2322            }
2323        } else {
2324            x_cm = (1u32 << b_blocks) - 1;
2325            y_cm = (1u32 << b_blocks) - 1;
2326        }
2327
2328        let mut ctx = BandCtx {
2329            encode,
2330            m,
2331            i,
2332            band_e,
2333            rc,
2334            spread,
2335            remaining_bits,
2336            resynth,
2337            tf_change,
2338            intensity,
2339            theta_round: 0,
2340            avoid_split_noise,
2341            arch: 0,
2342            disable_inv,
2343            seed: ctx_seed,
2344        };
2345
2346        let x_slice = &mut x[offset..offset + n];
2347        let band_uses_direct_norm = i >= m.eff_ebands;
2348        let allow_lowband_scratch = !(band_uses_direct_norm || (last && !encode));
2349        if *dual_stereo && i == intensity {
2350            *dual_stereo = false;
2351            if resynth {
2352                for j in 0..norm_pos {
2353                    norm[j] = 0.5 * (norm[j] + norm2[j]);
2354                }
2355            }
2356        }
2357
2358        if *dual_stereo {
2359            let y_slice = &mut y.as_mut().unwrap()[offset..offset + n];
2360
2361            let (lb_x, lb_out_x) = prepare_lowband_views(
2362                norm,
2363                lowband_scratch_ptr,
2364                allow_lowband_scratch,
2365                effective_lowband,
2366                norm_pos,
2367                n,
2368                !last,
2369            );
2370            x_cm = quant_band(
2371                &mut ctx,
2372                x_slice,
2373                n,
2374                b / 2,
2375                b_blocks,
2376                lb_x,
2377                lm,
2378                lb_out_x,
2379                1.0,
2380                x_cm,
2381            );
2382
2383            let (lb_y, lb_out_y) = prepare_lowband_views(
2384                norm2,
2385                lowband_scratch_ptr,
2386                allow_lowband_scratch,
2387                effective_lowband,
2388                norm_pos,
2389                n,
2390                !last,
2391            );
2392            y_cm = quant_band(
2393                &mut ctx,
2394                y_slice,
2395                n,
2396                b / 2,
2397                b_blocks,
2398                lb_y,
2399                lm,
2400                lb_out_y,
2401                1.0,
2402                y_cm,
2403            );
2404        } else if let Some(y_all) = y.as_mut() {
2405            let y_slice = &mut y_all[offset..offset + n];
2406            let (lb, lb_out) = prepare_lowband_views(
2407                norm,
2408                lowband_scratch_ptr,
2409                allow_lowband_scratch,
2410                effective_lowband,
2411                norm_pos,
2412                n,
2413                !last,
2414            );
2415            x_cm = quant_band_stereo(
2416                &mut ctx,
2417                x_slice,
2418                y_slice,
2419                n,
2420                b,
2421                b_blocks,
2422                lb,
2423                lm,
2424                lb_out,
2425                1.0,
2426                x_cm | y_cm,
2427            );
2428            y_cm = x_cm;
2429        } else {
2430            let (lb, lb_out) = prepare_lowband_views(
2431                norm,
2432                lowband_scratch_ptr,
2433                allow_lowband_scratch,
2434                effective_lowband,
2435                norm_pos,
2436                n,
2437                !last,
2438            );
2439            x_cm = quant_band(&mut ctx, x_slice, n, b, b_blocks, lb, lm, lb_out, 1.0, x_cm);
2440            y_cm = x_cm;
2441        }
2442
2443        collapse_masks[i * c_channels] = (x_cm & 0xFF) as u8 as u32;
2444        if c_channels == 2 {
2445            collapse_masks[i * c_channels + 1] = (y_cm & 0xFF) as u8 as u32;
2446        }
2447
2448        balance_val += pulses[i] + tell;
2449        ctx_seed = ctx.seed;
2450        update_lowband = b > ((n as i32) << BITRES);
2451
2452        avoid_split_noise = false;
2453    }
2454    *balance = balance_val;
2455    *seed = ctx_seed;
2456}
2457
2458#[cfg(target_arch = "aarch64")]
2459fn compute_band_energy_neon(band: &[f32]) -> f32 {
2460    use std::arch::aarch64::*;
2461
2462    let n = band.len();
2463    let mut sum = 1e-27f32;
2464
2465    unsafe {
2466        let n16 = n & !15;
2467        if n16 > 0 {
2468            let mut acc0 = vdupq_n_f32(0.0);
2469            let mut acc1 = vdupq_n_f32(0.0);
2470            let mut acc2 = vdupq_n_f32(0.0);
2471            let mut acc3 = vdupq_n_f32(0.0);
2472
2473            for i in (0..n16).step_by(16) {
2474                let v0 = vld1q_f32(band.as_ptr().add(i));
2475                let v1 = vld1q_f32(band.as_ptr().add(i + 4));
2476                let v2 = vld1q_f32(band.as_ptr().add(i + 8));
2477                let v3 = vld1q_f32(band.as_ptr().add(i + 12));
2478
2479                acc0 = vfmaq_f32(acc0, v0, v0);
2480                acc1 = vfmaq_f32(acc1, v1, v1);
2481                acc2 = vfmaq_f32(acc2, v2, v2);
2482                acc3 = vfmaq_f32(acc3, v3, v3);
2483            }
2484
2485            acc0 = vaddq_f32(acc0, acc1);
2486            acc2 = vaddq_f32(acc2, acc3);
2487            acc0 = vaddq_f32(acc0, acc2);
2488            sum += vaddvq_f32(acc0);
2489        }
2490
2491        let n4 = (n & !3) - n16;
2492        if n4 > 0 {
2493            let mut acc = vdupq_n_f32(0.0);
2494            for i in (n16..n16 + n4).step_by(4) {
2495                let v = vld1q_f32(band.as_ptr().add(i));
2496                acc = vfmaq_f32(acc, v, v);
2497            }
2498            sum += vaddvq_f32(acc);
2499        }
2500
2501        for i in (n16 + n4)..n {
2502            let v = band[i];
2503            sum += v * v;
2504        }
2505    }
2506
2507    sum.sqrt()
2508}
2509
2510#[cfg(target_arch = "x86_64")]
2511#[target_feature(enable = "avx2,fma")]
2512unsafe fn compute_band_energy_avx2(band: &[f32]) -> f32 {
2513    use std::arch::x86_64::*;
2514
2515    let n = band.len();
2516    let mut i = 0usize;
2517
2518    let mut acc0 = _mm256_setzero_ps();
2519    let mut acc1 = _mm256_setzero_ps();
2520
2521    while i + 16 <= n {
2522        let v0 = _mm256_loadu_ps(band.as_ptr().add(i));
2523        let v1 = _mm256_loadu_ps(band.as_ptr().add(i + 8));
2524        acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2525        acc1 = _mm256_fmadd_ps(v1, v1, acc1);
2526        i += 16;
2527    }
2528
2529    if i + 8 <= n {
2530        let v0 = _mm256_loadu_ps(band.as_ptr().add(i));
2531        acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2532        i += 8;
2533    }
2534
2535    let acc = _mm256_add_ps(acc0, acc1);
2536    let hi = _mm256_extractf128_ps(acc, 1);
2537    let lo = _mm256_castps256_ps128(acc);
2538    let s4 = _mm_add_ps(lo, hi);
2539    let t1 = _mm_movehl_ps(s4, s4);
2540    let s2 = _mm_add_ps(s4, t1);
2541    let t2 = _mm_shuffle_ps(s2, s2, 0x55);
2542    let mut sum = 1e-27f32 + _mm_cvtss_f32(_mm_add_ss(s2, t2));
2543
2544    for &v in &band[i..] {
2545        sum += v * v;
2546    }
2547
2548    sum.sqrt()
2549}
2550
2551pub fn compute_band_energies(
2552    m: &CeltMode,
2553    x: &[f32],
2554    band_e: &mut [f32],
2555    end: usize,
2556    channels: usize,
2557    lm: usize,
2558) {
2559    let frame_size = m.short_mdct_size << lm;
2560
2561    #[cfg(target_arch = "x86_64")]
2562    let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2563
2564    for c in 0..channels {
2565        let ch = &x[c * frame_size..(c + 1) * frame_size];
2566        for i in 0..end {
2567            let offset = (m.e_bands[i] as usize) << lm;
2568            let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2569            let band = &ch[offset..offset + n];
2570
2571            #[cfg(target_arch = "aarch64")]
2572            {
2573                band_e[c * m.nb_ebands + i] = compute_band_energy_neon(band);
2574            }
2575            #[cfg(target_arch = "x86_64")]
2576            {
2577                if n >= 8 && use_avx2 {
2578                    band_e[c * m.nb_ebands + i] = unsafe { compute_band_energy_avx2(band) };
2579                } else {
2580                    let sum = band.iter().fold(1e-27f32, |acc, &v| acc + v * v);
2581                    band_e[c * m.nb_ebands + i] = sum.sqrt();
2582                }
2583            }
2584            #[cfg(all(not(target_arch = "aarch64"), not(target_arch = "x86_64")))]
2585            {
2586                let sum = band.iter().fold(1e-27f32, |acc, &v| acc + v * v);
2587                band_e[c * m.nb_ebands + i] = sum.sqrt();
2588            }
2589        }
2590    }
2591}
2592
2593pub fn amp2log2(
2594    m: &CeltMode,
2595    start: usize,
2596    end: usize,
2597    band_e: &[f32],
2598    band_log_e: &mut [f32],
2599    channels: usize,
2600) {
2601    for c in 0..channels {
2602        for i in 0..start {
2603            band_log_e[c * m.nb_ebands + i] = -14.0;
2604        }
2605        for i in start..end {
2606            let val = band_e[c * m.nb_ebands + i].max(1e-10);
2607            band_log_e[c * m.nb_ebands + i] = val.log2() - m.e_means[i];
2608        }
2609    }
2610}
2611
2612pub fn log2amp(m: &CeltMode, end: usize, band_e: &mut [f32], band_log_e: &[f32], channels: usize) {
2613    for c in 0..channels {
2614        for i in 0..end {
2615            band_e[c * m.nb_ebands + i] = band_log_e[c * m.nb_ebands + i] + m.e_means[i];
2616        }
2617    }
2618}
2619
2620pub fn normalise_bands(
2621    m: &CeltMode,
2622    freq: &[f32],
2623    x: &mut [f32],
2624    band_e: &[f32],
2625    end: usize,
2626    channels: usize,
2627    m_val: usize,
2628) {
2629    let lm = m_val.trailing_zeros() as usize;
2630    let frame_size = m.short_mdct_size << lm;
2631    #[cfg(target_arch = "x86_64")]
2632    let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2633    for c in 0..channels {
2634        for i in 0..end {
2635            let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2636            let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2637            let norm = 1.0 / (1e-27 + band_e[c * m.nb_ebands + i]);
2638            let src = &freq[base..base + n];
2639            let dst = &mut x[base..base + n];
2640            #[cfg(target_arch = "x86_64")]
2641            if n >= 8 && use_avx2 {
2642                unsafe { scale_slice_avx2(src, dst, norm, n) };
2643                continue;
2644            }
2645            #[cfg(target_arch = "aarch64")]
2646            if n >= 8 {
2647                unsafe { scale_slice_neon(src, dst, norm, n) };
2648                continue;
2649            }
2650            for (d, &s) in dst.iter_mut().zip(src) {
2651                *d = s * norm;
2652            }
2653        }
2654    }
2655}
2656
2657#[cfg(target_arch = "x86_64")]
2658#[target_feature(enable = "avx2")]
2659unsafe fn scale_slice_avx2(src: &[f32], dst: &mut [f32], scale: f32, n: usize) {
2660    use std::arch::x86_64::*;
2661    let vscale = _mm256_set1_ps(scale);
2662    let mut i = 0;
2663
2664    while i + 16 <= n {
2665        let s0 = _mm256_loadu_ps(src.as_ptr().add(i));
2666        let s1 = _mm256_loadu_ps(src.as_ptr().add(i + 8));
2667        _mm256_storeu_ps(dst.as_mut_ptr().add(i), _mm256_mul_ps(s0, vscale));
2668        _mm256_storeu_ps(dst.as_mut_ptr().add(i + 8), _mm256_mul_ps(s1, vscale));
2669        i += 16;
2670    }
2671    while i + 8 <= n {
2672        let sv = _mm256_loadu_ps(src.as_ptr().add(i));
2673        _mm256_storeu_ps(dst.as_mut_ptr().add(i), _mm256_mul_ps(sv, vscale));
2674        i += 8;
2675    }
2676    for j in i..n {
2677        dst[j] = src[j] * scale;
2678    }
2679}
2680
2681#[cfg(target_arch = "aarch64")]
2682#[inline(always)]
2683#[allow(unsafe_op_in_unsafe_fn)]
2684unsafe fn scale_slice_neon(src: &[f32], dst: &mut [f32], scale: f32, n: usize) {
2685    use std::arch::aarch64::*;
2686    let vscale = vdupq_n_f32(scale);
2687    let mut i = 0;
2688
2689    while i + 16 <= n {
2690        let s0 = vld1q_f32(src.as_ptr().add(i));
2691        let s1 = vld1q_f32(src.as_ptr().add(i + 4));
2692        let s2 = vld1q_f32(src.as_ptr().add(i + 8));
2693        let s3 = vld1q_f32(src.as_ptr().add(i + 12));
2694        vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2695        vst1q_f32(dst.as_mut_ptr().add(i + 4), vmulq_f32(s1, vscale));
2696        vst1q_f32(dst.as_mut_ptr().add(i + 8), vmulq_f32(s2, vscale));
2697        vst1q_f32(dst.as_mut_ptr().add(i + 12), vmulq_f32(s3, vscale));
2698        i += 16;
2699    }
2700    while i + 8 <= n {
2701        let s0 = vld1q_f32(src.as_ptr().add(i));
2702        let s1 = vld1q_f32(src.as_ptr().add(i + 4));
2703        vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2704        vst1q_f32(dst.as_mut_ptr().add(i + 4), vmulq_f32(s1, vscale));
2705        i += 8;
2706    }
2707    while i + 4 <= n {
2708        let s0 = vld1q_f32(src.as_ptr().add(i));
2709        vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2710        i += 4;
2711    }
2712    for j in i..n {
2713        dst[j] = src[j] * scale;
2714    }
2715}
2716
2717#[allow(clippy::too_many_arguments)]
2718pub fn denormalise_bands(
2719    m: &CeltMode,
2720    x: &[f32],
2721    freq: &mut [f32],
2722    band_e: &[f32],
2723    start: usize,
2724    end: usize,
2725    channels: usize,
2726    m_val: usize,
2727) {
2728    let lm = m_val.trailing_zeros() as usize;
2729    let frame_size = m.short_mdct_size << lm;
2730    #[cfg(target_arch = "x86_64")]
2731    let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2732
2733    for c in 0..channels {
2734        for i in start..end {
2735            let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2736            let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2737            let band_log = band_e[c * m.nb_ebands + i];
2738
2739            // Match C: celt_exp2_db(MIN32(32.f, lg)) — cap gain to prevent overflow
2740            let g = (2.0f32).powf(band_log.min(32.0));
2741            let src = &x[base..base + n];
2742            let dst = &mut freq[base..base + n];
2743            #[cfg(target_arch = "x86_64")]
2744            if n >= 8 && use_avx2 {
2745                unsafe { scale_slice_avx2(src, dst, g, n) };
2746                continue;
2747            }
2748            #[cfg(target_arch = "aarch64")]
2749            if n >= 8 {
2750                unsafe { scale_slice_neon(src, dst, g, n) };
2751                continue;
2752            }
2753            for (d, &s) in dst.iter_mut().zip(src) {
2754                *d = s * g;
2755            }
2756        }
2757    }
2758}
2759
2760pub fn celt_lcg_rand(seed: u32) -> u32 {
2761    seed.wrapping_mul(1664525).wrapping_add(1013904223)
2762}
2763
2764#[cfg(target_arch = "aarch64")]
2765#[inline(always)]
2766#[allow(unsafe_op_in_unsafe_fn)]
2767unsafe fn renormalise_vector_neon(x: &mut [f32], n: usize, gain: f32) {
2768    use std::arch::aarch64::*;
2769
2770    let mut sum_vec = vdupq_n_f32(0.0);
2771    let mut i = 0;
2772
2773    while i + 16 <= n {
2774        let x0 = vld1q_f32(x.as_ptr().add(i));
2775        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2776        let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2777        let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2778        sum_vec = vfmaq_f32(sum_vec, x0, x0);
2779        sum_vec = vfmaq_f32(sum_vec, x1, x1);
2780        sum_vec = vfmaq_f32(sum_vec, x2, x2);
2781        sum_vec = vfmaq_f32(sum_vec, x3, x3);
2782        i += 16;
2783    }
2784
2785    while i + 8 <= n {
2786        let x0 = vld1q_f32(x.as_ptr().add(i));
2787        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2788        sum_vec = vfmaq_f32(sum_vec, x0, x0);
2789        sum_vec = vfmaq_f32(sum_vec, x1, x1);
2790        i += 8;
2791    }
2792
2793    while i + 4 <= n {
2794        let x0 = vld1q_f32(x.as_ptr().add(i));
2795        sum_vec = vfmaq_f32(sum_vec, x0, x0);
2796        i += 4;
2797    }
2798
2799    let mut e = 1e-15f32 + vaddvq_f32(sum_vec);
2800
2801    for j in i..n {
2802        e += x[j] * x[j];
2803    }
2804
2805    let norm = gain / e.sqrt();
2806    let vnorm = vdupq_n_f32(norm);
2807
2808    i = 0;
2809    while i + 16 <= n {
2810        let x0 = vld1q_f32(x.as_ptr().add(i));
2811        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2812        let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2813        let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2814        vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2815        vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2816        vst1q_f32(x.as_mut_ptr().add(i + 8), vmulq_f32(x2, vnorm));
2817        vst1q_f32(x.as_mut_ptr().add(i + 12), vmulq_f32(x3, vnorm));
2818        i += 16;
2819    }
2820
2821    while i + 8 <= n {
2822        let x0 = vld1q_f32(x.as_ptr().add(i));
2823        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2824        vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2825        vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2826        i += 8;
2827    }
2828
2829    while i + 4 <= n {
2830        let x0 = vld1q_f32(x.as_ptr().add(i));
2831        vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2832        i += 4;
2833    }
2834
2835    for j in i..n {
2836        x[j] *= norm;
2837    }
2838}
2839
2840#[cfg(target_arch = "x86_64")]
2841#[target_feature(enable = "avx2,fma")]
2842unsafe fn renormalise_vector_avx2(x: &mut [f32], n: usize, gain: f32) {
2843    use std::arch::x86_64::*;
2844
2845    let mut i = 0usize;
2846
2847    let mut acc0 = _mm256_setzero_ps();
2848    let mut acc1 = _mm256_setzero_ps();
2849
2850    while i + 16 <= n {
2851        let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2852        let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
2853        acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2854        acc1 = _mm256_fmadd_ps(v1, v1, acc1);
2855        i += 16;
2856    }
2857
2858    if i + 8 <= n {
2859        let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2860        acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2861        i += 8;
2862    }
2863
2864    let acc = _mm256_add_ps(acc0, acc1);
2865    let hi = _mm256_extractf128_ps(acc, 1);
2866    let lo = _mm256_castps256_ps128(acc);
2867    let s4 = _mm_add_ps(lo, hi);
2868    let t1 = _mm_movehl_ps(s4, s4);
2869    let s2 = _mm_add_ps(s4, t1);
2870    let t2 = _mm_shuffle_ps(s2, s2, 0x55);
2871    let mut e = 1e-15f32 + _mm_cvtss_f32(_mm_add_ss(s2, t2));
2872
2873    for &v in &x[i..n] {
2874        e += v * v;
2875    }
2876
2877    let norm = gain / e.sqrt();
2878    let vnorm = _mm256_set1_ps(norm);
2879
2880    i = 0;
2881    while i + 16 <= n {
2882        let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2883        let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
2884        _mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v0, vnorm));
2885        _mm256_storeu_ps(x.as_mut_ptr().add(i + 8), _mm256_mul_ps(v1, vnorm));
2886        i += 16;
2887    }
2888    while i + 8 <= n {
2889        let v = _mm256_loadu_ps(x.as_ptr().add(i));
2890        _mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v, vnorm));
2891        i += 8;
2892    }
2893    for v in &mut x[i..n] {
2894        *v *= norm;
2895    }
2896}
2897
2898pub fn renormalise_vector(x: &mut [f32], n: usize, gain: f32) {
2899    #[cfg(target_arch = "aarch64")]
2900    unsafe {
2901        renormalise_vector_neon(x, n, gain);
2902    }
2903    #[cfg(target_arch = "x86_64")]
2904    unsafe {
2905        if n >= 16 && std::arch::is_x86_feature_detected!("avx2") {
2906            renormalise_vector_avx2(x, n, gain);
2907            return;
2908        }
2909    }
2910    #[cfg(all(not(target_arch = "aarch64"), not(target_arch = "x86_64")))]
2911    {
2912        let mut e = 1e-15f32;
2913        for &xv in x[..n].iter() {
2914            e += xv * xv;
2915        }
2916        let norm = gain / e.sqrt();
2917        for xv in x[..n].iter_mut() {
2918            *xv *= norm;
2919        }
2920    }
2921    #[cfg(target_arch = "x86_64")]
2922    {
2923        let mut e = 1e-15f32;
2924        for &xv in x[..n].iter() {
2925            e += xv * xv;
2926        }
2927        let norm = gain / e.sqrt();
2928        for xv in x[..n].iter_mut() {
2929            *xv *= norm;
2930        }
2931    }
2932}
2933
2934#[allow(clippy::too_many_arguments)]
2935pub fn anti_collapse(
2936    m: &CeltMode,
2937    x_buf: &mut [f32],
2938    collapse_masks: &[u32],
2939    lm: i32,
2940    channels: usize,
2941    size: usize,
2942    start: usize,
2943    end: usize,
2944    log_e: &[f32],
2945    prev1_log_e: &[f32],
2946    prev2_log_e: &[f32],
2947    pulses: &[i32],
2948    mut seed: u32,
2949) -> u32 {
2950    for i in start..end {
2951        let n0 = (m.e_bands[i + 1] - m.e_bands[i]) as usize;
2952        let depth = if n0 > 0 {
2953            ((1 + pulses[i]) / n0 as i32) >> lm
2954        } else {
2955            0
2956        };
2957
2958        let thresh = 0.5 * (-(0.125 * depth as f32)).exp2();
2959        let sqrt_1 = 1.0 / ((n0 << lm) as f32).sqrt();
2960
2961        for c in 0..channels {
2962            let p1 = prev1_log_e[c * m.nb_ebands + i];
2963            let p2 = prev2_log_e[c * m.nb_ebands + i];
2964
2965            let (p1_adj, p2_adj) = if channels == 1 && prev1_log_e.len() >= 2 * m.nb_ebands {
2966                (
2967                    p1.max(prev1_log_e[m.nb_ebands + i]),
2968                    p2.max(prev2_log_e[m.nb_ebands + i]),
2969                )
2970            } else {
2971                (p1, p2)
2972            };
2973
2974            let e_diff = log_e[c * m.nb_ebands + i] - p1_adj.min(p2_adj);
2975            let e_diff = e_diff.max(0.0);
2976
2977            let mut r = 2.0 * (-e_diff).exp2();
2978            if lm == 3 {
2979                r *= std::f32::consts::SQRT_2;
2980            }
2981            r = r.min(thresh);
2982            r *= sqrt_1;
2983
2984            let x_offset = c * size + ((m.e_bands[i] as usize) << lm);
2985            let mut renormalize = false;
2986            for k in 0..(1 << lm) {
2987                if (collapse_masks[i * channels + c] & (1 << k)) == 0 {
2988                    for j in 0..n0 {
2989                        seed = celt_lcg_rand(seed);
2990                        x_buf[x_offset + (j << lm) + k] = if (seed & 0x8000) != 0 { r } else { -r };
2991                    }
2992                    renormalize = true;
2993                }
2994            }
2995            if renormalize {
2996                renormalise_vector(&mut x_buf[x_offset..x_offset + (n0 << lm)], n0 << lm, 1.0);
2997            }
2998        }
2999    }
3000    seed
3001}
3002
3003#[cfg(test)]
3004mod tests {
3005    use super::*;
3006
3007    #[test]
3008    fn test_bitexact_primitives_reference_values() {
3009        assert_eq!(bitexact_cos(64), 32767);
3010        assert_eq!(bitexact_cos(8192), 23171);
3011        assert_eq!(bitexact_cos(16320), 200);
3012
3013        assert_eq!(bitexact_log2tan(32767, 200), 15059);
3014        assert_eq!(bitexact_log2tan(30274, 12540), 2611);
3015        assert_eq!(bitexact_log2tan(23171, 23171), 0);
3016        assert_eq!(bitexact_log2tan(200, 32767), -15059);
3017    }
3018}