Skip to main content

opus_rs/
bands.rs

1use crate::modes::CeltMode;
2use crate::pvq::*;
3use crate::range_coder::RangeCoder;
4use crate::rate::{BITRES, bits2pulses, get_pulses, pulses2bits};
5use crate::tell_frac_inline;
6
7const MIN_STEREO_ENERGY: f32 = 1e-10;
8
9pub struct BandCtx<'a> {
10    pub encode: bool,
11    pub m: &'a CeltMode,
12    pub i: usize,
13    pub band_e: &'a [f32],
14    pub rc: &'a mut RangeCoder,
15    pub spread: i32,
16    pub remaining_bits: i32,
17    pub resynth: bool,
18    pub tf_change: i32,
19    pub intensity: usize,
20    pub theta_round: i32,
21    pub avoid_split_noise: bool,
22    pub arch: i32,
23    pub disable_inv: bool,
24    pub seed: u32,
25}
26
27#[inline]
28fn bitexact_cos(x: i16) -> i16 {
29    #[inline(always)]
30    fn frac_mul16(a: i16, b: i16) -> i16 {
31        ((16384i32 + (a as i32) * (b as i32)) >> 15) as i16
32    }
33
34    let tmp = (4096i32 + (x as i32) * (x as i32)) >> 13;
35    let x2 = tmp as i16;
36    let x2 = (32767 - x2 as i32
37        + frac_mul16(x2, -7651 + frac_mul16(x2, 8277 + frac_mul16(-626, x2))) as i32)
38        as i16;
39    1 + x2
40}
41
42#[inline]
43pub fn bitexact_log2tan(isin: i32, icos: i32) -> i32 {
44    let ec_ilog = |x: u32| -> i32 {
45        if x == 0 {
46            0
47        } else {
48            32 - x.leading_zeros() as i32
49        }
50    };
51    let lc = ec_ilog(icos.max(0) as u32);
52    let ls = ec_ilog(isin.max(0) as u32);
53    let icos_shifted = if lc > 0 {
54        icos.max(0) << (15 - lc).max(0)
55    } else {
56        0
57    };
58    let isin_shifted = if ls > 0 {
59        isin.max(0) << (15 - ls).max(0)
60    } else {
61        0
62    };
63    let fract_mul = |a: i32, b: i32| -> i32 { (a * b + 16384) >> 15 };
64    (ls - lc) * (1 << 11) + fract_mul(isin_shifted, fract_mul(isin_shifted, -2597) + 7932)
65        - fract_mul(icos_shifted, fract_mul(icos_shifted, -2597) + 7932)
66}
67
68#[inline(always)]
69fn celt_sudiv(n: i32, d: i32) -> i32 {
70    n / d
71}
72
73#[inline]
74fn isqrt32(mut val: u32) -> u32 {
75    let mut g = 0u32;
76    let mut bshift = ((32 - val.leading_zeros()) as i32 - 1) >> 1;
77    let mut b = 1u32 << bshift;
78    while bshift >= 0 {
79        let t = (((g << 1) + b) as u64) << bshift;
80        if t <= val as u64 {
81            g += b;
82            val -= t as u32;
83        }
84        b >>= 1;
85        bshift -= 1;
86    }
87    g
88}
89
90pub const SPREAD_NONE: i32 = 0;
91pub const SPREAD_LIGHT: i32 = 1;
92pub const SPREAD_NORMAL: i32 = 2;
93pub const SPREAD_AGGRESSIVE: i32 = 3;
94
95#[allow(clippy::too_many_arguments)]
96pub fn spreading_decision(
97    m: &CeltMode,
98    x_buf: &[f32],
99    average: &mut i32,
100    last_decision: i32,
101    hf_average: &mut i32,
102    tapset_decision: &mut i32,
103    update_hf: bool,
104    end: usize,
105    channels: usize,
106    m_val: usize,
107    spread_weight: &[i32],
108) -> i32 {
109    let mut sum = 0;
110    let mut nb_bands = 0;
111    let n0 = m_val * m.short_mdct_size;
112    let mut hf_sum = 0;
113
114    if m_val * (m.e_bands[end] as usize - m.e_bands[end - 1] as usize) <= 8 {
115        return SPREAD_NONE;
116    }
117
118    for c in 0..channels {
119        for (i, &sw) in spread_weight[..end].iter().enumerate() {
120            let n = m_val * (m.e_bands[i + 1] as usize - m.e_bands[i] as usize);
121            if n <= 8 {
122                continue;
123            }
124
125            let mut tcount = [0; 3];
126            let offset = m_val * m.e_bands[i] as usize + c * n0;
127            let x = &x_buf[offset..offset + n];
128
129            for xv in x.iter().copied() {
130                let x2n = xv * xv * (n as f32);
131                if x2n < 0.25 {
132                    tcount[0] += 1;
133                }
134                if x2n < 0.0625 {
135                    tcount[1] += 1;
136                }
137                if x2n < 0.015625 {
138                    tcount[2] += 1;
139                }
140            }
141
142            if i > m.nb_ebands - 4 {
143                hf_sum += 32 * (tcount[1] + tcount[0]) / (n as i32);
144            }
145
146            let tmp = (if 2 * tcount[2] >= (n as i32) { 1 } else { 0 })
147                + (if 2 * tcount[1] >= (n as i32) { 1 } else { 0 })
148                + (if 2 * tcount[0] >= (n as i32) { 1 } else { 0 });
149            sum += tmp * sw;
150            nb_bands += sw;
151        }
152    }
153
154    if update_hf {
155        if hf_sum > 0 {
156            hf_sum /= (channels as i32) * (4 - m.nb_ebands as i32 + end as i32);
157        }
158        *hf_average = (*hf_average + hf_sum) >> 1;
159        hf_sum = *hf_average;
160
161        if *tapset_decision == 2 {
162            hf_sum += 4;
163        } else if *tapset_decision == 0 {
164            hf_sum -= 4;
165        }
166
167        if hf_sum > 22 {
168            *tapset_decision = 2;
169        } else if hf_sum > 18 {
170            *tapset_decision = 1;
171        } else {
172            *tapset_decision = 0;
173        }
174    }
175
176    if nb_bands == 0 {
177        return SPREAD_NORMAL;
178    }
179
180    let mut sum_scaled = (sum << 8) / nb_bands;
181    sum_scaled = (sum_scaled + *average) >> 1;
182    *average = sum_scaled;
183
184    let sum_final = (3 * sum_scaled + (((3 - last_decision) << 7) + 64) + 2) >> 2;
185
186    if sum_final < 80 {
187        SPREAD_AGGRESSIVE
188    } else if sum_final < 256 {
189        SPREAD_NORMAL
190    } else if sum_final < 384 {
191        SPREAD_LIGHT
192    } else {
193        SPREAD_NONE
194    }
195}
196
197pub fn haar1(x: &mut [f32], n0: usize, stride: usize) {
198    #[cfg(target_arch = "aarch64")]
199    {
200        haar1_neon(x, n0, stride);
201    }
202    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
203    unsafe {
204        if stride == 1 && n0 >= 16 && is_x86_feature_detected!("avx") {
205            haar1_avx(x, n0);
206            return;
207        }
208    }
209    #[cfg(not(target_arch = "aarch64"))]
210    haar1_scalar(x, n0, stride);
211}
212
213#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
214#[target_feature(enable = "avx")]
215unsafe fn haar1_avx(x: &mut [f32], n0: usize) {
216    use std::arch::x86_64::*;
217    let n = n0 >> 1;
218    let scale = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
219    let mut j = 0;
220    while j + 8 <= n {
221        let ptr = x.as_mut_ptr().add(2 * j);
222        let a = _mm256_loadu_ps(ptr);
223        let b = _mm256_loadu_ps(ptr.add(4));
224
225        let t0 = _mm256_unpacklo_ps(a, b);
226        let t1 = _mm256_unpackhi_ps(a, b);
227
228        let even = _mm256_unpacklo_ps(t0, t1);
229        let odd = _mm256_unpackhi_ps(t0, t1);
230
231        let sum = _mm256_mul_ps(_mm256_add_ps(even, odd), scale);
232        let diff = _mm256_mul_ps(_mm256_sub_ps(even, odd), scale);
233
234        let r0 = _mm256_unpacklo_ps(sum, diff);
235        let r1 = _mm256_unpackhi_ps(sum, diff);
236
237        let out0 = _mm256_permute2f128_ps(r0, r1, 0x20);
238        let out1 = _mm256_permute2f128_ps(r0, r1, 0x31);
239
240        _mm256_storeu_ps(ptr, out0);
241        _mm256_storeu_ps(ptr.add(8), out1);
242        j += 8;
243    }
244
245    let scale = std::f32::consts::FRAC_1_SQRT_2;
246    while j < n {
247        let idx1 = 2 * j;
248        let idx2 = 2 * j + 1;
249        let tmp1 = scale * x[idx1];
250        let tmp2 = scale * x[idx2];
251        x[idx1] = tmp1 + tmp2;
252        x[idx2] = tmp1 - tmp2;
253        j += 1;
254    }
255}
256
257#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
258#[inline]
259fn haar1_scalar(x: &mut [f32], n0: usize, stride: usize) {
260    let n = n0 >> 1;
261    let scale = std::f32::consts::FRAC_1_SQRT_2;
262    for i in 0..stride {
263        for j in 0..n {
264            let idx1 = stride * 2 * j + i;
265            let idx2 = stride * (2 * j + 1) + i;
266            let tmp1 = scale * x[idx1];
267            let tmp2 = scale * x[idx2];
268            x[idx1] = tmp1 + tmp2;
269            x[idx2] = tmp1 - tmp2;
270        }
271    }
272}
273
274#[cfg(target_arch = "aarch64")]
275fn haar1_neon(x: &mut [f32], n0: usize, stride: usize) {
276    use std::arch::aarch64::*;
277
278    let n = n0 >> 1;
279    let scale = std::f32::consts::FRAC_1_SQRT_2;
280
281    unsafe {
282        let vscale = vdupq_n_f32(scale);
283
284        for i in 0..stride {
285            let mut j = 0;
286            while j + 4 <= n {
287                let idx_even = stride * 2 * j + i;
288                let idx_odd = stride * (2 * j + 1) + i;
289
290                let ve0 = vld1q_f32(x.as_ptr().add(idx_even));
291                let ve1 = vld1q_f32(x.as_ptr().add(idx_even + stride * 2));
292                let ve2 = vld1q_f32(x.as_ptr().add(idx_even + stride * 4));
293                let ve3 = vld1q_f32(x.as_ptr().add(idx_even + stride * 6));
294
295                let vo0 = vld1q_f32(x.as_ptr().add(idx_odd));
296                let vo1 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 2));
297                let vo2 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 4));
298                let vo3 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 6));
299
300                let te0 = vmulq_f32(ve0, vscale);
301                let te1 = vmulq_f32(ve1, vscale);
302                let te2 = vmulq_f32(ve2, vscale);
303                let te3 = vmulq_f32(ve3, vscale);
304
305                let to0 = vmulq_f32(vo0, vscale);
306                let to1 = vmulq_f32(vo1, vscale);
307                let to2 = vmulq_f32(vo2, vscale);
308                let to3 = vmulq_f32(vo3, vscale);
309
310                vst1q_f32(x.as_mut_ptr().add(idx_even), vaddq_f32(te0, to0));
311                vst1q_f32(
312                    x.as_mut_ptr().add(idx_even + stride * 2),
313                    vaddq_f32(te1, to1),
314                );
315                vst1q_f32(
316                    x.as_mut_ptr().add(idx_even + stride * 4),
317                    vaddq_f32(te2, to2),
318                );
319                vst1q_f32(
320                    x.as_mut_ptr().add(idx_even + stride * 6),
321                    vaddq_f32(te3, to3),
322                );
323
324                vst1q_f32(x.as_mut_ptr().add(idx_odd), vsubq_f32(te0, to0));
325                vst1q_f32(
326                    x.as_mut_ptr().add(idx_odd + stride * 2),
327                    vsubq_f32(te1, to1),
328                );
329                vst1q_f32(
330                    x.as_mut_ptr().add(idx_odd + stride * 4),
331                    vsubq_f32(te2, to2),
332                );
333                vst1q_f32(
334                    x.as_mut_ptr().add(idx_odd + stride * 6),
335                    vsubq_f32(te3, to3),
336                );
337
338                j += 4;
339            }
340
341            while j < n {
342                let idx1 = stride * 2 * j + i;
343                let idx2 = stride * (2 * j + 1) + i;
344                let tmp1 = scale * x[idx1];
345                let tmp2 = scale * x[idx2];
346                x[idx1] = tmp1 + tmp2;
347                x[idx2] = tmp1 - tmp2;
348                j += 1;
349            }
350        }
351    }
352}
353
354#[inline(always)]
355pub fn compute_qn(n: usize, b: i32, offset: i32, pulse_cap: i32, stereo: bool) -> i32 {
356    static EXP2_TABLE8: [i16; 8] = [16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048];
357    let mut n2 = (2 * n as i32) - 1;
358    if stereo && n == 2 {
359        n2 -= 1;
360    }
361    let mut qb = celt_sudiv(b + n2 * offset, n2);
362    qb = qb.min(b - pulse_cap - (4 << BITRES));
363    qb = qb.min(8 << BITRES);
364    if qb < (1i32 << BITRES >> 1) {
365        1
366    } else {
367        let val = EXP2_TABLE8[(qb & 0x7) as usize] as i32;
368        let shift = 14 - (qb >> BITRES);
369        let raw = if (0..32).contains(&shift) {
370            val >> shift
371        } else {
372            0
373        };
374        let qn = (raw + 1) >> 1 << 1;
375        qn.min(256)
376    }
377}
378
379#[cfg(target_arch = "aarch64")]
380#[inline(always)]
381#[allow(unsafe_op_in_unsafe_fn)]
382unsafe fn stereo_itheta_neon(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
383    use std::arch::aarch64::*;
384
385    let mut emid = 1e-15f32;
386    let mut eside = 1e-15f32;
387
388    if stereo {
389        let mut sum_mid = vdupq_n_f32(0.0);
390        let mut sum_side = vdupq_n_f32(0.0);
391        let mut i = 0;
392
393        while i + 16 <= n {
394            let x0 = vld1q_f32(x.as_ptr().add(i));
395            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
396            let x2 = vld1q_f32(x.as_ptr().add(i + 8));
397            let x3 = vld1q_f32(x.as_ptr().add(i + 12));
398            let y0 = vld1q_f32(y.as_ptr().add(i));
399            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
400            let y2 = vld1q_f32(y.as_ptr().add(i + 8));
401            let y3 = vld1q_f32(y.as_ptr().add(i + 12));
402
403            let m0 = vaddq_f32(x0, y0);
404            let m1 = vaddq_f32(x1, y1);
405            let m2 = vaddq_f32(x2, y2);
406            let m3 = vaddq_f32(x3, y3);
407            let s0 = vsubq_f32(x0, y0);
408            let s1 = vsubq_f32(x1, y1);
409            let s2 = vsubq_f32(x2, y2);
410            let s3 = vsubq_f32(x3, y3);
411
412            sum_mid = vfmaq_f32(sum_mid, m0, m0);
413            sum_mid = vfmaq_f32(sum_mid, m1, m1);
414            sum_mid = vfmaq_f32(sum_mid, m2, m2);
415            sum_mid = vfmaq_f32(sum_mid, m3, m3);
416            sum_side = vfmaq_f32(sum_side, s0, s0);
417            sum_side = vfmaq_f32(sum_side, s1, s1);
418            sum_side = vfmaq_f32(sum_side, s2, s2);
419            sum_side = vfmaq_f32(sum_side, s3, s3);
420
421            i += 16;
422        }
423
424        while i + 8 <= n {
425            let x0 = vld1q_f32(x.as_ptr().add(i));
426            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
427            let y0 = vld1q_f32(y.as_ptr().add(i));
428            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
429
430            let m0 = vaddq_f32(x0, y0);
431            let m1 = vaddq_f32(x1, y1);
432            let s0 = vsubq_f32(x0, y0);
433            let s1 = vsubq_f32(x1, y1);
434
435            sum_mid = vfmaq_f32(sum_mid, m0, m0);
436            sum_mid = vfmaq_f32(sum_mid, m1, m1);
437            sum_side = vfmaq_f32(sum_side, s0, s0);
438            sum_side = vfmaq_f32(sum_side, s1, s1);
439
440            i += 8;
441        }
442
443        while i + 4 <= n {
444            let x0 = vld1q_f32(x.as_ptr().add(i));
445            let y0 = vld1q_f32(y.as_ptr().add(i));
446            let m0 = vaddq_f32(x0, y0);
447            let s0 = vsubq_f32(x0, y0);
448            sum_mid = vfmaq_f32(sum_mid, m0, m0);
449            sum_side = vfmaq_f32(sum_side, s0, s0);
450            i += 4;
451        }
452
453        emid += vaddvq_f32(sum_mid);
454        eside += vaddvq_f32(sum_side);
455
456        for j in i..n {
457            let m = x[j] + y[j];
458            let s = x[j] - y[j];
459            emid += m * m;
460            eside += s * s;
461        }
462    } else {
463        let mut sum_mid = vdupq_n_f32(0.0);
464        let mut sum_side = vdupq_n_f32(0.0);
465        let mut i = 0;
466
467        while i + 16 <= n {
468            let x0 = vld1q_f32(x.as_ptr().add(i));
469            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
470            let x2 = vld1q_f32(x.as_ptr().add(i + 8));
471            let x3 = vld1q_f32(x.as_ptr().add(i + 12));
472            let y0 = vld1q_f32(y.as_ptr().add(i));
473            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
474            let y2 = vld1q_f32(y.as_ptr().add(i + 8));
475            let y3 = vld1q_f32(y.as_ptr().add(i + 12));
476
477            sum_mid = vfmaq_f32(sum_mid, x0, x0);
478            sum_mid = vfmaq_f32(sum_mid, x1, x1);
479            sum_mid = vfmaq_f32(sum_mid, x2, x2);
480            sum_mid = vfmaq_f32(sum_mid, x3, x3);
481            sum_side = vfmaq_f32(sum_side, y0, y0);
482            sum_side = vfmaq_f32(sum_side, y1, y1);
483            sum_side = vfmaq_f32(sum_side, y2, y2);
484            sum_side = vfmaq_f32(sum_side, y3, y3);
485
486            i += 16;
487        }
488
489        while i + 8 <= n {
490            let x0 = vld1q_f32(x.as_ptr().add(i));
491            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
492            let y0 = vld1q_f32(y.as_ptr().add(i));
493            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
494
495            sum_mid = vfmaq_f32(sum_mid, x0, x0);
496            sum_mid = vfmaq_f32(sum_mid, x1, x1);
497            sum_side = vfmaq_f32(sum_side, y0, y0);
498            sum_side = vfmaq_f32(sum_side, y1, y1);
499
500            i += 8;
501        }
502
503        while i + 4 <= n {
504            let x0 = vld1q_f32(x.as_ptr().add(i));
505            let y0 = vld1q_f32(y.as_ptr().add(i));
506            sum_mid = vfmaq_f32(sum_mid, x0, x0);
507            sum_side = vfmaq_f32(sum_side, y0, y0);
508            i += 4;
509        }
510
511        emid += vaddvq_f32(sum_mid);
512        eside += vaddvq_f32(sum_side);
513
514        for j in i..n {
515            emid += x[j] * x[j];
516            eside += y[j] * y[j];
517        }
518    }
519
520    let mid = emid.sqrt();
521    let side = eside.sqrt();
522    let theta_norm = celt_atan2p_norm(side, mid);
523    (0.5 + 16384.0 * theta_norm) as i32
524}
525
526#[inline(always)]
527#[cfg(target_arch = "aarch64")]
528pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
529    unsafe { stereo_itheta_neon(x, y, stereo, n) }
530}
531
532#[inline(always)]
533#[cfg(not(target_arch = "aarch64"))]
534pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
535    #[cfg(target_arch = "aarch64")]
536    unsafe {
537        return stereo_itheta_neon(x, y, stereo, n);
538    }
539    #[cfg(not(target_arch = "aarch64"))]
540    {
541        let mut emid = 1e-15f32;
542        let mut eside = 1e-15f32;
543        if stereo {
544            for i in 0..n {
545                let m = x[i] + y[i];
546                let s = x[i] - y[i];
547                emid += m * m;
548                eside += s * s;
549            }
550        } else {
551            for i in 0..n {
552                emid += x[i] * x[i];
553                eside += y[i] * y[i];
554            }
555        }
556        let mid = emid.sqrt();
557        let side = eside.sqrt();
558        let theta_norm = celt_atan2p_norm(side, mid);
559        (0.5 + 16384.0 * theta_norm) as i32
560    }
561}
562
563#[inline(always)]
564fn celt_atan2p_norm(y: f32, x: f32) -> f32 {
565    #[inline(always)]
566    fn atan_norm(x: f32) -> f32 {
567        const ATAN2_2_OVER_PI: f32 = std::f32::consts::FRAC_2_PI;
568        const A03: f32 = -3.333_166e-1_f32;
569        const A05: f32 = 1.996_270_4e-1_f32;
570        const A07: f32 = -1.397_658_3e-1_f32;
571        const A09: f32 = 9.794_234_e-2_f32;
572        const A11: f32 = -5.777_359_e-2_f32;
573        const A13: f32 = 2.304_014e-2_f32;
574        const A15: f32 = -4.355_406e-3_f32;
575        let x2 = x * x;
576        ATAN2_2_OVER_PI
577            * x
578            * (1.0
579                + x2 * (A03
580                    + x2 * (A05 + x2 * (A07 + x2 * (A09 + x2 * (A11 + x2 * (A13 + x2 * A15)))))))
581    }
582    if x * x + y * y < 1e-18 {
583        return 0.0;
584    }
585    if y < x {
586        atan_norm(y / x)
587    } else {
588        1.0 - atan_norm(x / y)
589    }
590}
591
592pub struct SplitCtx {
593    pub inv: bool,
594    pub imid: i32,
595    pub iside: i32,
596    pub delta: i32,
597    pub itheta: i32,
598    pub qalloc: i32,
599}
600
601#[allow(clippy::too_many_arguments)]
602#[inline(always)]
603fn compute_theta_encode(
604    ctx: &mut BandCtx,
605    sctx: &mut SplitCtx,
606    x: &[f32],
607    y: &[f32],
608    n: usize,
609    b: &mut i32,
610    b_blocks: i32,
611    b0: i32,
612    lm: i32,
613    stereo: bool,
614    fill: &mut u32,
615) {
616    let pulse_cap = ctx.m.log_n[ctx.i] as i32 + (lm << BITRES);
617    let offset = (pulse_cap >> 1) - if stereo && n == 2 { 16 } else { 4 };
618    let mut qn = compute_qn(n, *b, offset, pulse_cap, stereo);
619
620    if stereo && ctx.i >= ctx.intensity {
621        qn = 1;
622    }
623
624    if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
625        sctx.itheta = 8192;
626        sctx.qalloc = 0;
627        let imid = bitexact_cos(8192i16);
628        sctx.imid = imid as i32;
629        let iside = bitexact_cos(8192i16);
630        sctx.iside = iside as i32;
631        sctx.delta =
632            (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
633        return;
634    }
635
636    let mut itheta = stereo_itheta(x, y, stereo, n);
637
638    let tell_start = tell_frac_inline!(ctx.rc);
639
640    if qn != 1 {
641        if !stereo || ctx.theta_round == 0 {
642            itheta = (itheta * qn + 8192) >> 14;
643            if !stereo && ctx.avoid_split_noise && itheta > 0 && itheta < qn {
644                let unquantized = (itheta * 16384) / qn;
645                let imid = bitexact_cos(unquantized as i16) as i32;
646                let iside = bitexact_cos((16384 - unquantized) as i16) as i32;
647                let delta = (((n as i32 - 1) << 7) * bitexact_log2tan(iside, imid) + 16384) >> 15;
648                if delta > *b {
649                    itheta = qn;
650                } else if delta < -*b {
651                    itheta = 0;
652                }
653            }
654        } else {
655            let bias = if itheta > 8192 {
656                32767 / qn
657            } else {
658                -32767 / qn
659            };
660            let down = (itheta * qn + bias) >> 14;
661            let down = down.clamp(0, qn - 1);
662            if ctx.theta_round < 0 {
663                itheta = down;
664            } else {
665                itheta = down + 1;
666            }
667        }
668
669        if stereo && n > 2 {
670            let p0 = 3;
671            let x0 = qn / 2;
672            let ft = p0 * (x0 + 1) + x0;
673            let fl = if itheta <= x0 {
674                p0 * itheta
675            } else {
676                (itheta - 1 - x0) + (x0 + 1) * p0
677            };
678            let fh = if itheta <= x0 {
679                p0 * (itheta + 1)
680            } else {
681                (itheta - x0) + (x0 + 1) * p0
682            };
683            ctx.rc.encode(fl as u32, fh as u32, ft as u32);
684        } else if b0 > 1 || stereo {
685            ctx.rc.enc_uint(itheta as u32, (qn + 1) as u32);
686        } else {
687            let ft = ((qn >> 1) + 1) * ((qn >> 1) + 1);
688            let fs = if itheta <= (qn >> 1) {
689                itheta + 1
690            } else {
691                qn + 1 - itheta
692            };
693            let fl = if itheta <= (qn >> 1) {
694                (itheta * (itheta + 1)) >> 1
695            } else {
696                ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1)
697            };
698            ctx.rc.encode(fl as u32, (fl + fs) as u32, ft as u32);
699        }
700        itheta = (itheta as u32 * 16384 / qn as u32) as i32;
701    } else if stereo && ctx.i >= ctx.intensity {
702        let mut emid = 1e-15f32;
703        let mut eside = 1e-15f32;
704        for i in 0..n {
705            let m = x[i] + y[i];
706            let s = x[i] - y[i];
707            emid += m * m;
708            eside += s * s;
709        }
710        let inv = eside > emid;
711        ctx.rc.encode_bit_logp(inv, 1);
712        itheta = 0;
713        sctx.inv = inv;
714    } else {
715        itheta = 8192;
716    }
717
718    sctx.itheta = itheta;
719
720    sctx.qalloc = if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
721        0
722    } else {
723        tell_frac_inline!(ctx.rc) - tell_start
724    };
725    *b -= sctx.qalloc; // matches C: *b -= qalloc
726
727    if itheta == 0 {
728        sctx.imid = 32767;
729        sctx.iside = 0;
730        sctx.delta = -16384;
731        *fill &= (1 << b_blocks) - 1;
732    } else if itheta == 16384 {
733        sctx.imid = 0;
734        sctx.iside = 32767;
735        sctx.delta = 16384;
736        *fill &= !((1 << b_blocks) - 1);
737    } else {
738        let imid = bitexact_cos(itheta as i16);
739        sctx.imid = imid as i32;
740        let iside = bitexact_cos((16384 - itheta) as i16);
741        sctx.iside = iside as i32;
742        sctx.delta =
743            (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
744    }
745}
746
747#[allow(clippy::too_many_arguments)]
748#[inline(always)]
749pub fn compute_theta(
750    ctx: &mut BandCtx,
751    sctx: &mut SplitCtx,
752    x: &[f32],
753    y: &[f32],
754    n: usize,
755    b: &mut i32,
756    b_blocks: i32,
757    b0: i32,
758    lm: i32,
759    stereo: bool,
760    fill: &mut u32,
761) {
762    let pulse_cap = ctx.m.log_n[ctx.i] as i32 + (lm << BITRES);
763    let offset = (pulse_cap >> 1) - if stereo && n == 2 { 16 } else { 4 };
764    let mut qn = compute_qn(n, *b, offset, pulse_cap, stereo);
765
766    if stereo && ctx.i >= ctx.intensity {
767        qn = 1;
768    }
769
770    let mut itheta = 0;
771    if ctx.encode {
772        itheta = stereo_itheta(x, y, stereo, n);
773    }
774
775    let tell_start = tell_frac_inline!(ctx.rc);
776
777    if qn != 1 {
778        if ctx.encode {
779            if !stereo || ctx.theta_round == 0 {
780                itheta = (itheta * qn + 8192) >> 14;
781                if !stereo && ctx.avoid_split_noise && itheta > 0 && itheta < qn {
782                    let unquantized = (itheta * 16384) / qn;
783                    let imid = bitexact_cos(unquantized as i16) as i32;
784                    let iside = bitexact_cos((16384 - unquantized) as i16) as i32;
785                    let delta =
786                        (((n as i32 - 1) << 7) * bitexact_log2tan(iside, imid) + 16384) >> 15;
787                    if delta > *b {
788                        itheta = qn;
789                    } else if delta < -*b {
790                        itheta = 0;
791                    }
792                }
793            } else {
794                let bias = if itheta > 8192 {
795                    32767 / qn
796                } else {
797                    -32767 / qn
798                };
799                let down = (itheta * qn + bias) >> 14;
800                let down = down.clamp(0, qn - 1);
801                if ctx.theta_round < 0 {
802                    itheta = down;
803                } else {
804                    itheta = down + 1;
805                }
806            }
807        }
808
809        if stereo && n > 2 {
810            let p0 = 3;
811            let x0 = qn / 2;
812            let ft = p0 * (x0 + 1) + x0;
813            if ctx.encode {
814                let fl = if itheta <= x0 {
815                    p0 * itheta
816                } else {
817                    (itheta - 1 - x0) + (x0 + 1) * p0
818                };
819                let fh = if itheta <= x0 {
820                    p0 * (itheta + 1)
821                } else {
822                    (itheta - x0) + (x0 + 1) * p0
823                };
824                ctx.rc.encode(fl as u32, fh as u32, ft as u32);
825            } else {
826                let fs = ctx.rc.decode(ft as u32);
827                if fs < (x0 + 1) as u32 * p0 as u32 {
828                    itheta = fs as i32 / p0;
829                } else {
830                    itheta = (x0 + 1) + (fs as i32 - (x0 + 1) * p0);
831                }
832                let fl = if itheta <= x0 {
833                    p0 * itheta
834                } else {
835                    (itheta - 1 - x0) + (x0 + 1) * p0
836                };
837                let fh = if itheta <= x0 {
838                    p0 * (itheta + 1)
839                } else {
840                    (itheta - x0) + (x0 + 1) * p0
841                };
842                ctx.rc.update(fl as u32, fh as u32, ft as u32);
843            }
844        } else if b0 > 1 || stereo {
845            if ctx.encode {
846                ctx.rc.enc_uint(itheta as u32, (qn + 1) as u32);
847            } else {
848                itheta = ctx.rc.dec_uint((qn + 1) as u32) as i32;
849            }
850        } else {
851            let ft = ((qn >> 1) + 1) * ((qn >> 1) + 1);
852            if ctx.encode {
853                let fs = if itheta <= (qn >> 1) {
854                    itheta + 1
855                } else {
856                    qn + 1 - itheta
857                };
858                let fl = if itheta <= (qn >> 1) {
859                    (itheta * (itheta + 1)) >> 1
860                } else {
861                    ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1)
862                };
863                ctx.rc.encode(fl as u32, (fl + fs) as u32, ft as u32);
864            } else {
865                let fm = ctx.rc.decode(ft as u32) as i32;
866                if fm < (((qn >> 1) * ((qn >> 1) + 1)) >> 1) {
867                    itheta = (isqrt32((8 * fm + 1) as u32) as i32 - 1) >> 1;
868                    let fl = (itheta * (itheta + 1)) >> 1;
869                    let fs = itheta + 1;
870                    ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
871                } else {
872                    itheta = (2 * (qn + 1) - isqrt32((8 * (ft - fm - 1) + 1) as u32) as i32) >> 1;
873                    let fs = qn + 1 - itheta;
874                    let fl = ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1);
875                    ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
876                }
877            }
878        }
879        itheta = (itheta as u32 * 16384 / qn as u32) as i32;
880    } else if stereo && ctx.i >= ctx.intensity {
881        if ctx.encode {
882            let mut emid = 1e-15f32;
883            let mut eside = 1e-15f32;
884            for i in 0..n {
885                let m = x[i] + y[i];
886                let s = x[i] - y[i];
887                emid += m * m;
888                eside += s * s;
889            }
890            let inv = eside > emid;
891            ctx.rc.encode_bit_logp(inv, 1);
892            itheta = 0;
893            sctx.inv = inv;
894        } else {
895            sctx.inv = ctx.rc.decode_bit_logp(1);
896            itheta = 0;
897        }
898    } else {
899        itheta = 8192;
900    }
901
902    sctx.itheta = itheta;
903
904    sctx.qalloc = if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
905        0
906    } else {
907        tell_frac_inline!(ctx.rc) - tell_start
908    };
909    *b -= sctx.qalloc; // matches C: *b -= qalloc
910
911    if itheta == 0 {
912        sctx.imid = 32767;
913        sctx.iside = 0;
914        sctx.delta = -16384;
915        *fill &= (1 << b_blocks) - 1;
916    } else if itheta == 16384 {
917        sctx.imid = 0;
918        sctx.iside = 32767;
919        sctx.delta = 16384;
920        *fill &= !((1 << b_blocks) - 1);
921    } else {
922        let imid = bitexact_cos(itheta as i16);
923        sctx.imid = imid as i32;
924        let iside = bitexact_cos((16384 - itheta) as i16);
925        sctx.iside = iside as i32;
926        sctx.delta =
927            (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
928    }
929}
930
931#[inline(always)]
932fn quant_partition_n2_encode(
933    ctx: &mut BandCtx,
934    x: &mut [f32],
935    b: i32,
936    b_blocks: i32,
937    lowband: Option<&mut [f32]>,
938    lm: i32,
939    gain: f32,
940    fill: u32,
941) -> u32 {
942    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
943    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
944    ctx.remaining_bits -= curr_bits;
945
946    while ctx.remaining_bits < 0 && q > 0 {
947        ctx.remaining_bits += curr_bits;
948        q -= 1;
949        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
950        ctx.remaining_bits -= curr_bits;
951    }
952
953    if q != 0 {
954        let k = get_pulses(q);
955        alg_quant(x, 2, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
956    } else {
957        let has_lowband = lowband.is_some();
958        if has_lowband {
959            fill
960        } else {
961            (1u32 << b_blocks) - 1
962        }
963    }
964}
965
966#[inline(always)]
967fn quant_partition_n4_encode(
968    ctx: &mut BandCtx,
969    x: &mut [f32],
970    b: i32,
971    b_blocks: i32,
972    lowband: Option<&mut [f32]>,
973    lm: i32,
974    gain: f32,
975    fill: u32,
976) -> u32 {
977    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
978    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
979    ctx.remaining_bits -= curr_bits;
980
981    while ctx.remaining_bits < 0 && q > 0 {
982        ctx.remaining_bits += curr_bits;
983        q -= 1;
984        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
985        ctx.remaining_bits -= curr_bits;
986    }
987
988    if q != 0 {
989        let k = get_pulses(q);
990        alg_quant(x, 4, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
991    } else {
992        let has_lowband = lowband.is_some();
993        if has_lowband {
994            fill
995        } else {
996            (1u32 << b_blocks) - 1
997        }
998    }
999}
1000
1001#[inline(always)]
1002fn quant_partition_n8_encode(
1003    ctx: &mut BandCtx,
1004    x: &mut [f32],
1005    b: i32,
1006    b_blocks: i32,
1007    lowband: Option<&mut [f32]>,
1008    lm: i32,
1009    gain: f32,
1010    fill: u32,
1011) -> u32 {
1012    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1013    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1014    ctx.remaining_bits -= curr_bits;
1015
1016    while ctx.remaining_bits < 0 && q > 0 {
1017        ctx.remaining_bits += curr_bits;
1018        q -= 1;
1019        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1020        ctx.remaining_bits -= curr_bits;
1021    }
1022
1023    if q != 0 {
1024        let k = get_pulses(q);
1025        alg_quant(x, 8, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1026    } else {
1027        let has_lowband = lowband.is_some();
1028        if has_lowband {
1029            fill
1030        } else {
1031            (1u32 << b_blocks) - 1
1032        }
1033    }
1034}
1035
1036#[inline(always)]
1037#[allow(clippy::too_many_arguments)]
1038fn quant_partition_direct_encode(
1039    ctx: &mut BandCtx,
1040    x: &mut [f32],
1041    n: usize,
1042    b: i32,
1043    b_blocks: i32,
1044    lowband: Option<&mut [f32]>,
1045    lm: i32,
1046    gain: f32,
1047    fill: u32,
1048) -> u32 {
1049    let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1050    let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1051    ctx.remaining_bits -= curr_bits;
1052
1053    while ctx.remaining_bits < 0 && q > 0 {
1054        ctx.remaining_bits += curr_bits;
1055        q -= 1;
1056        curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1057        ctx.remaining_bits -= curr_bits;
1058    }
1059
1060    if q != 0 {
1061        let k = get_pulses(q);
1062        alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1063    } else {
1064        let has_lowband = lowband.is_some();
1065        if has_lowband {
1066            fill
1067        } else {
1068            (1u32 << b_blocks) - 1
1069        }
1070    }
1071}
1072
1073#[inline(always)]
1074#[allow(clippy::too_many_arguments)]
1075fn quant_partition_encode(
1076    ctx: &mut BandCtx,
1077    x: &mut [f32],
1078    n: usize,
1079    b: i32,
1080    b_blocks: i32,
1081    lowband: Option<&mut [f32]>,
1082    lm: i32,
1083    gain: f32,
1084    fill: u32,
1085) -> u32 {
1086    // N==2 can never split (should_split requires n>2), dispatch immediately
1087    if n == 2 {
1088        return quant_partition_n2_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1089    }
1090
1091    // Check split condition FIRST (matching C's quant_partition which checks this before dispatch)
1092    let should_split = if lm >= 0 && n > 2 {
1093        let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
1094        let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) } as usize;
1095        if cache_base > 0 {
1096            let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
1097            let max_q = unsafe { *cache_ptr } as usize;
1098            b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
1099        } else {
1100            false
1101        }
1102    } else {
1103        false
1104    };
1105
1106    if should_split {
1107        let mut sctx = SplitCtx {
1108            inv: false,
1109            imid: 0,
1110            iside: 0,
1111            delta: 0,
1112            itheta: 0,
1113            qalloc: 0,
1114        };
1115        let mut b_mut = b;
1116        let mut fill_mut = fill;
1117        let mid = n / 2;
1118        let lm = lm - 1;
1119        let b0 = b_blocks;
1120        if b_blocks == 1 {
1121            fill_mut = (fill_mut & 1) | (fill_mut << 1);
1122        }
1123        let b_blocks = (b_blocks + 1) >> 1;
1124        let (x_mid, x_side) = x.split_at_mut(mid);
1125
1126        compute_theta_encode(
1127            ctx,
1128            &mut sctx,
1129            x_mid,
1130            x_side,
1131            mid,
1132            &mut b_mut,
1133            b_blocks,
1134            b0,
1135            lm,
1136            false,
1137            &mut fill_mut,
1138        );
1139
1140        ctx.remaining_bits -= sctx.qalloc;
1141        let mut delta = sctx.delta;
1142        /* Give more bits to low-energy MDCTs than they would otherwise deserve */
1143        if b0 > 1 && (sctx.itheta & 0x3fff) != 0 {
1144            if sctx.itheta > 8192 {
1145                delta -= delta >> (4 - lm);
1146            } else {
1147                delta = 0.min(delta + ((mid as i32) << BITRES >> (5 - lm)));
1148            }
1149        }
1150        let mbits = (0).max((b_mut - delta) / 2).min(b_mut);
1151        let mut sbits = b_mut - mbits;
1152        let mut mbits = mbits;
1153
1154        let mut rebalance = ctx.remaining_bits;
1155        let mut cm;
1156
1157        if mbits >= sbits {
1158            cm = quant_partition_encode(
1159                ctx, x_mid, mid, mbits, b_blocks, lowband, lm, gain, fill_mut,
1160            );
1161            rebalance = mbits - (rebalance - ctx.remaining_bits);
1162            if rebalance > (3 << 3) && sctx.itheta != 0 {
1163                sbits += rebalance - (3 << 3);
1164            }
1165            cm |= quant_partition_encode(
1166                ctx,
1167                x_side,
1168                mid,
1169                sbits,
1170                b_blocks,
1171                None,
1172                lm,
1173                gain,
1174                fill_mut >> b_blocks,
1175            ) << (b0 >> 1);
1176        } else {
1177            cm = quant_partition_encode(
1178                ctx,
1179                x_side,
1180                mid,
1181                sbits,
1182                b_blocks,
1183                None,
1184                lm,
1185                gain,
1186                fill_mut >> b_blocks,
1187            ) << (b0 >> 1);
1188            rebalance = sbits - (rebalance - ctx.remaining_bits);
1189            if rebalance > (3 << 3) && sctx.itheta != 16384 {
1190                mbits += rebalance - (3 << 3);
1191            }
1192            cm |= quant_partition_encode(
1193                ctx, x_mid, mid, mbits, b_blocks, lowband, lm, gain, fill_mut,
1194            );
1195        }
1196        cm
1197    } else {
1198        // No split — dispatch to small-N specialized encoders or direct path
1199        if n == 4 {
1200            return quant_partition_n4_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1201        }
1202        if n == 8 {
1203            return quant_partition_n8_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1204        }
1205        if n == 16 {
1206            return quant_partition_direct_encode(ctx, x, n, b, b_blocks, lowband, lm, gain, fill);
1207        }
1208        let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1209        let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1210        ctx.remaining_bits -= curr_bits;
1211
1212        while ctx.remaining_bits < 0 && q > 0 {
1213            ctx.remaining_bits += curr_bits;
1214            q -= 1;
1215            curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1216            ctx.remaining_bits -= curr_bits;
1217        }
1218
1219        if q != 0 {
1220            let k = get_pulses(q);
1221            alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1222        } else if lowband.is_some() {
1223            fill
1224        } else {
1225            (1 << b_blocks) - 1
1226        }
1227    }
1228}
1229
1230#[inline(always)]
1231#[allow(clippy::too_many_arguments)]
1232pub fn quant_partition(
1233    ctx: &mut BandCtx,
1234    x: &mut [f32],
1235    n: usize,
1236    b: i32,
1237    b_blocks: i32,
1238    lowband: Option<&mut [f32]>,
1239    lm: i32,
1240    gain: f32,
1241    fill: u32,
1242) -> u32 {
1243    /* Check split condition FIRST, before dispatching to specialized handlers.
1244    This matches the C code which checks this at the top of quant_partition. */
1245    let should_split = if lm >= 0 && n > 2 {
1246        let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
1247        let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) } as usize;
1248        if cache_base > 0 {
1249            let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
1250            let max_q = unsafe { *cache_ptr } as usize;
1251            b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
1252        } else {
1253            false
1254        }
1255    } else {
1256        false
1257    };
1258    if should_split {
1259        let mut sctx = SplitCtx {
1260            inv: false,
1261            imid: 0,
1262            iside: 0,
1263            delta: 0,
1264            itheta: 0,
1265            qalloc: 0,
1266        };
1267        let mut b_mut = b;
1268        let mut fill_mut = fill;
1269        let mid = n / 2;
1270        let lm = lm - 1;
1271        let b0 = b_blocks; // Save original B0
1272        if b_blocks == 1 {
1273            fill_mut = (fill_mut & 1) | (fill_mut << 1);
1274        }
1275        let b_blocks = (b_blocks + 1) >> 1;
1276        let (x_mid, x_side) = x.split_at_mut(mid);
1277
1278        compute_theta(
1279            ctx,
1280            &mut sctx,
1281            x_mid,
1282            x_side,
1283            mid,
1284            &mut b_mut,
1285            b_blocks,
1286            b0,
1287            lm,
1288            false,
1289            &mut fill_mut,
1290        );
1291
1292        ctx.remaining_bits -= sctx.qalloc;
1293        let mut delta = sctx.delta;
1294        /* Give more bits to low-energy MDCTs than they would otherwise deserve
1295        (matches C quant_partition's B0>1 adjustment) */
1296        if b0 > 1 && (sctx.itheta & 0x3fff) != 0 {
1297            if sctx.itheta > 8192 {
1298                delta -= delta >> (4 - lm);
1299            } else {
1300                delta = 0.min(delta + ((mid as i32) << BITRES >> (5 - lm)));
1301            }
1302        }
1303        let mbits = (0).max((b_mut - delta) / 2).min(b_mut);
1304        let mut sbits = b_mut - mbits;
1305        let mut mbits = mbits;
1306
1307        let mut rebalance = ctx.remaining_bits;
1308        let mut cm;
1309
1310        if mbits >= sbits {
1311            cm = quant_partition(
1312                ctx,
1313                x_mid,
1314                mid,
1315                mbits,
1316                b_blocks,
1317                lowband,
1318                lm,
1319                gain * (sctx.imid as f32 / 32768.0),
1320                fill_mut,
1321            );
1322            rebalance = mbits - (rebalance - ctx.remaining_bits);
1323            if rebalance > (3 << 3) && sctx.itheta != 0 {
1324                sbits += rebalance - (3 << 3);
1325            }
1326            cm |= quant_partition(
1327                ctx,
1328                x_side,
1329                mid,
1330                sbits,
1331                b_blocks,
1332                None,
1333                lm,
1334                gain * (sctx.iside as f32 / 32768.0),
1335                fill_mut >> b_blocks,
1336            ) << (b0 >> 1);
1337        } else {
1338            cm = quant_partition(
1339                ctx,
1340                x_side,
1341                mid,
1342                sbits,
1343                b_blocks,
1344                None,
1345                lm,
1346                gain * (sctx.iside as f32 / 32768.0),
1347                fill_mut >> b_blocks,
1348            ) << (b0 >> 1);
1349            rebalance = sbits - (rebalance - ctx.remaining_bits);
1350            if rebalance > (3 << 3) && sctx.itheta != 16384 {
1351                mbits += rebalance - (3 << 3);
1352            }
1353            cm |= quant_partition(
1354                ctx,
1355                x_mid,
1356                mid,
1357                mbits,
1358                b_blocks,
1359                lowband,
1360                lm,
1361                gain * (sctx.imid as f32 / 32768.0),
1362                fill_mut,
1363            );
1364        }
1365        cm
1366    } else {
1367        let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1368        let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1369        ctx.remaining_bits -= curr_bits;
1370
1371        while ctx.remaining_bits < 0 && q > 0 {
1372            ctx.remaining_bits += curr_bits;
1373            q -= 1;
1374            curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1375            ctx.remaining_bits -= curr_bits;
1376        }
1377
1378        if q != 0 {
1379            let k = get_pulses(q);
1380            if ctx.encode {
1381                alg_quant(
1382                    x,
1383                    n,
1384                    k,
1385                    ctx.spread,
1386                    b_blocks as usize,
1387                    ctx.rc,
1388                    gain,
1389                    ctx.resynth,
1390                )
1391            } else {
1392                alg_unquant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1393            }
1394        } else {
1395            let has_lowband = lowband.is_some();
1396            if ctx.resynth {
1397                let cm_mask = (1u32 << b_blocks) - 1;
1398                let fill_masked = fill & cm_mask;
1399                if fill_masked == 0 {
1400                    x[..n].fill(0.0);
1401                } else if has_lowband {
1402                    let lb = lowband.unwrap();
1403                    #[cfg(target_arch = "aarch64")]
1404                    unsafe {
1405                        use std::arch::aarch64::*;
1406                        let n8 = n & !7;
1407                        let mut i = 0;
1408                        while i < n8 {
1409                            let mut vals = [0.0f32; 8];
1410                            for j in 0..8 {
1411                                ctx.seed = celt_lcg_rand(ctx.seed);
1412                                vals[j] = if ctx.seed & 0x8000 != 0 {
1413                                    1.0 / 256.0
1414                                } else {
1415                                    -1.0 / 256.0
1416                                };
1417                            }
1418                            let vnoise = vld1q_f32(vals.as_ptr());
1419                            let vnoise1 = vld1q_f32(vals.as_ptr().add(4));
1420                            let vlb = vld1q_f32(lb.as_ptr().add(i));
1421                            let vlb1 = vld1q_f32(lb.as_ptr().add(i + 4));
1422                            let vres = vaddq_f32(vlb, vnoise);
1423                            let vres1 = vaddq_f32(vlb1, vnoise1);
1424                            vst1q_f32(x.as_mut_ptr().add(i), vres);
1425                            vst1q_f32(x.as_mut_ptr().add(i + 4), vres1);
1426                            i += 8;
1427                        }
1428                        for j in i..n {
1429                            ctx.seed = celt_lcg_rand(ctx.seed);
1430                            x[j] = lb[j]
1431                                + if ctx.seed & 0x8000 != 0 {
1432                                    1.0 / 256.0
1433                                } else {
1434                                    -1.0 / 256.0
1435                                };
1436                        }
1437                    }
1438                    #[cfg(not(target_arch = "aarch64"))]
1439                    {
1440                        for j in 0..n {
1441                            ctx.seed = celt_lcg_rand(ctx.seed);
1442                            x[j] = lb[j]
1443                                + if ctx.seed & 0x8000 != 0 {
1444                                    1.0 / 256.0
1445                                } else {
1446                                    -1.0 / 256.0
1447                                };
1448                        }
1449                    }
1450                    renormalise_vector(x, n, gain);
1451                } else {
1452                    for xv in x[..n].iter_mut() {
1453                        ctx.seed = celt_lcg_rand(ctx.seed);
1454                        *xv = ((ctx.seed as i32 >> 20) as f32) / 16384.0;
1455                    }
1456                    renormalise_vector(x, n, gain);
1457                }
1458            }
1459            if has_lowband {
1460                fill
1461            } else {
1462                (1 << b_blocks) - 1
1463            }
1464        }
1465    }
1466}
1467
1468#[cfg(target_arch = "aarch64")]
1469#[inline(always)]
1470unsafe fn deinterleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1471    let n = n0 * stride;
1472    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1473    let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1474
1475    for i in 0..stride {
1476        let src_offset = i;
1477        let dst_offset = i * n0;
1478        for j in 0..n0 {
1479            tmp[dst_offset + j] = x[j * stride + src_offset];
1480        }
1481    }
1482
1483    x[..n].copy_from_slice(tmp);
1484}
1485
1486pub fn deinterleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1487    let n = n0 * stride;
1488
1489    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1490
1491    let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1492    if hadamard {
1493        let offset = match stride {
1494            2 => 0,
1495            4 => 2,
1496            8 => 6,
1497            16 => 14,
1498            _ => 0,
1499        };
1500        let ordery = &ORDERY_TABLE[offset..offset + stride];
1501        for i in 0..stride {
1502            for j in 0..n0 {
1503                tmp[ordery[i] as usize * n0 + j] = x[j * stride + i];
1504            }
1505        }
1506    } else {
1507        #[cfg(target_arch = "aarch64")]
1508        unsafe {
1509            if n0 >= 4 {
1510                deinterleave_hadamard_neon(x, n0, stride);
1511                return;
1512            }
1513        }
1514        for i in 0..stride {
1515            for j in 0..n0 {
1516                tmp[i * n0 + j] = x[j * stride + i];
1517            }
1518        }
1519    }
1520    x[..n].copy_from_slice(tmp);
1521}
1522
1523#[cfg(target_arch = "aarch64")]
1524#[inline(always)]
1525unsafe fn interleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1526    let n = n0 * stride;
1527    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1528    let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1529
1530    for i in 0..stride {
1531        let src_offset = i * n0;
1532        let dst_offset = i;
1533        for j in 0..n0 {
1534            tmp[j * stride + dst_offset] = x[src_offset + j];
1535        }
1536    }
1537
1538    x[..n].copy_from_slice(tmp);
1539}
1540
1541pub fn interleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1542    let n = n0 * stride;
1543    let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1544    let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1545    if hadamard {
1546        let offset = match stride {
1547            2 => 0,
1548            4 => 2,
1549            8 => 6,
1550            16 => 14,
1551            _ => 0,
1552        };
1553        let ordery = &ORDERY_TABLE[offset..offset + stride];
1554        for i in 0..stride {
1555            for j in 0..n0 {
1556                tmp[j * stride + i] = x[ordery[i] as usize * n0 + j];
1557            }
1558        }
1559    } else {
1560        #[cfg(target_arch = "aarch64")]
1561        unsafe {
1562            if n0 >= 4 {
1563                interleave_hadamard_neon(x, n0, stride);
1564                return;
1565            }
1566        }
1567        for i in 0..stride {
1568            for j in 0..n0 {
1569                tmp[j * stride + i] = x[i * n0 + j];
1570            }
1571        }
1572    }
1573    x[..n].copy_from_slice(tmp);
1574}
1575
1576const ORDERY_TABLE: [i32; 30] = [
1577    1, 0, 3, 0, 2, 1, 7, 0, 4, 3, 6, 1, 5, 2, 15, 0, 8, 7, 12, 3, 11, 4, 14, 1, 9, 6, 13, 2, 10, 5,
1578];
1579
1580fn quant_band_n1(
1581    ctx: &mut BandCtx,
1582    x: &mut [f32],
1583    y: Option<&mut [f32]>,
1584    lowband_out: Option<&mut [f32]>,
1585) -> u32 {
1586    let mut sign = 0;
1587    if ctx.remaining_bits >= 1 << BITRES {
1588        if ctx.encode {
1589            sign = if x[0] < 0.0 { 1 } else { 0 };
1590            ctx.rc.enc_bits(sign as u32, 1);
1591        } else {
1592            sign = ctx.rc.dec_bits(1) as i32;
1593        }
1594        ctx.remaining_bits -= 1 << BITRES;
1595    }
1596    if ctx.resynth {
1597        x[0] = if sign != 0 { -1.0 } else { 1.0 };
1598    }
1599    if let Some(y_val) = y {
1600        let mut y_sign = 0;
1601        if ctx.remaining_bits >= 1 << BITRES {
1602            if ctx.encode {
1603                y_sign = if y_val[0] < 0.0 { 1 } else { 0 };
1604                ctx.rc.enc_bits(y_sign as u32, 1);
1605            } else {
1606                y_sign = ctx.rc.dec_bits(1) as i32;
1607            }
1608            ctx.remaining_bits -= 1 << BITRES;
1609        }
1610        if ctx.resynth {
1611            y_val[0] = if y_sign != 0 { -1.0 } else { 1.0 };
1612        }
1613    }
1614    if let Some(l_out) = lowband_out {
1615        l_out[0] = x[0] / 16.0;
1616    }
1617    1
1618}
1619
1620#[allow(clippy::too_many_arguments)]
1621#[inline(always)]
1622pub fn quant_band(
1623    ctx: &mut BandCtx,
1624    x: &mut [f32],
1625    n: usize,
1626    b: i32,
1627    b_blocks: i32,
1628    lowband: Option<&mut [f32]>,
1629    lm: i32,
1630    lowband_out: Option<&mut [f32]>,
1631    gain: f32,
1632    fill: u32,
1633) -> u32 {
1634    let n0 = n;
1635    let b0 = b_blocks;
1636    let long_blocks = b0 == 1;
1637
1638    if n == 1 {
1639        return quant_band_n1(ctx, x, None, lowband_out);
1640    }
1641
1642    let mut b_blocks = b_blocks;
1643    let mut n_b = n / b_blocks as usize;
1644    let mut time_divide = 0;
1645    let mut recombine = 0;
1646    let mut tf_change_local = ctx.tf_change;
1647    let mut fill = fill;
1648
1649    if tf_change_local > 0 {
1650        recombine = tf_change_local;
1651    }
1652
1653    let mut lowband_buf = lowband;
1654
1655    static BIT_INTERLEAVE_TABLE: [u8; 16] = [0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3];
1656
1657    for k in 0..recombine {
1658        if ctx.encode {
1659            haar1(x, n >> k, 1 << k);
1660        }
1661        if let Some(ref mut lb) = lowband_buf {
1662            haar1(lb, n >> k, 1 << k);
1663        }
1664        fill = (BIT_INTERLEAVE_TABLE[(fill & 0xF) as usize] as u32)
1665            | ((BIT_INTERLEAVE_TABLE[(fill >> 4) as usize] as u32) << 2);
1666    }
1667    b_blocks >>= recombine;
1668    n_b <<= recombine;
1669
1670    while n_b & 1 == 0 && tf_change_local < 0 {
1671        if ctx.encode {
1672            haar1(x, n_b, b_blocks as usize);
1673        }
1674        if let Some(ref mut lb) = lowband_buf {
1675            haar1(lb, n_b, b_blocks as usize);
1676        }
1677        fill |= fill << b_blocks;
1678        b_blocks <<= 1;
1679        n_b >>= 1;
1680        time_divide += 1;
1681        tf_change_local += 1;
1682    }
1683
1684    let b0_after = b_blocks;
1685    let n_b0 = n_b;
1686
1687    if b_blocks > 1 {
1688        if ctx.encode {
1689            deinterleave_hadamard(
1690                x,
1691                n_b >> recombine as usize,
1692                (b_blocks << recombine) as usize,
1693                long_blocks,
1694            );
1695        }
1696        if let Some(ref mut lb) = lowband_buf {
1697            deinterleave_hadamard(
1698                lb,
1699                n_b >> recombine as usize,
1700                (b_blocks << recombine) as usize,
1701                long_blocks,
1702            );
1703        }
1704    }
1705
1706    let cm = if ctx.encode {
1707        quant_partition_encode(ctx, x, n, b, b_blocks, lowband_buf, lm, gain, fill)
1708    } else {
1709        quant_partition(ctx, x, n, b, b_blocks, lowband_buf, lm, gain, fill)
1710    };
1711
1712    if ctx.resynth {
1713        let mut cm = cm;
1714
1715        if b_blocks > 1 {
1716            interleave_hadamard(
1717                x,
1718                n_b >> recombine as usize,
1719                (b0_after << recombine) as usize,
1720                long_blocks,
1721            );
1722        }
1723
1724        let mut n_b_undo = n_b0;
1725        let mut b_undo = b0_after;
1726        for _ in 0..time_divide {
1727            b_undo >>= 1;
1728            n_b_undo <<= 1;
1729            cm |= cm >> b_undo;
1730            haar1(x, n_b_undo, b_undo as usize);
1731        }
1732
1733        static BIT_DEINTERLEAVE_TABLE: [u8; 16] = [
1734            0x00, 0x03, 0x0C, 0x0F, 0x30, 0x33, 0x3C, 0x3F, 0xC0, 0xC3, 0xCC, 0xCF, 0xF0, 0xF3,
1735            0xFC, 0xFF,
1736        ];
1737        for k in 0..recombine {
1738            cm = BIT_DEINTERLEAVE_TABLE[cm as usize & 0xF] as u32;
1739            haar1(x, n0 >> k, 1 << k);
1740        }
1741        let mut b_final = b0_after;
1742        b_final <<= recombine;
1743
1744        if let Some(lb_out) = lowband_out {
1745            let scale = (n0 as f32).sqrt();
1746            for j in 0..n0 {
1747                lb_out[j] = scale * x[j];
1748            }
1749        }
1750        cm &= (1u32 << b_final) - 1;
1751        return cm;
1752    }
1753
1754    cm
1755}
1756
1757pub fn stereo_merge(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1758    #[cfg(target_arch = "aarch64")]
1759    {
1760        stereo_merge_neon(x, y, mid, side, n);
1761    }
1762    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
1763    {
1764        unsafe { stereo_merge_avx2(x, y, mid, side, n) };
1765        return;
1766    }
1767    #[cfg(all(target_arch = "x86_64", not(target_feature = "avx2")))]
1768    {
1769        if std::arch::is_x86_feature_detected!("avx2") {
1770            unsafe { stereo_merge_avx2(x, y, mid, side, n) };
1771            return;
1772        }
1773    }
1774    #[cfg(not(target_arch = "aarch64"))]
1775    stereo_merge_scalar(x, y, mid, side, n);
1776}
1777
1778#[cfg(target_arch = "x86_64")]
1779#[target_feature(enable = "avx2")]
1780unsafe fn stereo_merge_avx2(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1781    use std::arch::x86_64::*;
1782
1783    let mut i = 0;
1784
1785    let v_mid = _mm256_set1_ps(mid);
1786    let v_side = _mm256_set1_ps(side);
1787
1788    while i + 15 < n {
1789        let x0 = _mm256_loadu_ps(x.as_ptr().add(i));
1790        let x1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
1791        let y0 = _mm256_loadu_ps(y.as_ptr().add(i));
1792        let y1 = _mm256_loadu_ps(y.as_ptr().add(i + 8));
1793
1794        let x_val0 = _mm256_mul_ps(x0, v_mid);
1795        let x_val1 = _mm256_mul_ps(x1, v_mid);
1796        let y_val0 = _mm256_mul_ps(y0, v_side);
1797        let y_val1 = _mm256_mul_ps(y1, v_side);
1798
1799        let new_x0 = _mm256_sub_ps(x_val0, y_val0);
1800        let new_x1 = _mm256_sub_ps(x_val1, y_val1);
1801        let new_y0 = _mm256_add_ps(x_val0, y_val0);
1802        let new_y1 = _mm256_add_ps(x_val1, y_val1);
1803
1804        _mm256_storeu_ps(x.as_mut_ptr().add(i), new_x0);
1805        _mm256_storeu_ps(x.as_mut_ptr().add(i + 8), new_x1);
1806        _mm256_storeu_ps(y.as_mut_ptr().add(i), new_y0);
1807        _mm256_storeu_ps(y.as_mut_ptr().add(i + 8), new_y1);
1808
1809        i += 16;
1810    }
1811
1812    while i + 7 < n {
1813        let x0 = _mm256_loadu_ps(x.as_ptr().add(i));
1814        let y0 = _mm256_loadu_ps(y.as_ptr().add(i));
1815
1816        let x_val = _mm256_mul_ps(x0, v_mid);
1817        let y_val = _mm256_mul_ps(y0, v_side);
1818
1819        let new_x = _mm256_sub_ps(x_val, y_val);
1820        let new_y = _mm256_add_ps(x_val, y_val);
1821
1822        _mm256_storeu_ps(x.as_mut_ptr().add(i), new_x);
1823        _mm256_storeu_ps(y.as_mut_ptr().add(i), new_y);
1824
1825        i += 8;
1826    }
1827
1828    for j in i..n {
1829        let x_val = x[j] * mid;
1830        let y_val = y[j] * side;
1831        x[j] = x_val - y_val;
1832        y[j] = x_val + y_val;
1833    }
1834}
1835
1836#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1837#[inline]
1838fn stereo_merge_scalar(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1839    for i in 0..n {
1840        let x_val = x[i] * mid;
1841        let y_val = y[i] * side;
1842        x[i] = x_val - y_val;
1843        y[i] = x_val + y_val;
1844    }
1845}
1846
1847#[cfg(target_arch = "aarch64")]
1848fn stereo_merge_neon(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1849    use std::arch::aarch64::*;
1850
1851    unsafe {
1852        let vmid = vdupq_n_f32(mid);
1853        let vside = vdupq_n_f32(side);
1854
1855        let n16 = n & !15;
1856        for i in (0..n16).step_by(16) {
1857            let x0 = vld1q_f32(x.as_ptr().add(i));
1858            let x1 = vld1q_f32(x.as_ptr().add(i + 4));
1859            let x2 = vld1q_f32(x.as_ptr().add(i + 8));
1860            let x3 = vld1q_f32(x.as_ptr().add(i + 12));
1861
1862            let y0 = vld1q_f32(y.as_ptr().add(i));
1863            let y1 = vld1q_f32(y.as_ptr().add(i + 4));
1864            let y2 = vld1q_f32(y.as_ptr().add(i + 8));
1865            let y3 = vld1q_f32(y.as_ptr().add(i + 12));
1866
1867            let xv0 = vmulq_f32(x0, vmid);
1868            let xv1 = vmulq_f32(x1, vmid);
1869            let xv2 = vmulq_f32(x2, vmid);
1870            let xv3 = vmulq_f32(x3, vmid);
1871
1872            let yv0 = vmulq_f32(y0, vside);
1873            let yv1 = vmulq_f32(y1, vside);
1874            let yv2 = vmulq_f32(y2, vside);
1875            let yv3 = vmulq_f32(y3, vside);
1876
1877            vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(xv0, yv0));
1878            vst1q_f32(x.as_mut_ptr().add(i + 4), vsubq_f32(xv1, yv1));
1879            vst1q_f32(x.as_mut_ptr().add(i + 8), vsubq_f32(xv2, yv2));
1880            vst1q_f32(x.as_mut_ptr().add(i + 12), vsubq_f32(xv3, yv3));
1881
1882            vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(xv0, yv0));
1883            vst1q_f32(y.as_mut_ptr().add(i + 4), vaddq_f32(xv1, yv1));
1884            vst1q_f32(y.as_mut_ptr().add(i + 8), vaddq_f32(xv2, yv2));
1885            vst1q_f32(y.as_mut_ptr().add(i + 12), vaddq_f32(xv3, yv3));
1886        }
1887
1888        let n4 = (n & !3) - n16;
1889        for i in (n16..n16 + n4).step_by(4) {
1890            let xv = vld1q_f32(x.as_ptr().add(i));
1891            let yv = vld1q_f32(y.as_ptr().add(i));
1892
1893            let x_val = vmulq_f32(xv, vmid);
1894            let y_val = vmulq_f32(yv, vside);
1895
1896            vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(x_val, y_val));
1897            vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(x_val, y_val));
1898        }
1899
1900        for i in (n16 + n4)..n {
1901            let x_val = x[i] * mid;
1902            let y_val = y[i] * side;
1903            x[i] = x_val - y_val;
1904            y[i] = x_val + y_val;
1905        }
1906    }
1907}
1908
1909#[allow(clippy::too_many_arguments)]
1910#[inline(always)]
1911pub fn quant_band_stereo(
1912    ctx: &mut BandCtx,
1913    x: &mut [f32],
1914    y: &mut [f32],
1915    n: usize,
1916    b: i32,
1917    b_blocks: i32,
1918    lowband: Option<&mut [f32]>,
1919    lm: i32,
1920    lowband_out: Option<&mut [f32]>,
1921    _gain: f32,
1922    fill: u32,
1923) -> u32 {
1924    if n == 1 {
1925        return quant_band_n1(ctx, x, Some(y), lowband_out);
1926    }
1927
1928    if ctx.encode
1929        && (ctx.band_e[ctx.i] < MIN_STEREO_ENERGY
1930            || ctx.band_e[ctx.m.nb_ebands + ctx.i] < MIN_STEREO_ENERGY)
1931    {
1932        if ctx.band_e[ctx.i] > ctx.band_e[ctx.m.nb_ebands + ctx.i] {
1933            y.copy_from_slice(x);
1934        } else {
1935            x.copy_from_slice(y);
1936        }
1937    }
1938
1939    let mut sctx = SplitCtx {
1940        inv: false,
1941        imid: 0,
1942        iside: 0,
1943        delta: 0,
1944        itheta: 0,
1945        qalloc: 0,
1946    };
1947    let mut b_mut = b;
1948    let mut fill_mut = fill;
1949    if ctx.encode {
1950        compute_theta_encode(
1951            ctx,
1952            &mut sctx,
1953            x,
1954            y,
1955            n,
1956            &mut b_mut,
1957            b_blocks,
1958            b_blocks,
1959            lm,
1960            true,
1961            &mut fill_mut,
1962        );
1963    } else {
1964        compute_theta(
1965            ctx,
1966            &mut sctx,
1967            x,
1968            y,
1969            n,
1970            &mut b_mut,
1971            b_blocks,
1972            b_blocks,
1973            lm,
1974            true,
1975            &mut fill_mut,
1976        );
1977    };
1978
1979    let mid_gain = sctx.imid as f32 / 32768.0;
1980    let side_gain = sctx.iside as f32 / 32768.0;
1981
1982    if n == 2 {
1983        let mut mbits = b_mut;
1984        let mut sbits = 0;
1985        if sctx.itheta != 0 && sctx.itheta != 16384 {
1986            sbits = 1 << BITRES;
1987        }
1988        mbits -= sbits;
1989        let c = sctx.itheta > 8192;
1990        ctx.remaining_bits -= sctx.qalloc + sbits;
1991
1992        let mut sign = 0;
1993        if sbits != 0 {
1994            if ctx.encode {
1995                sign = if c {
1996                    if (y[0] * x[1] - y[1] * x[0]) < 0.0 {
1997                        1
1998                    } else {
1999                        0
2000                    }
2001                } else if (x[0] * y[1] - x[1] * y[0]) < 0.0 {
2002                    1
2003                } else {
2004                    0
2005                };
2006                ctx.rc.enc_bits(sign as u32, 1);
2007            } else {
2008                sign = ctx.rc.dec_bits(1) as i32;
2009            }
2010        }
2011        let sign_val = (1 - 2 * sign) as f32;
2012        let cm = if c {
2013            let cm = quant_band(
2014                ctx,
2015                y,
2016                n,
2017                mbits,
2018                b_blocks,
2019                lowband,
2020                lm,
2021                lowband_out,
2022                1.0,
2023                fill,
2024            );
2025            x[0] = -sign_val * y[1];
2026            x[1] = sign_val * y[0];
2027            cm
2028        } else {
2029            let cm = quant_band(
2030                ctx,
2031                x,
2032                n,
2033                mbits,
2034                b_blocks,
2035                lowband,
2036                lm,
2037                lowband_out,
2038                1.0,
2039                fill,
2040            );
2041            y[0] = -sign_val * x[1];
2042            y[1] = sign_val * x[0];
2043            cm
2044        };
2045
2046        if ctx.resynth {
2047            let x0 = x[0];
2048            let x1 = x[1];
2049            let y0 = y[0];
2050            let y1 = y[1];
2051            x[0] = mid_gain * x0 - side_gain * y0;
2052            x[1] = mid_gain * x1 - side_gain * y1;
2053            y[0] = mid_gain * x0 + side_gain * y0;
2054            y[1] = mid_gain * x1 + side_gain * y1;
2055        }
2056        return cm;
2057    }
2058
2059    ctx.remaining_bits -= sctx.qalloc;
2060    let mut mbits = (0).max((b_mut - sctx.delta) / 2).min(b_mut);
2061    let mut sbits = b_mut - mbits;
2062
2063    let mut rebalance = ctx.remaining_bits;
2064    let mut cm;
2065
2066    if mbits >= sbits {
2067        cm = quant_band(
2068            ctx,
2069            x,
2070            n,
2071            mbits,
2072            b_blocks,
2073            lowband,
2074            lm,
2075            lowband_out,
2076            1.0,
2077            fill_mut,
2078        );
2079        rebalance = mbits - (rebalance - ctx.remaining_bits);
2080        if rebalance > (3 << 3) && sctx.itheta != 0 {
2081            sbits += rebalance - (3 << 3);
2082        }
2083        cm |= quant_band(
2084            ctx,
2085            y,
2086            n,
2087            sbits,
2088            b_blocks,
2089            None,
2090            lm,
2091            None,
2092            side_gain,
2093            fill_mut >> b_blocks,
2094        ) << (b_blocks >> 1);
2095    } else {
2096        cm = quant_band(
2097            ctx,
2098            y,
2099            n,
2100            sbits,
2101            b_blocks,
2102            None,
2103            lm,
2104            None,
2105            side_gain,
2106            fill_mut >> b_blocks,
2107        ) << (b_blocks >> 1);
2108        rebalance = sbits - (rebalance - ctx.remaining_bits);
2109        if rebalance > (3 << 3) && sctx.itheta != 16384 {
2110            mbits += rebalance - (3 << 3);
2111        }
2112        cm |= quant_band(
2113            ctx,
2114            x,
2115            n,
2116            mbits,
2117            b_blocks,
2118            lowband,
2119            lm,
2120            lowband_out,
2121            1.0,
2122            fill_mut,
2123        );
2124    }
2125
2126    if ctx.resynth {
2127        stereo_merge(x, y, mid_gain, side_gain, n);
2128        if sctx.inv {
2129            for yv in y[..n].iter_mut() {
2130                *yv = -*yv;
2131            }
2132        }
2133    }
2134    cm
2135}
2136
2137#[allow(clippy::too_many_arguments)]
2138pub fn quant_all_bands(
2139    encode: bool,
2140    m: &CeltMode,
2141    start: usize,
2142    end: usize,
2143    x: &mut [f32],
2144    mut y: Option<&mut [f32]>,
2145    collapse_masks: &mut [u32],
2146    band_e: &[f32],
2147    pulses: &[i32],
2148    short_blocks: bool,
2149    spread: i32,
2150    dual_stereo: &mut bool,
2151    intensity: usize,
2152    tf_res: &[i32],
2153    total_bits: i32,
2154    balance: &mut i32,
2155    rc: &mut RangeCoder,
2156    lm: i32,
2157    coded_bands: i32,
2158    resynth: bool,
2159    seed: &mut u32,
2160) {
2161    let mut balance_val = *balance;
2162    let b_blocks = if short_blocks { 1 << lm } else { 1 };
2163    let c_channels = if y.is_some() { 2 } else { 1 };
2164    let m_val = 1usize << lm as usize;
2165
2166    let norm_offset = m_val * (m.e_bands[start] as usize);
2167    let norm_size = m_val * (m.e_bands[m.nb_ebands - 1] as usize) - norm_offset;
2168
2169    const MAX_NORM_SIZE: usize = 800;
2170    debug_assert!(norm_size <= MAX_NORM_SIZE);
2171
2172    let mut norm_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_NORM_SIZE];
2173    let norm =
2174        unsafe { std::slice::from_raw_parts_mut(norm_buf.as_mut_ptr() as *mut f32, norm_size) };
2175
2176    let mut lowband_scratch_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
2177    let lowband_scratch_ptr = lowband_scratch_buf.as_mut_ptr() as *mut f32;
2178
2179    let lowband_offset: usize = 0;
2180    let mut avoid_split_noise = b_blocks > 1;
2181
2182    let e_bands = &m.e_bands;
2183    let mut ctx_seed = *seed;
2184
2185    for i in start..end {
2186        let e_band_i = e_bands[i] as usize;
2187        let e_band_i1 = e_bands[i + 1] as usize;
2188        let offset = m_val * e_band_i;
2189        let n = m_val * (e_band_i1 - e_band_i);
2190        let last = i == end - 1;
2191
2192        let tell = tell_frac_inline!(rc);
2193        if i != start {
2194            balance_val -= tell;
2195        }
2196        let remaining_bits = total_bits - tell - 1;
2197
2198        let mut b = 0i32;
2199        if i < coded_bands as usize {
2200            let curr_balance = celt_sudiv(balance_val, 3i32.min(coded_bands - i as i32));
2201            b = 0i32.max(16383i32.min((remaining_bits + 1).min(pulses[i] + curr_balance)));
2202        }
2203
2204        let norm_pos = m_val * e_band_i - norm_offset;
2205        let tf_change = tf_res[i];
2206
2207        let mut effective_lowband: i32 = -1;
2208        let mut x_cm: u32;
2209        let mut y_cm: u32;
2210
2211        if lowband_offset != 0 && (spread != SPREAD_AGGRESSIVE || b_blocks > 1 || tf_change < 0) {
2212            effective_lowband = 0i32.max(
2213                (m_val * e_bands[lowband_offset] as usize) as i32 - norm_offset as i32 - n as i32,
2214            );
2215            let el_abs = effective_lowband as usize + norm_offset;
2216
2217            let mut fold_start = lowband_offset;
2218            loop {
2219                if fold_start == 0 {
2220                    break;
2221                }
2222                fold_start -= 1;
2223                if m_val * (e_bands[fold_start] as usize) <= el_abs {
2224                    break;
2225                }
2226            }
2227            let mut fold_end = lowband_offset.saturating_sub(1);
2228            while fold_end + 1 < i && m_val * (e_bands[fold_end + 1] as usize) < el_abs + n {
2229                fold_end += 1;
2230            }
2231
2232            x_cm = 0;
2233            y_cm = 0;
2234            for fi in fold_start..fold_end {
2235                x_cm |= collapse_masks[fi * c_channels];
2236                y_cm |= collapse_masks[fi * c_channels + c_channels - 1];
2237            }
2238        } else {
2239            x_cm = (1u32 << b_blocks) - 1;
2240            y_cm = (1u32 << b_blocks) - 1;
2241        }
2242
2243        let mut ctx = BandCtx {
2244            encode,
2245            m,
2246            i,
2247            band_e,
2248            rc,
2249            spread,
2250            remaining_bits,
2251            resynth,
2252            tf_change,
2253            intensity,
2254            theta_round: 0,
2255            avoid_split_noise,
2256            arch: 0,
2257            disable_inv: false,
2258            seed: ctx_seed,
2259        };
2260
2261        if *dual_stereo && i == intensity {
2262            *dual_stereo = false;
2263        }
2264
2265        let mut lowband_scratch: Option<&mut [f32]> = if effective_lowband >= 0 {
2266            let lb_start = effective_lowband as usize;
2267            let lb_end = lb_start + n;
2268            if lb_end <= norm.len() {
2269                unsafe {
2270                    std::ptr::copy_nonoverlapping(
2271                        norm.as_ptr().add(lb_start),
2272                        lowband_scratch_ptr,
2273                        n,
2274                    )
2275                };
2276                Some(unsafe { std::slice::from_raw_parts_mut(lowband_scratch_ptr, n) })
2277            } else {
2278                None
2279            }
2280        } else {
2281            None
2282        };
2283
2284        let x_slice = &mut x[offset..offset + n];
2285        if *dual_stereo {
2286            let y_slice = &mut y.as_mut().unwrap()[offset..offset + n];
2287            let lb_x = lowband_scratch.as_deref_mut();
2288            let lb_out_x = if !last && norm_pos + n <= norm.len() {
2289                Some(&mut norm[norm_pos..norm_pos + n])
2290            } else {
2291                None
2292            };
2293            x_cm = quant_band(
2294                &mut ctx,
2295                x_slice,
2296                n,
2297                b / 2,
2298                b_blocks,
2299                lb_x,
2300                lm,
2301                lb_out_x,
2302                1.0,
2303                x_cm,
2304            );
2305            y_cm = quant_band(
2306                &mut ctx,
2307                y_slice,
2308                n,
2309                b / 2,
2310                b_blocks,
2311                None,
2312                lm,
2313                None,
2314                1.0,
2315                y_cm,
2316            );
2317        } else if let Some(y_all) = y.as_mut() {
2318            let y_slice = &mut y_all[offset..offset + n];
2319            let lb = lowband_scratch.as_deref_mut();
2320            let lb_out = if !last && norm_pos + n <= norm.len() {
2321                Some(&mut norm[norm_pos..norm_pos + n])
2322            } else {
2323                None
2324            };
2325            x_cm = quant_band_stereo(
2326                &mut ctx,
2327                x_slice,
2328                y_slice,
2329                n,
2330                b,
2331                b_blocks,
2332                lb,
2333                lm,
2334                lb_out,
2335                1.0,
2336                x_cm | y_cm,
2337            );
2338            y_cm = x_cm;
2339        } else {
2340            let lb = lowband_scratch;
2341            let lb_out = if !last && norm_pos + n <= norm.len() {
2342                Some(&mut norm[norm_pos..norm_pos + n])
2343            } else {
2344                None
2345            };
2346            x_cm = quant_band(&mut ctx, x_slice, n, b, b_blocks, lb, lm, lb_out, 1.0, x_cm);
2347            y_cm = x_cm;
2348        }
2349
2350        collapse_masks[i * c_channels] = (x_cm & 0xFF) as u8 as u32;
2351        if c_channels == 2 {
2352            collapse_masks[i * c_channels + 1] = (y_cm & 0xFF) as u8 as u32;
2353        }
2354
2355        balance_val += pulses[i] + tell;
2356        ctx_seed = ctx.seed;
2357
2358        avoid_split_noise = false;
2359    }
2360    *balance = balance_val;
2361    *seed = ctx_seed;
2362}
2363
2364#[cfg(target_arch = "aarch64")]
2365fn compute_band_energy_neon(band: &[f32]) -> f32 {
2366    use std::arch::aarch64::*;
2367
2368    let n = band.len();
2369    let mut sum = 1e-27f32;
2370
2371    unsafe {
2372        let n16 = n & !15;
2373        if n16 > 0 {
2374            let mut acc0 = vdupq_n_f32(0.0);
2375            let mut acc1 = vdupq_n_f32(0.0);
2376            let mut acc2 = vdupq_n_f32(0.0);
2377            let mut acc3 = vdupq_n_f32(0.0);
2378
2379            for i in (0..n16).step_by(16) {
2380                let v0 = vld1q_f32(band.as_ptr().add(i));
2381                let v1 = vld1q_f32(band.as_ptr().add(i + 4));
2382                let v2 = vld1q_f32(band.as_ptr().add(i + 8));
2383                let v3 = vld1q_f32(band.as_ptr().add(i + 12));
2384
2385                acc0 = vfmaq_f32(acc0, v0, v0);
2386                acc1 = vfmaq_f32(acc1, v1, v1);
2387                acc2 = vfmaq_f32(acc2, v2, v2);
2388                acc3 = vfmaq_f32(acc3, v3, v3);
2389            }
2390
2391            acc0 = vaddq_f32(acc0, acc1);
2392            acc2 = vaddq_f32(acc2, acc3);
2393            acc0 = vaddq_f32(acc0, acc2);
2394            sum += vaddvq_f32(acc0);
2395        }
2396
2397        let n4 = (n & !3) - n16;
2398        if n4 > 0 {
2399            let mut acc = vdupq_n_f32(0.0);
2400            for i in (n16..n16 + n4).step_by(4) {
2401                let v = vld1q_f32(band.as_ptr().add(i));
2402                acc = vfmaq_f32(acc, v, v);
2403            }
2404            sum += vaddvq_f32(acc);
2405        }
2406
2407        for i in (n16 + n4)..n {
2408            let v = band[i];
2409            sum += v * v;
2410        }
2411    }
2412
2413    sum.sqrt()
2414}
2415
2416#[cfg(target_arch = "x86_64")]
2417#[target_feature(enable = "avx2,fma")]
2418unsafe fn compute_band_energy_avx2(band: &[f32]) -> f32 {
2419    use std::arch::x86_64::*;
2420
2421    let n = band.len();
2422    let mut i = 0usize;
2423
2424    let mut acc0 = _mm256_setzero_ps();
2425    let mut acc1 = _mm256_setzero_ps();
2426
2427    while i + 16 <= n {
2428        let v0 = _mm256_loadu_ps(band.as_ptr().add(i));
2429        let v1 = _mm256_loadu_ps(band.as_ptr().add(i + 8));
2430        acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2431        acc1 = _mm256_fmadd_ps(v1, v1, acc1);
2432        i += 16;
2433    }
2434
2435    if i + 8 <= n {
2436        let v0 = _mm256_loadu_ps(band.as_ptr().add(i));
2437        acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2438        i += 8;
2439    }
2440
2441    let acc = _mm256_add_ps(acc0, acc1);
2442    let hi = _mm256_extractf128_ps(acc, 1);
2443    let lo = _mm256_castps256_ps128(acc);
2444    let s4 = _mm_add_ps(lo, hi);
2445    let t1 = _mm_movehl_ps(s4, s4);
2446    let s2 = _mm_add_ps(s4, t1);
2447    let t2 = _mm_shuffle_ps(s2, s2, 0x55);
2448    let mut sum = 1e-27f32 + _mm_cvtss_f32(_mm_add_ss(s2, t2));
2449
2450    for &v in &band[i..] {
2451        sum += v * v;
2452    }
2453
2454    sum.sqrt()
2455}
2456
2457pub fn compute_band_energies(
2458    m: &CeltMode,
2459    x: &[f32],
2460    band_e: &mut [f32],
2461    end: usize,
2462    channels: usize,
2463    lm: usize,
2464) {
2465    let frame_size = m.short_mdct_size << lm;
2466
2467    #[cfg(target_arch = "x86_64")]
2468    let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2469
2470    for c in 0..channels {
2471        let ch = &x[c * frame_size..(c + 1) * frame_size];
2472        for i in 0..end {
2473            let offset = (m.e_bands[i] as usize) << lm;
2474            let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2475            let band = &ch[offset..offset + n];
2476
2477            #[cfg(target_arch = "aarch64")]
2478            {
2479                band_e[c * m.nb_ebands + i] = compute_band_energy_neon(band);
2480            }
2481            #[cfg(target_arch = "x86_64")]
2482            {
2483                if n >= 8 && use_avx2 {
2484                    band_e[c * m.nb_ebands + i] = unsafe { compute_band_energy_avx2(band) };
2485                } else {
2486                    let sum = band.iter().fold(1e-27f32, |acc, &v| acc + v * v);
2487                    band_e[c * m.nb_ebands + i] = sum.sqrt();
2488                }
2489            }
2490            #[cfg(all(not(target_arch = "aarch64"), not(target_arch = "x86_64")))]
2491            {
2492                let sum = band.iter().fold(1e-27f32, |acc, &v| acc + v * v);
2493                band_e[c * m.nb_ebands + i] = sum.sqrt();
2494            }
2495        }
2496    }
2497}
2498
2499pub fn amp2log2(
2500    m: &CeltMode,
2501    start: usize,
2502    end: usize,
2503    band_e: &[f32],
2504    band_log_e: &mut [f32],
2505    channels: usize,
2506) {
2507    for c in 0..channels {
2508        for i in 0..start {
2509            band_log_e[c * m.nb_ebands + i] = -14.0;
2510        }
2511        for i in start..end {
2512            let val = band_e[c * m.nb_ebands + i].max(1e-10);
2513            band_log_e[c * m.nb_ebands + i] = val.log2() - m.e_means[i];
2514        }
2515    }
2516}
2517
2518pub fn log2amp(m: &CeltMode, end: usize, band_e: &mut [f32], band_log_e: &[f32], channels: usize) {
2519    for c in 0..channels {
2520        for i in 0..end {
2521            band_e[c * m.nb_ebands + i] = band_log_e[c * m.nb_ebands + i] + m.e_means[i];
2522        }
2523    }
2524}
2525
2526pub fn normalise_bands(
2527    m: &CeltMode,
2528    freq: &[f32],
2529    x: &mut [f32],
2530    band_e: &[f32],
2531    end: usize,
2532    channels: usize,
2533    m_val: usize,
2534) {
2535    let lm = m_val.trailing_zeros() as usize;
2536    let frame_size = m.short_mdct_size << lm;
2537    #[cfg(target_arch = "x86_64")]
2538    let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2539    for c in 0..channels {
2540        for i in 0..end {
2541            let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2542            let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2543            let norm = 1.0 / (1e-27 + band_e[c * m.nb_ebands + i]);
2544            let src = &freq[base..base + n];
2545            let dst = &mut x[base..base + n];
2546            #[cfg(target_arch = "x86_64")]
2547            if n >= 8 && use_avx2 {
2548                unsafe { scale_slice_avx2(src, dst, norm, n) };
2549                continue;
2550            }
2551            #[cfg(target_arch = "aarch64")]
2552            if n >= 8 {
2553                unsafe { scale_slice_neon(src, dst, norm, n) };
2554                continue;
2555            }
2556            for (d, &s) in dst.iter_mut().zip(src) {
2557                *d = s * norm;
2558            }
2559        }
2560    }
2561}
2562
2563#[cfg(target_arch = "x86_64")]
2564#[target_feature(enable = "avx2")]
2565unsafe fn scale_slice_avx2(src: &[f32], dst: &mut [f32], scale: f32, n: usize) {
2566    use std::arch::x86_64::*;
2567    let vscale = _mm256_set1_ps(scale);
2568    let mut i = 0;
2569
2570    while i + 16 <= n {
2571        let s0 = _mm256_loadu_ps(src.as_ptr().add(i));
2572        let s1 = _mm256_loadu_ps(src.as_ptr().add(i + 8));
2573        _mm256_storeu_ps(dst.as_mut_ptr().add(i), _mm256_mul_ps(s0, vscale));
2574        _mm256_storeu_ps(dst.as_mut_ptr().add(i + 8), _mm256_mul_ps(s1, vscale));
2575        i += 16;
2576    }
2577    while i + 8 <= n {
2578        let sv = _mm256_loadu_ps(src.as_ptr().add(i));
2579        _mm256_storeu_ps(dst.as_mut_ptr().add(i), _mm256_mul_ps(sv, vscale));
2580        i += 8;
2581    }
2582    for j in i..n {
2583        dst[j] = src[j] * scale;
2584    }
2585}
2586
2587#[cfg(target_arch = "aarch64")]
2588#[inline(always)]
2589#[allow(unsafe_op_in_unsafe_fn)]
2590unsafe fn scale_slice_neon(src: &[f32], dst: &mut [f32], scale: f32, n: usize) {
2591    use std::arch::aarch64::*;
2592    let vscale = vdupq_n_f32(scale);
2593    let mut i = 0;
2594
2595    while i + 16 <= n {
2596        let s0 = vld1q_f32(src.as_ptr().add(i));
2597        let s1 = vld1q_f32(src.as_ptr().add(i + 4));
2598        let s2 = vld1q_f32(src.as_ptr().add(i + 8));
2599        let s3 = vld1q_f32(src.as_ptr().add(i + 12));
2600        vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2601        vst1q_f32(dst.as_mut_ptr().add(i + 4), vmulq_f32(s1, vscale));
2602        vst1q_f32(dst.as_mut_ptr().add(i + 8), vmulq_f32(s2, vscale));
2603        vst1q_f32(dst.as_mut_ptr().add(i + 12), vmulq_f32(s3, vscale));
2604        i += 16;
2605    }
2606    while i + 8 <= n {
2607        let s0 = vld1q_f32(src.as_ptr().add(i));
2608        let s1 = vld1q_f32(src.as_ptr().add(i + 4));
2609        vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2610        vst1q_f32(dst.as_mut_ptr().add(i + 4), vmulq_f32(s1, vscale));
2611        i += 8;
2612    }
2613    while i + 4 <= n {
2614        let s0 = vld1q_f32(src.as_ptr().add(i));
2615        vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2616        i += 4;
2617    }
2618    for j in i..n {
2619        dst[j] = src[j] * scale;
2620    }
2621}
2622
2623#[allow(clippy::too_many_arguments)]
2624pub fn denormalise_bands(
2625    m: &CeltMode,
2626    x: &[f32],
2627    freq: &mut [f32],
2628    band_e: &[f32],
2629    start: usize,
2630    end: usize,
2631    channels: usize,
2632    m_val: usize,
2633) {
2634    let lm = m_val.trailing_zeros() as usize;
2635    let frame_size = m.short_mdct_size << lm;
2636    #[cfg(target_arch = "x86_64")]
2637    let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2638
2639    for c in 0..channels {
2640        for i in start..end {
2641            let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2642            let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2643            let band_log = band_e[c * m.nb_ebands + i];
2644
2645            // Match C: celt_exp2_db(MIN32(32.f, lg)) — cap gain to prevent overflow
2646            let g = (2.0f32).powf(band_log.min(32.0));
2647            let src = &x[base..base + n];
2648            let dst = &mut freq[base..base + n];
2649            #[cfg(target_arch = "x86_64")]
2650            if n >= 8 && use_avx2 {
2651                unsafe { scale_slice_avx2(src, dst, g, n) };
2652                continue;
2653            }
2654            #[cfg(target_arch = "aarch64")]
2655            if n >= 8 {
2656                unsafe { scale_slice_neon(src, dst, g, n) };
2657                continue;
2658            }
2659            for (d, &s) in dst.iter_mut().zip(src) {
2660                *d = s * g;
2661            }
2662        }
2663    }
2664}
2665
2666pub fn celt_lcg_rand(seed: u32) -> u32 {
2667    seed.wrapping_mul(1664525).wrapping_add(1013904223)
2668}
2669
2670#[cfg(target_arch = "aarch64")]
2671#[inline(always)]
2672#[allow(unsafe_op_in_unsafe_fn)]
2673unsafe fn renormalise_vector_neon(x: &mut [f32], n: usize, gain: f32) {
2674    use std::arch::aarch64::*;
2675
2676    let mut sum_vec = vdupq_n_f32(0.0);
2677    let mut i = 0;
2678
2679    while i + 16 <= n {
2680        let x0 = vld1q_f32(x.as_ptr().add(i));
2681        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2682        let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2683        let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2684        sum_vec = vfmaq_f32(sum_vec, x0, x0);
2685        sum_vec = vfmaq_f32(sum_vec, x1, x1);
2686        sum_vec = vfmaq_f32(sum_vec, x2, x2);
2687        sum_vec = vfmaq_f32(sum_vec, x3, x3);
2688        i += 16;
2689    }
2690
2691    while i + 8 <= n {
2692        let x0 = vld1q_f32(x.as_ptr().add(i));
2693        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2694        sum_vec = vfmaq_f32(sum_vec, x0, x0);
2695        sum_vec = vfmaq_f32(sum_vec, x1, x1);
2696        i += 8;
2697    }
2698
2699    while i + 4 <= n {
2700        let x0 = vld1q_f32(x.as_ptr().add(i));
2701        sum_vec = vfmaq_f32(sum_vec, x0, x0);
2702        i += 4;
2703    }
2704
2705    let mut e = 1e-15f32 + vaddvq_f32(sum_vec);
2706
2707    for j in i..n {
2708        e += x[j] * x[j];
2709    }
2710
2711    let norm = gain / e.sqrt();
2712    let vnorm = vdupq_n_f32(norm);
2713
2714    i = 0;
2715    while i + 16 <= n {
2716        let x0 = vld1q_f32(x.as_ptr().add(i));
2717        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2718        let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2719        let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2720        vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2721        vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2722        vst1q_f32(x.as_mut_ptr().add(i + 8), vmulq_f32(x2, vnorm));
2723        vst1q_f32(x.as_mut_ptr().add(i + 12), vmulq_f32(x3, vnorm));
2724        i += 16;
2725    }
2726
2727    while i + 8 <= n {
2728        let x0 = vld1q_f32(x.as_ptr().add(i));
2729        let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2730        vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2731        vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2732        i += 8;
2733    }
2734
2735    while i + 4 <= n {
2736        let x0 = vld1q_f32(x.as_ptr().add(i));
2737        vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2738        i += 4;
2739    }
2740
2741    for j in i..n {
2742        x[j] *= norm;
2743    }
2744}
2745
2746#[cfg(target_arch = "x86_64")]
2747#[target_feature(enable = "avx2,fma")]
2748unsafe fn renormalise_vector_avx2(x: &mut [f32], n: usize, gain: f32) {
2749    use std::arch::x86_64::*;
2750
2751    let mut i = 0usize;
2752
2753    let mut acc0 = _mm256_setzero_ps();
2754    let mut acc1 = _mm256_setzero_ps();
2755
2756    while i + 16 <= n {
2757        let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2758        let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
2759        acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2760        acc1 = _mm256_fmadd_ps(v1, v1, acc1);
2761        i += 16;
2762    }
2763
2764    if i + 8 <= n {
2765        let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2766        acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2767        i += 8;
2768    }
2769
2770    let acc = _mm256_add_ps(acc0, acc1);
2771    let hi = _mm256_extractf128_ps(acc, 1);
2772    let lo = _mm256_castps256_ps128(acc);
2773    let s4 = _mm_add_ps(lo, hi);
2774    let t1 = _mm_movehl_ps(s4, s4);
2775    let s2 = _mm_add_ps(s4, t1);
2776    let t2 = _mm_shuffle_ps(s2, s2, 0x55);
2777    let mut e = 1e-15f32 + _mm_cvtss_f32(_mm_add_ss(s2, t2));
2778
2779    for &v in &x[i..n] {
2780        e += v * v;
2781    }
2782
2783    let norm = gain / e.sqrt();
2784    let vnorm = _mm256_set1_ps(norm);
2785
2786    i = 0;
2787    while i + 16 <= n {
2788        let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2789        let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
2790        _mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v0, vnorm));
2791        _mm256_storeu_ps(x.as_mut_ptr().add(i + 8), _mm256_mul_ps(v1, vnorm));
2792        i += 16;
2793    }
2794    while i + 8 <= n {
2795        let v = _mm256_loadu_ps(x.as_ptr().add(i));
2796        _mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v, vnorm));
2797        i += 8;
2798    }
2799    for v in &mut x[i..n] {
2800        *v *= norm;
2801    }
2802}
2803
2804pub fn renormalise_vector(x: &mut [f32], n: usize, gain: f32) {
2805    #[cfg(target_arch = "aarch64")]
2806    unsafe {
2807        renormalise_vector_neon(x, n, gain);
2808    }
2809    #[cfg(target_arch = "x86_64")]
2810    unsafe {
2811        if n >= 16 && std::arch::is_x86_feature_detected!("avx2") {
2812            renormalise_vector_avx2(x, n, gain);
2813            return;
2814        }
2815    }
2816    #[cfg(all(not(target_arch = "aarch64"), not(target_arch = "x86_64")))]
2817    {
2818        let mut e = 1e-15f32;
2819        for &xv in x[..n].iter() {
2820            e += xv * xv;
2821        }
2822        let norm = gain / e.sqrt();
2823        for xv in x[..n].iter_mut() {
2824            *xv *= norm;
2825        }
2826    }
2827    #[cfg(target_arch = "x86_64")]
2828    {
2829        let mut e = 1e-15f32;
2830        for &xv in x[..n].iter() {
2831            e += xv * xv;
2832        }
2833        let norm = gain / e.sqrt();
2834        for xv in x[..n].iter_mut() {
2835            *xv *= norm;
2836        }
2837    }
2838}
2839
2840#[allow(clippy::too_many_arguments)]
2841pub fn anti_collapse(
2842    m: &CeltMode,
2843    x_buf: &mut [f32],
2844    collapse_masks: &[u32],
2845    lm: i32,
2846    channels: usize,
2847    size: usize,
2848    start: usize,
2849    end: usize,
2850    log_e: &[f32],
2851    prev1_log_e: &[f32],
2852    prev2_log_e: &[f32],
2853    pulses: &[i32],
2854    mut seed: u32,
2855) -> u32 {
2856    for i in start..end {
2857        let n0 = (m.e_bands[i + 1] - m.e_bands[i]) as usize;
2858        let depth = if n0 > 0 {
2859            ((1 + pulses[i]) / n0 as i32) >> lm
2860        } else {
2861            0
2862        };
2863
2864        let thresh = 0.5 * (-(0.125 * depth as f32)).exp2();
2865        let sqrt_1 = 1.0 / ((n0 << lm) as f32).sqrt();
2866
2867        for c in 0..channels {
2868            let p1 = prev1_log_e[c * m.nb_ebands + i];
2869            let p2 = prev2_log_e[c * m.nb_ebands + i];
2870
2871            let (p1_adj, p2_adj) = if channels == 1 && prev1_log_e.len() >= 2 * m.nb_ebands {
2872                (
2873                    p1.max(prev1_log_e[m.nb_ebands + i]),
2874                    p2.max(prev2_log_e[m.nb_ebands + i]),
2875                )
2876            } else {
2877                (p1, p2)
2878            };
2879
2880            let e_diff = log_e[c * m.nb_ebands + i] - p1_adj.min(p2_adj);
2881            let e_diff = e_diff.max(0.0);
2882
2883            let mut r = 2.0 * (-e_diff).exp2();
2884            if lm == 3 {
2885                r *= std::f32::consts::SQRT_2;
2886            }
2887            r = r.min(thresh);
2888            r *= sqrt_1;
2889
2890            let x_offset = c * size + ((m.e_bands[i] as usize) << lm);
2891            let mut renormalize = false;
2892            for k in 0..(1 << lm) {
2893                if (collapse_masks[i * channels + c] & (1 << k)) == 0 {
2894                    for j in 0..n0 {
2895                        seed = celt_lcg_rand(seed);
2896                        x_buf[x_offset + (j << lm) + k] = if (seed & 0x8000) != 0 { r } else { -r };
2897                    }
2898                    renormalize = true;
2899                }
2900            }
2901            if renormalize {
2902                renormalise_vector(&mut x_buf[x_offset..x_offset + (n0 << lm)], n0 << lm, 1.0);
2903            }
2904        }
2905    }
2906    seed
2907}