1use crate::modes::CeltMode;
2use crate::pvq::*;
3use crate::range_coder::RangeCoder;
4use crate::rate::{BITRES, bits2pulses, get_pulses, pulses2bits};
5use crate::tell_frac_inline;
6
7const MIN_STEREO_ENERGY: f32 = 1e-10;
8
9pub struct BandCtx<'a> {
10 pub encode: bool,
11 pub m: &'a CeltMode,
12 pub i: usize,
13 pub band_e: &'a [f32],
14 pub rc: &'a mut RangeCoder,
15 pub spread: i32,
16 pub remaining_bits: i32,
17 pub resynth: bool,
18 pub tf_change: i32,
19 pub intensity: usize,
20 pub theta_round: i32,
21 pub avoid_split_noise: bool,
22 pub arch: i32,
23 pub disable_inv: bool,
24 pub seed: u32,
25}
26
27#[inline]
32fn bitexact_cos(x: i16) -> i16 {
33 #[inline(always)]
35 fn frac_mul16(a: i16, b: i16) -> i16 {
36 ((16384i32 + (a as i32) * (b as i32)) >> 15) as i16
37 }
38
39 let tmp = (4096i32 + (x as i32) * (x as i32)) >> 13;
40 let x2 = tmp as i16;
41 let x2 = (32767 - x2 as i32
42 + frac_mul16(x2, -7651 + frac_mul16(x2, 8277 + frac_mul16(-626, x2))) as i32)
43 as i16;
44 1 + x2
45}
46
47#[inline]
48pub fn bitexact_log2tan(isin: i32, icos: i32) -> i32 {
49 let ec_ilog = |x: u32| -> i32 {
50 if x == 0 {
51 0
52 } else {
53 32 - x.leading_zeros() as i32
54 }
55 };
56 let lc = ec_ilog(icos.max(0) as u32);
57 let ls = ec_ilog(isin.max(0) as u32);
58 let icos_shifted = if lc > 0 {
59 icos.max(0) << (15 - lc).max(0)
60 } else {
61 0
62 };
63 let isin_shifted = if ls > 0 {
64 isin.max(0) << (15 - ls).max(0)
65 } else {
66 0
67 };
68 let fract_mul = |a: i32, b: i32| -> i32 { (a * b + 16384) >> 15 };
69 (ls - lc) * (1 << 11) + fract_mul(isin_shifted, fract_mul(isin_shifted, -2597) + 7932)
70 - fract_mul(icos_shifted, fract_mul(icos_shifted, -2597) + 7932)
71}
72
73#[inline(always)]
74fn celt_sudiv(n: i32, d: i32) -> i32 {
75 if n < 0 {
76 -((-n + (d >> 1)) / d)
77 } else {
78 (n + (d >> 1)) / d
79 }
80}
81
82#[inline]
83fn isqrt32(mut val: u32) -> u32 {
84 let mut g = 0u32;
85 let mut bshift = ((32 - val.leading_zeros()) as i32 - 1) >> 1;
86 let mut b = 1u32 << bshift;
87 while bshift >= 0 {
88 let t = (((g << 1) + b) as u64) << bshift;
89 if t <= val as u64 {
90 g += b;
91 val -= t as u32;
92 }
93 b >>= 1;
94 bshift -= 1;
95 }
96 g
97}
98
99pub const SPREAD_NONE: i32 = 0;
100pub const SPREAD_LIGHT: i32 = 1;
101pub const SPREAD_NORMAL: i32 = 2;
102pub const SPREAD_AGGRESSIVE: i32 = 3;
103
104#[allow(clippy::too_many_arguments)]
105pub fn spreading_decision(
106 m: &CeltMode,
107 x_buf: &[f32],
108 average: &mut i32,
109 last_decision: i32,
110 hf_average: &mut i32,
111 tapset_decision: &mut i32,
112 update_hf: bool,
113 end: usize,
114 channels: usize,
115 m_val: usize,
116 spread_weight: &[i32],
117) -> i32 {
118 let mut sum = 0;
119 let mut nb_bands = 0;
120 let n0 = m_val * m.short_mdct_size;
121 let mut hf_sum = 0;
122
123 if m_val * (m.e_bands[end] as usize - m.e_bands[end - 1] as usize) <= 8 {
124 return SPREAD_NONE;
125 }
126
127 for c in 0..channels {
128 for (i, &sw) in spread_weight[..end].iter().enumerate() {
129 let n = m_val * (m.e_bands[i + 1] as usize - m.e_bands[i] as usize);
130 if n <= 8 {
131 continue;
132 }
133
134 let mut tcount = [0; 3];
135 let offset = m_val * m.e_bands[i] as usize + c * n0;
136 let x = &x_buf[offset..offset + n];
137
138 for xv in x.iter().copied() {
139 let x2n = xv * xv * (n as f32);
140 if x2n < 0.25 {
141 tcount[0] += 1;
142 }
143 if x2n < 0.0625 {
144 tcount[1] += 1;
145 }
146 if x2n < 0.015625 {
147 tcount[2] += 1;
148 }
149 }
150
151 if i > m.nb_ebands - 4 {
152 hf_sum += 32 * (tcount[1] + tcount[0]) / (n as i32);
153 }
154
155 let tmp = (if 2 * tcount[2] >= (n as i32) { 1 } else { 0 })
156 + (if 2 * tcount[1] >= (n as i32) { 1 } else { 0 })
157 + (if 2 * tcount[0] >= (n as i32) { 1 } else { 0 });
158 sum += tmp * sw;
159 nb_bands += sw;
160 }
161 }
162
163 if update_hf {
164 if hf_sum > 0 {
165 hf_sum /= (channels as i32) * (4 - m.nb_ebands as i32 + end as i32);
166 }
167 *hf_average = (*hf_average + hf_sum) >> 1;
168 hf_sum = *hf_average;
169
170 if *tapset_decision == 2 {
171 hf_sum += 4;
172 } else if *tapset_decision == 0 {
173 hf_sum -= 4;
174 }
175
176 if hf_sum > 22 {
177 *tapset_decision = 2;
178 } else if hf_sum > 18 {
179 *tapset_decision = 1;
180 } else {
181 *tapset_decision = 0;
182 }
183 }
184
185 if nb_bands == 0 {
186 return SPREAD_NORMAL;
187 }
188
189 let mut sum_scaled = (sum << 8) / nb_bands;
190 sum_scaled = (sum_scaled + *average) >> 1;
191 *average = sum_scaled;
192
193 let sum_final = (3 * sum_scaled + (((3 - last_decision) << 7) + 64) + 2) >> 2;
194
195 if sum_final < 80 {
196 SPREAD_AGGRESSIVE
197 } else if sum_final < 256 {
198 SPREAD_NORMAL
199 } else if sum_final < 384 {
200 SPREAD_LIGHT
201 } else {
202 SPREAD_NONE
203 }
204}
205
206pub fn haar1(x: &mut [f32], n0: usize, stride: usize) {
207 #[cfg(target_arch = "aarch64")]
208 {
209 haar1_neon(x, n0, stride);
210 return;
211 }
212 #[cfg(not(target_arch = "aarch64"))]
213 {
214 haar1_scalar(x, n0, stride);
215 }
216}
217
218#[cfg(not(target_arch = "aarch64"))]
219#[inline]
220fn haar1_scalar(x: &mut [f32], n0: usize, stride: usize) {
221 let n = n0 >> 1;
222 let scale = std::f32::consts::FRAC_1_SQRT_2;
223 for i in 0..stride {
224 for j in 0..n {
225 let idx1 = stride * 2 * j + i;
226 let idx2 = stride * (2 * j + 1) + i;
227 let tmp1 = scale * x[idx1];
228 let tmp2 = scale * x[idx2];
229 x[idx1] = tmp1 + tmp2;
230 x[idx2] = tmp1 - tmp2;
231 }
232 }
233}
234
235#[cfg(target_arch = "aarch64")]
236fn haar1_neon(x: &mut [f32], n0: usize, stride: usize) {
237 use std::arch::aarch64::*;
238
239 let n = n0 >> 1;
240 let scale = std::f32::consts::FRAC_1_SQRT_2;
241
242 unsafe {
243 let vscale = vdupq_n_f32(scale);
244
245 for i in 0..stride {
246 let mut j = 0;
248 while j + 4 <= n {
249 let idx_even = stride * 2 * j + i;
251 let idx_odd = stride * (2 * j + 1) + i;
252
253 let ve0 = vld1q_f32(x.as_ptr().add(idx_even));
254 let ve1 = vld1q_f32(x.as_ptr().add(idx_even + stride * 2));
255 let ve2 = vld1q_f32(x.as_ptr().add(idx_even + stride * 4));
256 let ve3 = vld1q_f32(x.as_ptr().add(idx_even + stride * 6));
257
258 let vo0 = vld1q_f32(x.as_ptr().add(idx_odd));
259 let vo1 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 2));
260 let vo2 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 4));
261 let vo3 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 6));
262
263 let te0 = vmulq_f32(ve0, vscale);
265 let te1 = vmulq_f32(ve1, vscale);
266 let te2 = vmulq_f32(ve2, vscale);
267 let te3 = vmulq_f32(ve3, vscale);
268
269 let to0 = vmulq_f32(vo0, vscale);
270 let to1 = vmulq_f32(vo1, vscale);
271 let to2 = vmulq_f32(vo2, vscale);
272 let to3 = vmulq_f32(vo3, vscale);
273
274 vst1q_f32(x.as_mut_ptr().add(idx_even), vaddq_f32(te0, to0));
276 vst1q_f32(
277 x.as_mut_ptr().add(idx_even + stride * 2),
278 vaddq_f32(te1, to1),
279 );
280 vst1q_f32(
281 x.as_mut_ptr().add(idx_even + stride * 4),
282 vaddq_f32(te2, to2),
283 );
284 vst1q_f32(
285 x.as_mut_ptr().add(idx_even + stride * 6),
286 vaddq_f32(te3, to3),
287 );
288
289 vst1q_f32(x.as_mut_ptr().add(idx_odd), vsubq_f32(te0, to0));
290 vst1q_f32(
291 x.as_mut_ptr().add(idx_odd + stride * 2),
292 vsubq_f32(te1, to1),
293 );
294 vst1q_f32(
295 x.as_mut_ptr().add(idx_odd + stride * 4),
296 vsubq_f32(te2, to2),
297 );
298 vst1q_f32(
299 x.as_mut_ptr().add(idx_odd + stride * 6),
300 vsubq_f32(te3, to3),
301 );
302
303 j += 4;
304 }
305
306 while j < n {
308 let idx1 = stride * 2 * j + i;
309 let idx2 = stride * (2 * j + 1) + i;
310 let tmp1 = scale * x[idx1];
311 let tmp2 = scale * x[idx2];
312 x[idx1] = tmp1 + tmp2;
313 x[idx2] = tmp1 - tmp2;
314 j += 1;
315 }
316 }
317 }
318}
319
320#[inline(always)]
321pub fn compute_qn(n: usize, b: i32, offset: i32, pulse_cap: i32, stereo: bool) -> i32 {
322 static EXP2_TABLE8: [i16; 8] = [16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048];
323 let mut n2 = (2 * n as i32) - 1;
324 if stereo && n == 2 {
325 n2 -= 1;
326 }
327 let mut qb = celt_sudiv(b + n2 * offset, n2);
328 qb = qb.min(b - pulse_cap - (4 << BITRES));
329 qb = qb.min(8 << BITRES);
330 if qb < (1i32 << BITRES >> 1) {
331 1
332 } else {
333 let val = EXP2_TABLE8[(qb & 0x7) as usize] as i32;
334 let shift = 14 - (qb >> BITRES);
335 let raw = if (0..32).contains(&shift) {
336 val >> shift
337 } else {
338 0
339 };
340 let qn = (raw + 1) >> 1 << 1;
341 qn.min(256)
342 }
343}
344
345#[cfg(target_arch = "aarch64")]
347#[inline(always)]
348#[allow(unsafe_op_in_unsafe_fn)]
349unsafe fn stereo_itheta_neon(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
350 use std::arch::aarch64::*;
351
352 let mut emid = 1e-15f32;
353 let mut eside = 1e-15f32;
354
355 if stereo {
356 let mut sum_mid = vdupq_n_f32(0.0);
357 let mut sum_side = vdupq_n_f32(0.0);
358 let mut i = 0;
359
360 while i + 16 <= n {
362 let x0 = vld1q_f32(x.as_ptr().add(i));
363 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
364 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
365 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
366 let y0 = vld1q_f32(y.as_ptr().add(i));
367 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
368 let y2 = vld1q_f32(y.as_ptr().add(i + 8));
369 let y3 = vld1q_f32(y.as_ptr().add(i + 12));
370
371 let m0 = vaddq_f32(x0, y0);
372 let m1 = vaddq_f32(x1, y1);
373 let m2 = vaddq_f32(x2, y2);
374 let m3 = vaddq_f32(x3, y3);
375 let s0 = vsubq_f32(x0, y0);
376 let s1 = vsubq_f32(x1, y1);
377 let s2 = vsubq_f32(x2, y2);
378 let s3 = vsubq_f32(x3, y3);
379
380 sum_mid = vfmaq_f32(sum_mid, m0, m0);
381 sum_mid = vfmaq_f32(sum_mid, m1, m1);
382 sum_mid = vfmaq_f32(sum_mid, m2, m2);
383 sum_mid = vfmaq_f32(sum_mid, m3, m3);
384 sum_side = vfmaq_f32(sum_side, s0, s0);
385 sum_side = vfmaq_f32(sum_side, s1, s1);
386 sum_side = vfmaq_f32(sum_side, s2, s2);
387 sum_side = vfmaq_f32(sum_side, s3, s3);
388
389 i += 16;
390 }
391
392 while i + 8 <= n {
394 let x0 = vld1q_f32(x.as_ptr().add(i));
395 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
396 let y0 = vld1q_f32(y.as_ptr().add(i));
397 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
398
399 let m0 = vaddq_f32(x0, y0);
400 let m1 = vaddq_f32(x1, y1);
401 let s0 = vsubq_f32(x0, y0);
402 let s1 = vsubq_f32(x1, y1);
403
404 sum_mid = vfmaq_f32(sum_mid, m0, m0);
405 sum_mid = vfmaq_f32(sum_mid, m1, m1);
406 sum_side = vfmaq_f32(sum_side, s0, s0);
407 sum_side = vfmaq_f32(sum_side, s1, s1);
408
409 i += 8;
410 }
411
412 while i + 4 <= n {
414 let x0 = vld1q_f32(x.as_ptr().add(i));
415 let y0 = vld1q_f32(y.as_ptr().add(i));
416 let m0 = vaddq_f32(x0, y0);
417 let s0 = vsubq_f32(x0, y0);
418 sum_mid = vfmaq_f32(sum_mid, m0, m0);
419 sum_side = vfmaq_f32(sum_side, s0, s0);
420 i += 4;
421 }
422
423 emid += vaddvq_f32(sum_mid);
424 eside += vaddvq_f32(sum_side);
425
426 for j in i..n {
428 let m = x[j] + y[j];
429 let s = x[j] - y[j];
430 emid += m * m;
431 eside += s * s;
432 }
433 } else {
434 let mut sum_mid = vdupq_n_f32(0.0);
435 let mut sum_side = vdupq_n_f32(0.0);
436 let mut i = 0;
437
438 while i + 16 <= n {
440 let x0 = vld1q_f32(x.as_ptr().add(i));
441 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
442 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
443 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
444 let y0 = vld1q_f32(y.as_ptr().add(i));
445 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
446 let y2 = vld1q_f32(y.as_ptr().add(i + 8));
447 let y3 = vld1q_f32(y.as_ptr().add(i + 12));
448
449 sum_mid = vfmaq_f32(sum_mid, x0, x0);
450 sum_mid = vfmaq_f32(sum_mid, x1, x1);
451 sum_mid = vfmaq_f32(sum_mid, x2, x2);
452 sum_mid = vfmaq_f32(sum_mid, x3, x3);
453 sum_side = vfmaq_f32(sum_side, y0, y0);
454 sum_side = vfmaq_f32(sum_side, y1, y1);
455 sum_side = vfmaq_f32(sum_side, y2, y2);
456 sum_side = vfmaq_f32(sum_side, y3, y3);
457
458 i += 16;
459 }
460
461 while i + 8 <= n {
463 let x0 = vld1q_f32(x.as_ptr().add(i));
464 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
465 let y0 = vld1q_f32(y.as_ptr().add(i));
466 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
467
468 sum_mid = vfmaq_f32(sum_mid, x0, x0);
469 sum_mid = vfmaq_f32(sum_mid, x1, x1);
470 sum_side = vfmaq_f32(sum_side, y0, y0);
471 sum_side = vfmaq_f32(sum_side, y1, y1);
472
473 i += 8;
474 }
475
476 while i + 4 <= n {
478 let x0 = vld1q_f32(x.as_ptr().add(i));
479 let y0 = vld1q_f32(y.as_ptr().add(i));
480 sum_mid = vfmaq_f32(sum_mid, x0, x0);
481 sum_side = vfmaq_f32(sum_side, y0, y0);
482 i += 4;
483 }
484
485 emid += vaddvq_f32(sum_mid);
486 eside += vaddvq_f32(sum_side);
487
488 for j in i..n {
490 emid += x[j] * x[j];
491 eside += y[j] * y[j];
492 }
493 }
494
495 let mid = emid.sqrt();
496 let side = eside.sqrt();
497 let theta_norm = celt_atan2p_norm(side, mid);
498 (0.5 + 16384.0 * theta_norm) as i32
499}
500
501#[inline(always)]
502#[cfg(target_arch = "aarch64")]
503pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
504 unsafe {
505 return stereo_itheta_neon(x, y, stereo, n);
506 }
507}
508
509#[inline(always)]
510#[cfg(not(target_arch = "aarch64"))]
511pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
512 #[cfg(target_arch = "aarch64")]
514 unsafe {
515 return stereo_itheta_neon(x, y, stereo, n);
516 }
517 #[cfg(not(target_arch = "aarch64"))]
518 {
519 let mut emid = 1e-15f32;
520 let mut eside = 1e-15f32;
521 if stereo {
522 for i in 0..n {
523 let m = x[i] + y[i];
524 let s = x[i] - y[i];
525 emid += m * m;
526 eside += s * s;
527 }
528 } else {
529 for i in 0..n {
530 emid += x[i] * x[i];
531 eside += y[i] * y[i];
532 }
533 }
534 let mid = emid.sqrt();
535 let side = eside.sqrt();
536 let theta_norm = celt_atan2p_norm(side, mid);
537 (0.5 + 16384.0 * theta_norm) as i32
538 }
539}
540
541#[inline(always)]
545fn celt_atan2p_norm(y: f32, x: f32) -> f32 {
546 #[inline(always)]
547 fn atan_norm(x: f32) -> f32 {
548 const ATAN2_2_OVER_PI: f32 = 0.636_619_77_f32;
549 const A03: f32 = -3.333_165_9e-1_f32;
550 const A05: f32 = 1.996_270_4e-1_f32;
551 const A07: f32 = -1.397_658_3e-1_f32;
552 const A09: f32 = 9.794_234_e-2_f32;
553 const A11: f32 = -5.777_359_e-2_f32;
554 const A13: f32 = 2.304_014e-2_f32;
555 const A15: f32 = -4.355_406e-3_f32;
556 let x2 = x * x;
557 ATAN2_2_OVER_PI
558 * x
559 * (1.0
560 + x2 * (A03
561 + x2 * (A05 + x2 * (A07 + x2 * (A09 + x2 * (A11 + x2 * (A13 + x2 * A15)))))))
562 }
563 if x * x + y * y < 1e-18 {
564 return 0.0;
565 }
566 if y < x {
567 atan_norm(y / x)
568 } else {
569 1.0 - atan_norm(x / y)
570 }
571}
572
573pub struct SplitCtx {
574 pub inv: bool,
575 pub imid: i32,
576 pub iside: i32,
577 pub delta: i32,
578 pub itheta: i32,
579 pub qalloc: i32,
580}
581
582#[allow(clippy::too_many_arguments)]
584#[inline(always)]
585fn compute_theta_encode(
586 ctx: &mut BandCtx,
587 sctx: &mut SplitCtx,
588 x: &[f32],
589 y: &[f32],
590 n: usize,
591 b: &mut i32,
592 b_blocks: i32,
593 _b0: i32,
594 lm: i32,
595 stereo: bool,
596 fill: &mut u32,
597) {
598 let pulse_cap = ctx.m.log_n[ctx.i] as i32 + (lm << BITRES);
599 let offset = (pulse_cap >> 1) - if stereo && n == 2 { 16 } else { 4 };
600 let mut qn = compute_qn(n, *b, offset, pulse_cap, stereo);
601
602 if stereo && ctx.i >= ctx.intensity {
603 qn = 1;
604 }
605
606 if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
610 sctx.itheta = 8192;
611 sctx.qalloc = 0;
612 let imid = bitexact_cos(8192i16);
613 sctx.imid = imid as i32;
614 let iside = bitexact_cos(8192i16);
615 sctx.iside = iside as i32;
616 sctx.delta =
617 (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
618 return;
619 }
620
621 let mut itheta = stereo_itheta(x, y, stereo, n);
622
623 let tell_start = tell_frac_inline!(ctx.rc);
624
625 if qn != 1 {
626 if !stereo || ctx.theta_round == 0 {
627 itheta = (itheta * qn + 8192) >> 14;
628 if !stereo && ctx.avoid_split_noise && itheta > 0 && itheta < qn {
629 let unquantized = (itheta * 16384) / qn;
630 let imid = bitexact_cos(unquantized as i16) as i32;
631 let iside = bitexact_cos((16384 - unquantized) as i16) as i32;
632 let delta = (((n as i32 - 1) << 7) * bitexact_log2tan(iside, imid) + 16384) >> 15;
633 if delta > *b {
634 itheta = qn;
635 } else if delta < -*b {
636 itheta = 0;
637 }
638 }
639 } else {
640 let bias = if itheta > 8192 {
641 32767 / qn
642 } else {
643 -32767 / qn
644 };
645 let down = (itheta * qn + bias) >> 14;
646 let down = down.clamp(0, qn - 1);
647 if ctx.theta_round < 0 {
648 itheta = down;
649 } else {
650 itheta = down + 1;
651 }
652 }
653
654 if stereo && n > 2 {
655 let p0 = 3;
656 let x0 = qn / 2;
657 let ft = p0 * (x0 + 1) + x0;
658 let fl = if itheta <= x0 {
659 p0 * itheta
660 } else {
661 (itheta - 1 - x0) + (x0 + 1) * p0
662 };
663 let fh = if itheta <= x0 {
664 p0 * (itheta + 1)
665 } else {
666 (itheta - x0) + (x0 + 1) * p0
667 };
668 ctx.rc.encode(fl as u32, fh as u32, ft as u32);
669 } else if b_blocks > 1 || stereo {
670 ctx.rc.enc_uint(itheta as u32, (qn + 1) as u32);
671 } else {
672 let ft = ((qn >> 1) + 1) * ((qn >> 1) + 1);
673 let fs = if itheta <= (qn >> 1) {
674 itheta + 1
675 } else {
676 qn + 1 - itheta
677 };
678 let fl = if itheta <= (qn >> 1) {
679 (itheta * (itheta + 1)) >> 1
680 } else {
681 ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1)
682 };
683 ctx.rc.encode(fl as u32, (fl + fs) as u32, ft as u32);
684 }
685 itheta = (itheta * 16384 + (qn >> 1)) / qn;
686 } else {
687 if stereo && ctx.i >= ctx.intensity {
688 let mut emid = 1e-15f32;
689 let mut eside = 1e-15f32;
690 for i in 0..n {
691 let m = x[i] + y[i];
692 let s = x[i] - y[i];
693 emid += m * m;
694 eside += s * s;
695 }
696 let inv = eside > emid;
697 ctx.rc.encode_bit_logp(inv, 1);
698 itheta = 0;
699 sctx.inv = inv;
700 } else {
701 itheta = 8192;
702 }
703 }
704
705 sctx.itheta = itheta;
706
707 sctx.qalloc = if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
708 0
709 } else {
710 tell_frac_inline!(ctx.rc) - tell_start
711 };
712
713 if itheta == 0 {
714 sctx.imid = 32767;
715 sctx.iside = 0;
716 sctx.delta = -16384;
717 *fill &= (1 << b_blocks) - 1;
718 } else if itheta == 16384 {
719 sctx.imid = 0;
720 sctx.iside = 32767;
721 sctx.delta = 16384;
722 *fill &= !((1 << b_blocks) - 1);
723 } else {
724 let imid = bitexact_cos(itheta as i16);
725 sctx.imid = imid as i32;
726 let iside = bitexact_cos((16384 - itheta) as i16);
727 sctx.iside = iside as i32;
728 sctx.delta =
729 (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
730 }
731}
732
733#[allow(clippy::too_many_arguments)]
734#[inline(always)]
735pub fn compute_theta(
736 ctx: &mut BandCtx,
737 sctx: &mut SplitCtx,
738 x: &[f32],
739 y: &[f32],
740 n: usize,
741 b: &mut i32,
742 b_blocks: i32,
743 _b0: i32,
744 lm: i32,
745 stereo: bool,
746 fill: &mut u32,
747) {
748 let pulse_cap = ctx.m.log_n[ctx.i] as i32 + (lm << BITRES);
749 let offset = (pulse_cap >> 1) - if stereo && n == 2 { 16 } else { 4 };
750 let mut qn = compute_qn(n, *b, offset, pulse_cap, stereo);
751
752 if stereo && ctx.i >= ctx.intensity {
753 qn = 1;
754 }
755
756 let mut itheta = 0;
757 if ctx.encode {
758 itheta = stereo_itheta(x, y, stereo, n);
759 }
760
761 let tell_start = tell_frac_inline!(ctx.rc);
762
763 if qn != 1 {
764 if ctx.encode {
765 if !stereo || ctx.theta_round == 0 {
766 itheta = (itheta * qn + 8192) >> 14;
767 if !stereo && ctx.avoid_split_noise && itheta > 0 && itheta < qn {
768 let unquantized = (itheta * 16384) / qn;
769 let imid = bitexact_cos(unquantized as i16) as i32;
770 let iside = bitexact_cos((16384 - unquantized) as i16) as i32;
771 let delta =
772 (((n as i32 - 1) << 7) * bitexact_log2tan(iside, imid) + 16384) >> 15;
773 if delta > *b {
774 itheta = qn;
775 } else if delta < -*b {
776 itheta = 0;
777 }
778 }
779 } else {
780 let bias = if itheta > 8192 {
781 32767 / qn
782 } else {
783 -32767 / qn
784 };
785 let down = (itheta * qn + bias) >> 14;
786 let down = down.clamp(0, qn - 1);
787 if ctx.theta_round < 0 {
788 itheta = down;
789 } else {
790 itheta = down + 1;
791 }
792 }
793 }
794
795 if stereo && n > 2 {
796 let p0 = 3;
797 let x0 = qn / 2;
798 let ft = p0 * (x0 + 1) + x0;
799 if ctx.encode {
800 let fl = if itheta <= x0 {
801 p0 * itheta
802 } else {
803 (itheta - 1 - x0) + (x0 + 1) * p0
804 };
805 let fh = if itheta <= x0 {
806 p0 * (itheta + 1)
807 } else {
808 (itheta - x0) + (x0 + 1) * p0
809 };
810 ctx.rc.encode(fl as u32, fh as u32, ft as u32);
811 } else {
812 let fs = ctx.rc.decode(ft as u32);
813 if fs < (x0 + 1) as u32 * p0 as u32 {
814 itheta = fs as i32 / p0;
815 } else {
816 itheta = (x0 + 1) + (fs as i32 - (x0 + 1) * p0);
817 }
818 let fl = if itheta <= x0 {
819 p0 * itheta
820 } else {
821 (itheta - 1 - x0) + (x0 + 1) * p0
822 };
823 let fh = if itheta <= x0 {
824 p0 * (itheta + 1)
825 } else {
826 (itheta - x0) + (x0 + 1) * p0
827 };
828 ctx.rc.update(fl as u32, fh as u32, ft as u32);
829 }
830 } else if b_blocks > 1 || stereo {
831 if ctx.encode {
832 ctx.rc.enc_uint(itheta as u32, (qn + 1) as u32);
833 } else {
834 itheta = ctx.rc.dec_uint((qn + 1) as u32) as i32;
835 }
836 } else {
837 let ft = ((qn >> 1) + 1) * ((qn >> 1) + 1);
838 if ctx.encode {
839 let fs = if itheta <= (qn >> 1) {
840 itheta + 1
841 } else {
842 qn + 1 - itheta
843 };
844 let fl = if itheta <= (qn >> 1) {
845 (itheta * (itheta + 1)) >> 1
846 } else {
847 ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1)
848 };
849 ctx.rc.encode(fl as u32, (fl + fs) as u32, ft as u32);
850 } else {
851 let fm = ctx.rc.decode(ft as u32) as i32;
852 if fm < (((qn >> 1) * ((qn >> 1) + 1)) >> 1) {
853 itheta = (isqrt32((8 * fm + 1) as u32) as i32 - 1) >> 1;
854 let fl = (itheta * (itheta + 1)) >> 1;
855 let fs = itheta + 1;
856 ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
857 } else {
858 itheta = (2 * (qn + 1) - isqrt32((8 * (ft - fm - 1) + 1) as u32) as i32) >> 1;
859 let fs = qn + 1 - itheta;
860 let fl = ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1);
861 ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
862 }
863 }
864 }
865 itheta = (itheta * 16384 + (qn >> 1)) / qn;
866 } else {
867 if stereo && ctx.i >= ctx.intensity {
868 if ctx.encode {
869 let mut emid = 1e-15f32;
870 let mut eside = 1e-15f32;
871 for i in 0..n {
872 let m = x[i] + y[i];
873 let s = x[i] - y[i];
874 emid += m * m;
875 eside += s * s;
876 }
877 let inv = eside > emid;
878 ctx.rc.encode_bit_logp(inv, 1);
879 itheta = 0;
880 sctx.inv = inv;
881 } else {
882 sctx.inv = ctx.rc.decode_bit_logp(1);
883 itheta = 0;
884 }
885 } else {
886 itheta = 8192;
887 }
888 }
889
890 sctx.itheta = itheta;
891
892 sctx.qalloc = if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
895 0 } else {
897 tell_frac_inline!(ctx.rc) - tell_start
898 };
899
900 if itheta == 0 {
901 sctx.imid = 32767;
902 sctx.iside = 0;
903 sctx.delta = -16384;
904 *fill &= (1 << b_blocks) - 1;
905 } else if itheta == 16384 {
906 sctx.imid = 0;
907 sctx.iside = 32767;
908 sctx.delta = 16384;
909 *fill &= !((1 << b_blocks) - 1);
910 } else {
911 let imid = bitexact_cos(itheta as i16);
912 sctx.imid = imid as i32;
913 let iside = bitexact_cos((16384 - itheta) as i16);
914 sctx.iside = iside as i32;
915 sctx.delta =
916 (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
917 }
918}
919
920#[inline(always)]
922fn quant_partition_n2_encode(
923 ctx: &mut BandCtx,
924 x: &mut [f32],
925 b: i32,
926 b_blocks: i32,
927 lowband: Option<&mut [f32]>,
928 lm: i32,
929 gain: f32,
930 fill: u32,
931) -> u32 {
932 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
933 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
934 ctx.remaining_bits -= curr_bits;
935
936 while ctx.remaining_bits < 0 && q > 0 {
937 ctx.remaining_bits += curr_bits;
938 q -= 1;
939 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
940 ctx.remaining_bits -= curr_bits;
941 }
942
943 if q != 0 {
944 let k = get_pulses(q);
945 alg_quant(x, 2, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
946 } else {
947 let has_lowband = lowband.is_some();
948 if has_lowband {
949 fill
950 } else {
951 (1u32 << b_blocks) - 1
952 }
953 }
954}
955
956#[inline(always)]
958fn quant_partition_n2(
959 ctx: &mut BandCtx,
960 x: &mut [f32],
961 b: i32,
962 b_blocks: i32,
963 lowband: Option<&mut [f32]>,
964 lm: i32,
965 gain: f32,
966 fill: u32,
967) -> u32 {
968 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
970 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
971 ctx.remaining_bits -= curr_bits;
972
973 while ctx.remaining_bits < 0 && q > 0 {
974 ctx.remaining_bits += curr_bits;
975 q -= 1;
976 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
977 ctx.remaining_bits -= curr_bits;
978 }
979
980 if q != 0 {
981 let k = get_pulses(q);
982 if ctx.encode {
983 alg_quant(
984 x,
985 2,
986 k,
987 ctx.spread,
988 b_blocks as usize,
989 ctx.rc,
990 gain,
991 ctx.resynth,
992 )
993 } else {
994 alg_unquant(x, 2, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
995 }
996 } else {
997 let has_lowband = lowband.is_some();
998 if ctx.resynth {
999 let mut seed = ctx.rc.tell() as u32;
1000 if has_lowband {
1001 let lb = lowband.unwrap();
1002 for i in 0..2 {
1003 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1004 x[i] = lb[i]
1005 + if seed & 0x8000 != 0 {
1006 1.0 / 256.0
1007 } else {
1008 -1.0 / 256.0
1009 };
1010 }
1011 } else {
1012 for i in 0..2 {
1013 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1014 x[i] = ((seed as i32 >> 20) as f32) / 16384.0;
1015 }
1016 }
1017 renormalise_vector(x, 2, gain);
1018 }
1019 if has_lowband {
1020 fill
1021 } else {
1022 (1u32 << b_blocks) - 1
1023 }
1024 }
1025}
1026
1027#[inline(always)]
1029fn quant_partition_n4_encode(
1030 ctx: &mut BandCtx,
1031 x: &mut [f32],
1032 b: i32,
1033 b_blocks: i32,
1034 lowband: Option<&mut [f32]>,
1035 lm: i32,
1036 gain: f32,
1037 fill: u32,
1038) -> u32 {
1039 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1040 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1041 ctx.remaining_bits -= curr_bits;
1042
1043 while ctx.remaining_bits < 0 && q > 0 {
1044 ctx.remaining_bits += curr_bits;
1045 q -= 1;
1046 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1047 ctx.remaining_bits -= curr_bits;
1048 }
1049
1050 if q != 0 {
1051 let k = get_pulses(q);
1052 alg_quant(x, 4, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1053 } else {
1054 let has_lowband = lowband.is_some();
1055 if has_lowband {
1056 fill
1057 } else {
1058 (1u32 << b_blocks) - 1
1059 }
1060 }
1061}
1062
1063#[inline(always)]
1065fn quant_partition_n8_encode(
1066 ctx: &mut BandCtx,
1067 x: &mut [f32],
1068 b: i32,
1069 b_blocks: i32,
1070 lowband: Option<&mut [f32]>,
1071 lm: i32,
1072 gain: f32,
1073 fill: u32,
1074) -> u32 {
1075 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1076 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1077 ctx.remaining_bits -= curr_bits;
1078
1079 while ctx.remaining_bits < 0 && q > 0 {
1080 ctx.remaining_bits += curr_bits;
1081 q -= 1;
1082 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1083 ctx.remaining_bits -= curr_bits;
1084 }
1085
1086 if q != 0 {
1087 let k = get_pulses(q);
1088 alg_quant(x, 8, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1089 } else {
1090 let has_lowband = lowband.is_some();
1091 if has_lowband {
1092 fill
1093 } else {
1094 (1u32 << b_blocks) - 1
1095 }
1096 }
1097}
1098
1099#[inline(always)]
1101#[allow(clippy::too_many_arguments)]
1102fn quant_partition_direct_encode(
1103 ctx: &mut BandCtx,
1104 x: &mut [f32],
1105 n: usize,
1106 b: i32,
1107 b_blocks: i32,
1108 lowband: Option<&mut [f32]>,
1109 lm: i32,
1110 gain: f32,
1111 fill: u32,
1112) -> u32 {
1113 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1114 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1115 ctx.remaining_bits -= curr_bits;
1116
1117 while ctx.remaining_bits < 0 && q > 0 {
1118 ctx.remaining_bits += curr_bits;
1119 q -= 1;
1120 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1121 ctx.remaining_bits -= curr_bits;
1122 }
1123
1124 if q != 0 {
1125 let k = get_pulses(q);
1126 alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1127 } else {
1128 let has_lowband = lowband.is_some();
1129 if has_lowband {
1130 fill
1131 } else {
1132 (1u32 << b_blocks) - 1
1133 }
1134 }
1135}
1136
1137#[inline(always)]
1140#[allow(clippy::too_many_arguments)]
1141fn quant_partition_encode(
1142 ctx: &mut BandCtx,
1143 x: &mut [f32],
1144 n: usize,
1145 b: i32,
1146 b_blocks: i32,
1147 lowband: Option<&mut [f32]>,
1148 lm: i32,
1149 gain: f32,
1150 fill: u32,
1151) -> u32 {
1152 if n == 2 {
1153 return quant_partition_n2_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1154 }
1155 if n == 4 {
1156 return quant_partition_n4_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1157 }
1158 if n == 8 {
1159 return quant_partition_n8_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1160 }
1161 if n == 16 {
1162 return quant_partition_direct_encode(ctx, x, n, b, b_blocks, lowband, lm, gain, fill);
1163 }
1164
1165 let should_split = if lm >= 0 && n > 2 {
1167 let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
1168 let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) } as usize;
1169 if cache_base > 0 {
1170 let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
1171 let max_q = unsafe { *cache_ptr } as usize;
1172 b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
1173 } else {
1174 false
1175 }
1176 } else {
1177 false
1178 };
1179
1180 if should_split {
1181 let mut sctx = SplitCtx {
1182 inv: false,
1183 imid: 0,
1184 iside: 0,
1185 delta: 0,
1186 itheta: 0,
1187 qalloc: 0,
1188 };
1189 let mut b_mut = b;
1190 let mut fill_mut = fill;
1191 let mid = n / 2;
1192 let (x_mid, x_side) = x.split_at_mut(mid);
1193
1194 compute_theta_encode(
1195 ctx,
1196 &mut sctx,
1197 x_mid,
1198 x_side,
1199 mid,
1200 &mut b_mut,
1201 (b_blocks + 1) >> 1,
1202 b_blocks,
1203 lm,
1204 false,
1205 &mut fill_mut,
1206 );
1207
1208 ctx.remaining_bits -= sctx.qalloc;
1209 let mbits = (0).max((b_mut - sctx.delta) / 2).min(b_mut);
1210 let mut sbits = b_mut - mbits;
1211 let mut mbits = mbits;
1212
1213 let mut rebalance = ctx.remaining_bits;
1214 let mut cm;
1215
1216 if mbits >= sbits {
1219 cm = quant_partition_encode(
1220 ctx,
1221 x_mid,
1222 mid,
1223 mbits,
1224 (b_blocks + 1) >> 1,
1225 lowband,
1226 lm,
1227 gain,
1228 fill_mut,
1229 );
1230 rebalance = mbits - (rebalance - ctx.remaining_bits);
1231 if rebalance > (3 << 3) && sctx.itheta != 0 {
1232 sbits += rebalance - (3 << 3);
1233 }
1234 cm |= quant_partition_encode(
1235 ctx,
1236 x_side,
1237 mid,
1238 sbits,
1239 (b_blocks + 1) >> 1,
1240 None,
1241 lm,
1242 gain,
1243 fill_mut >> b_blocks,
1244 ) << (b_blocks >> 1);
1245 } else {
1246 cm = quant_partition_encode(
1247 ctx,
1248 x_side,
1249 mid,
1250 sbits,
1251 (b_blocks + 1) >> 1,
1252 None,
1253 lm,
1254 gain,
1255 fill_mut >> b_blocks,
1256 ) << (b_blocks >> 1);
1257 rebalance = sbits - (rebalance - ctx.remaining_bits);
1258 if rebalance > (3 << 3) && sctx.itheta != 16384 {
1259 mbits += rebalance - (3 << 3);
1260 }
1261 cm |= quant_partition_encode(
1262 ctx,
1263 x_mid,
1264 mid,
1265 mbits,
1266 (b_blocks + 1) >> 1,
1267 lowband,
1268 lm,
1269 gain,
1270 fill_mut,
1271 );
1272 }
1273 cm
1274 } else {
1275 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1276 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1277 ctx.remaining_bits -= curr_bits;
1278
1279 while ctx.remaining_bits < 0 && q > 0 {
1280 ctx.remaining_bits += curr_bits;
1281 q -= 1;
1282 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1283 ctx.remaining_bits -= curr_bits;
1284 }
1285
1286 if q != 0 {
1287 let k = get_pulses(q);
1288 alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1289 } else {
1290 if lowband.is_some() {
1291 fill
1292 } else {
1293 (1 << b_blocks) - 1
1294 }
1295 }
1296 }
1297}
1298
1299#[inline(always)]
1301fn quant_partition_n4(
1302 ctx: &mut BandCtx,
1303 x: &mut [f32],
1304 b: i32,
1305 b_blocks: i32,
1306 lowband: Option<&mut [f32]>,
1307 lm: i32,
1308 gain: f32,
1309 fill: u32,
1310) -> u32 {
1311 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1313 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1314 ctx.remaining_bits -= curr_bits;
1315
1316 while ctx.remaining_bits < 0 && q > 0 {
1317 ctx.remaining_bits += curr_bits;
1318 q -= 1;
1319 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1320 ctx.remaining_bits -= curr_bits;
1321 }
1322
1323 if q != 0 {
1324 let k = get_pulses(q);
1325 if ctx.encode {
1326 alg_quant(
1327 x,
1328 4,
1329 k,
1330 ctx.spread,
1331 b_blocks as usize,
1332 ctx.rc,
1333 gain,
1334 ctx.resynth,
1335 )
1336 } else {
1337 alg_unquant(x, 4, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1338 }
1339 } else {
1340 let has_lowband = lowband.is_some();
1341 if ctx.resynth {
1342 let mut seed = ctx.rc.tell() as u32;
1343 if has_lowband {
1344 let lb = lowband.unwrap();
1345 for i in 0..4 {
1346 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1347 x[i] = lb[i]
1348 + if seed & 0x8000 != 0 {
1349 1.0 / 256.0
1350 } else {
1351 -1.0 / 256.0
1352 };
1353 }
1354 } else {
1355 for i in 0..4 {
1356 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1357 x[i] = ((seed as i32 >> 20) as f32) / 16384.0;
1358 }
1359 }
1360 renormalise_vector(x, 4, gain);
1361 }
1362 if has_lowband {
1363 fill
1364 } else {
1365 (1u32 << b_blocks) - 1
1366 }
1367 }
1368}
1369
1370#[inline(always)]
1372fn quant_partition_n8(
1373 ctx: &mut BandCtx,
1374 x: &mut [f32],
1375 b: i32,
1376 b_blocks: i32,
1377 lowband: Option<&mut [f32]>,
1378 lm: i32,
1379 gain: f32,
1380 fill: u32,
1381) -> u32 {
1382 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1383 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1384 ctx.remaining_bits -= curr_bits;
1385
1386 while ctx.remaining_bits < 0 && q > 0 {
1387 ctx.remaining_bits += curr_bits;
1388 q -= 1;
1389 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1390 ctx.remaining_bits -= curr_bits;
1391 }
1392
1393 if q != 0 {
1394 let k = get_pulses(q);
1395 if ctx.encode {
1396 alg_quant(
1397 x,
1398 8,
1399 k,
1400 ctx.spread,
1401 b_blocks as usize,
1402 ctx.rc,
1403 gain,
1404 ctx.resynth,
1405 )
1406 } else {
1407 alg_unquant(x, 8, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1408 }
1409 } else {
1410 let has_lowband = lowband.is_some();
1411 if ctx.resynth {
1412 let mut seed = ctx.rc.tell() as u32;
1413 if has_lowband {
1414 let lb = lowband.unwrap();
1415 for i in 0..8 {
1416 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1417 x[i] = lb[i]
1418 + if seed & 0x8000 != 0 {
1419 1.0 / 256.0
1420 } else {
1421 -1.0 / 256.0
1422 };
1423 }
1424 } else {
1425 for i in 0..8 {
1426 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1427 x[i] = ((seed as i32 >> 20) as f32) / 16384.0;
1428 }
1429 }
1430 renormalise_vector(x, 8, gain);
1431 }
1432 if has_lowband {
1433 fill
1434 } else {
1435 (1u32 << b_blocks) - 1
1436 }
1437 }
1438}
1439
1440#[inline(always)]
1442#[allow(clippy::too_many_arguments)]
1443fn quant_partition_direct(
1444 ctx: &mut BandCtx,
1445 x: &mut [f32],
1446 n: usize,
1447 b: i32,
1448 b_blocks: i32,
1449 lowband: Option<&mut [f32]>,
1450 lm: i32,
1451 gain: f32,
1452 fill: u32,
1453) -> u32 {
1454 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1455 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1456 ctx.remaining_bits -= curr_bits;
1457
1458 while ctx.remaining_bits < 0 && q > 0 {
1459 ctx.remaining_bits += curr_bits;
1460 q -= 1;
1461 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1462 ctx.remaining_bits -= curr_bits;
1463 }
1464
1465 if q != 0 {
1466 let k = get_pulses(q);
1467 if ctx.encode {
1468 alg_quant(
1469 x,
1470 n,
1471 k,
1472 ctx.spread,
1473 b_blocks as usize,
1474 ctx.rc,
1475 gain,
1476 ctx.resynth,
1477 )
1478 } else {
1479 alg_unquant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1480 }
1481 } else {
1482 let has_lowband = lowband.is_some();
1483 if ctx.resynth {
1484 let mut seed = ctx.rc.tell() as u32;
1485 if let Some(lb) = lowband {
1486 for i in 0..n {
1487 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1488 x[i] = lb[i]
1489 + if seed & 0x8000 != 0 {
1490 1.0 / 256.0
1491 } else {
1492 -1.0 / 256.0
1493 };
1494 }
1495 } else {
1496 for i in 0..n {
1497 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
1498 x[i] = ((seed as i32 >> 20) as f32) / 16384.0;
1499 }
1500 }
1501 renormalise_vector(x, n, gain);
1502 }
1503 if has_lowband {
1504 fill
1505 } else {
1506 (1u32 << b_blocks) - 1
1507 }
1508 }
1509}
1510
1511#[inline(always)]
1512#[allow(clippy::too_many_arguments)]
1513pub fn quant_partition(
1514 ctx: &mut BandCtx,
1515 x: &mut [f32],
1516 n: usize,
1517 b: i32,
1518 b_blocks: i32,
1519 lowband: Option<&mut [f32]>,
1520 lm: i32,
1521 gain: f32,
1522 fill: u32,
1523) -> u32 {
1524 if n == 2 {
1527 return quant_partition_n2(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1528 }
1529
1530 if n == 4 {
1531 return quant_partition_n4(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1532 }
1533
1534 if n == 8 {
1535 return quant_partition_n8(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1536 }
1537
1538 if n == 16 {
1540 return quant_partition_direct(ctx, x, n, b, b_blocks, lowband, lm, gain, fill);
1541 }
1542
1543 let should_split = if lm >= 0 && n > 2 {
1547 let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
1548 let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) } as usize;
1549 if cache_base > 0 {
1550 let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
1551 let max_q = unsafe { *cache_ptr } as usize;
1552 b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
1553 } else {
1554 false
1555 }
1556 } else {
1557 false
1558 };
1559 if should_split {
1560 let mut sctx = SplitCtx {
1561 inv: false,
1562 imid: 0,
1563 iside: 0,
1564 delta: 0,
1565 itheta: 0,
1566 qalloc: 0,
1567 };
1568 let mut b_mut = b;
1569 let mut fill_mut = fill;
1570 let mid = n / 2;
1571 let (x_mid, x_side) = x.split_at_mut(mid);
1572
1573 compute_theta(
1574 ctx,
1575 &mut sctx,
1576 x_mid,
1577 x_side,
1578 mid,
1579 &mut b_mut,
1580 (b_blocks + 1) >> 1,
1581 b_blocks,
1582 lm,
1583 false,
1584 &mut fill_mut,
1585 );
1586
1587 ctx.remaining_bits -= sctx.qalloc;
1588 let mbits = (0).max((b_mut - sctx.delta) / 2).min(b_mut);
1589 let mut sbits = b_mut - mbits;
1590 let mut mbits = mbits;
1591
1592 let mut rebalance = ctx.remaining_bits;
1593 let mut cm;
1594
1595 if mbits >= sbits {
1596 cm = quant_partition(
1597 ctx,
1598 x_mid,
1599 mid,
1600 mbits,
1601 (b_blocks + 1) >> 1,
1602 lowband,
1603 lm,
1604 gain * (sctx.imid as f32 / 32768.0),
1605 fill_mut,
1606 );
1607 rebalance = mbits - (rebalance - ctx.remaining_bits);
1608 if rebalance > (3 << 3) && sctx.itheta != 0 {
1609 sbits += rebalance - (3 << 3);
1610 }
1611 cm |= quant_partition(
1612 ctx,
1613 x_side,
1614 mid,
1615 sbits,
1616 (b_blocks + 1) >> 1,
1617 None,
1618 lm,
1619 gain * (sctx.iside as f32 / 32768.0),
1620 fill_mut >> b_blocks,
1621 ) << (b_blocks >> 1);
1622 } else {
1623 cm = quant_partition(
1624 ctx,
1625 x_side,
1626 mid,
1627 sbits,
1628 (b_blocks + 1) >> 1,
1629 None,
1630 lm,
1631 gain * (sctx.iside as f32 / 32768.0),
1632 fill_mut >> b_blocks,
1633 ) << (b_blocks >> 1);
1634 rebalance = sbits - (rebalance - ctx.remaining_bits);
1635 if rebalance > (3 << 3) && sctx.itheta != 16384 {
1636 mbits += rebalance - (3 << 3);
1637 }
1638 cm |= quant_partition(
1639 ctx,
1640 x_mid,
1641 mid,
1642 mbits,
1643 (b_blocks + 1) >> 1,
1644 lowband,
1645 lm,
1646 gain * (sctx.imid as f32 / 32768.0),
1647 fill_mut,
1648 );
1649 }
1650 cm
1651 } else {
1652 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1653 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1654 ctx.remaining_bits -= curr_bits;
1655
1656 while ctx.remaining_bits < 0 && q > 0 {
1659 ctx.remaining_bits += curr_bits;
1660 q -= 1;
1661 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1662 ctx.remaining_bits -= curr_bits;
1663 }
1664
1665 if q != 0 {
1666 let k = get_pulses(q);
1667 if ctx.encode {
1668 alg_quant(
1669 x,
1670 n,
1671 k,
1672 ctx.spread,
1673 b_blocks as usize,
1674 ctx.rc,
1675 gain,
1676 ctx.resynth,
1677 )
1678 } else {
1679 alg_unquant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1680 }
1681 } else {
1682 let has_lowband = lowband.is_some();
1684 if ctx.resynth {
1685 let cm_mask = (1u32 << b_blocks) - 1;
1686 let fill_masked = fill & cm_mask;
1687 if fill_masked == 0 {
1688 x[..n].fill(0.0);
1690 } else if has_lowband {
1691 let lb = lowband.unwrap();
1692 #[cfg(target_arch = "aarch64")]
1693 unsafe {
1694 use std::arch::aarch64::*;
1695 let n8 = n & !7;
1696 let mut i = 0;
1697 while i < n8 {
1698 let mut vals = [0.0f32; 8];
1700 for j in 0..8 {
1701 ctx.seed = celt_lcg_rand(ctx.seed);
1702 vals[j] = if ctx.seed & 0x8000 != 0 {
1703 1.0 / 256.0
1704 } else {
1705 -1.0 / 256.0
1706 };
1707 }
1708 let vnoise = vld1q_f32(vals.as_ptr());
1709 let vnoise1 = vld1q_f32(vals.as_ptr().add(4));
1710 let vlb = vld1q_f32(lb.as_ptr().add(i));
1711 let vlb1 = vld1q_f32(lb.as_ptr().add(i + 4));
1712 let vres = vaddq_f32(vlb, vnoise);
1713 let vres1 = vaddq_f32(vlb1, vnoise1);
1714 vst1q_f32(x.as_mut_ptr().add(i), vres);
1715 vst1q_f32(x.as_mut_ptr().add(i + 4), vres1);
1716 i += 8;
1717 }
1718 for j in i..n {
1719 ctx.seed = celt_lcg_rand(ctx.seed);
1720 x[j] = lb[j]
1721 + if ctx.seed & 0x8000 != 0 {
1722 1.0 / 256.0
1723 } else {
1724 -1.0 / 256.0
1725 };
1726 }
1727 }
1728 #[cfg(not(target_arch = "aarch64"))]
1729 {
1730 for j in 0..n {
1731 ctx.seed = celt_lcg_rand(ctx.seed);
1732 x[j] = lb[j]
1733 + if ctx.seed & 0x8000 != 0 {
1734 1.0 / 256.0
1735 } else {
1736 -1.0 / 256.0
1737 };
1738 }
1739 }
1740 renormalise_vector(x, n, gain);
1741 } else {
1742 for xv in x[..n].iter_mut() {
1743 ctx.seed = celt_lcg_rand(ctx.seed);
1744 *xv = ((ctx.seed as i32 >> 20) as f32) / 16384.0;
1745 }
1746 renormalise_vector(x, n, gain);
1747 }
1748 }
1749 if has_lowband {
1750 fill
1751 } else {
1752 (1 << b_blocks) - 1
1753 }
1754 }
1755 }
1756}
1757
1758#[cfg(target_arch = "aarch64")]
1760#[inline(always)]
1761unsafe fn deinterleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1762 let n = n0 * stride;
1763 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1764 let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1765
1766 for i in 0..stride {
1767 let src_offset = i;
1768 let dst_offset = i * n0;
1769 for j in 0..n0 {
1770 tmp[dst_offset + j] = x[j * stride + src_offset];
1771 }
1772 }
1773
1774 x[..n].copy_from_slice(tmp);
1775}
1776
1777pub fn deinterleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1778 let n = n0 * stride;
1779 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1781 let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1783 if hadamard {
1784 let offset = match stride {
1785 2 => 0,
1786 4 => 2,
1787 8 => 6,
1788 16 => 14,
1789 _ => 0,
1790 };
1791 let ordery = &ORDERY_TABLE[offset..offset + stride];
1792 for i in 0..stride {
1793 for j in 0..n0 {
1794 tmp[ordery[i] as usize * n0 + j] = x[j * stride + i];
1795 }
1796 }
1797 } else {
1798 #[cfg(target_arch = "aarch64")]
1800 unsafe {
1801 if n0 >= 4 {
1802 deinterleave_hadamard_neon(x, n0, stride);
1803 return;
1804 }
1805 }
1806 for i in 0..stride {
1807 for j in 0..n0 {
1808 tmp[i * n0 + j] = x[j * stride + i];
1809 }
1810 }
1811 }
1812 x[..n].copy_from_slice(tmp);
1813}
1814
1815#[cfg(target_arch = "aarch64")]
1817#[inline(always)]
1818unsafe fn interleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1819 let n = n0 * stride;
1820 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1821 let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1822
1823 for i in 0..stride {
1824 let src_offset = i * n0;
1825 let dst_offset = i;
1826 for j in 0..n0 {
1827 tmp[j * stride + dst_offset] = x[src_offset + j];
1828 }
1829 }
1830
1831 x[..n].copy_from_slice(tmp);
1832}
1833
1834pub fn interleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1835 let n = n0 * stride;
1836 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1837 let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1838 if hadamard {
1839 let offset = match stride {
1840 2 => 0,
1841 4 => 2,
1842 8 => 6,
1843 16 => 14,
1844 _ => 0,
1845 };
1846 let ordery = &ORDERY_TABLE[offset..offset + stride];
1847 for i in 0..stride {
1848 for j in 0..n0 {
1849 tmp[j * stride + i] = x[ordery[i] as usize * n0 + j];
1850 }
1851 }
1852 } else {
1853 #[cfg(target_arch = "aarch64")]
1855 unsafe {
1856 if n0 >= 4 {
1857 interleave_hadamard_neon(x, n0, stride);
1858 return;
1859 }
1860 }
1861 for i in 0..stride {
1862 for j in 0..n0 {
1863 tmp[j * stride + i] = x[i * n0 + j];
1864 }
1865 }
1866 }
1867 x[..n].copy_from_slice(tmp);
1868}
1869
1870const ORDERY_TABLE: [i32; 30] = [
1871 1, 0, 3, 0, 2, 1, 7, 0, 4, 3, 6, 1, 5, 2, 15, 0, 8, 7, 12, 3, 11, 4, 14, 1, 9, 6, 13, 2, 10, 5,
1872];
1873
1874fn quant_band_n1(
1875 ctx: &mut BandCtx,
1876 x: &mut [f32],
1877 y: Option<&mut [f32]>,
1878 lowband_out: Option<&mut [f32]>,
1879) -> u32 {
1880 let mut sign = 0;
1881 if ctx.remaining_bits >= 1 << BITRES {
1882 if ctx.encode {
1883 sign = if x[0] < 0.0 { 1 } else { 0 };
1884 ctx.rc.enc_bits(sign as u32, 1);
1885 } else {
1886 sign = ctx.rc.dec_bits(1) as i32;
1887 }
1888 ctx.remaining_bits -= 1 << BITRES;
1889 }
1890 if ctx.resynth {
1891 x[0] = if sign != 0 { -1.0 } else { 1.0 };
1892 }
1893 if let Some(y_val) = y {
1894 let mut y_sign = 0;
1895 if ctx.remaining_bits >= 1 << BITRES {
1896 if ctx.encode {
1897 y_sign = if y_val[0] < 0.0 { 1 } else { 0 };
1898 ctx.rc.enc_bits(y_sign as u32, 1);
1899 } else {
1900 y_sign = ctx.rc.dec_bits(1) as i32;
1901 }
1902 ctx.remaining_bits -= 1 << BITRES;
1903 }
1904 if ctx.resynth {
1905 y_val[0] = if y_sign != 0 { -1.0 } else { 1.0 };
1906 }
1907 }
1908 if let Some(l_out) = lowband_out {
1909 l_out[0] = x[0] / 16.0;
1910 }
1911 1
1912}
1913
1914#[allow(clippy::too_many_arguments)]
1915#[inline(always)]
1916pub fn quant_band(
1917 ctx: &mut BandCtx,
1918 x: &mut [f32],
1919 n: usize,
1920 b: i32,
1921 b_blocks: i32,
1922 lowband: Option<&mut [f32]>,
1923 lm: i32,
1924 lowband_out: Option<&mut [f32]>,
1925 gain: f32,
1926 fill: u32,
1927) -> u32 {
1928 let n0 = n;
1929 let b0 = b_blocks;
1930 let long_blocks = b0 == 1;
1931
1932 if n == 1 {
1933 return quant_band_n1(ctx, x, None, lowband_out);
1934 }
1935
1936 let mut b_blocks = b_blocks;
1937 let mut n_b = n / b_blocks as usize;
1938 let mut time_divide = 0;
1939 let mut recombine = 0;
1940 let mut tf_change_local = ctx.tf_change;
1941 let mut fill = fill;
1942
1943 if tf_change_local > 0 {
1944 recombine = tf_change_local;
1945 }
1946
1947 let mut lowband_buf = lowband;
1952
1953 static BIT_INTERLEAVE_TABLE: [u8; 16] = [0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3];
1954
1955 for k in 0..recombine {
1956 if ctx.encode {
1957 haar1(x, n >> k, 1 << k);
1958 }
1959 if let Some(ref mut lb) = lowband_buf {
1960 haar1(lb, n >> k, 1 << k);
1961 }
1962 fill = (BIT_INTERLEAVE_TABLE[(fill & 0xF) as usize] as u32)
1963 | ((BIT_INTERLEAVE_TABLE[(fill >> 4) as usize] as u32) << 2);
1964 }
1965 b_blocks >>= recombine;
1966 n_b <<= recombine;
1967
1968 while n_b & 1 == 0 && tf_change_local < 0 {
1969 if ctx.encode {
1970 haar1(x, n_b, b_blocks as usize);
1971 }
1972 if let Some(ref mut lb) = lowband_buf {
1973 haar1(lb, n_b, b_blocks as usize);
1974 }
1975 fill |= fill << b_blocks;
1976 b_blocks <<= 1;
1977 n_b >>= 1;
1978 time_divide += 1;
1979 tf_change_local += 1;
1980 }
1981
1982 let b0_after = b_blocks;
1983 let n_b0 = n_b;
1984
1985 if b_blocks > 1 {
1986 if ctx.encode {
1987 deinterleave_hadamard(
1988 x,
1989 n_b >> recombine as usize,
1990 (b_blocks << recombine) as usize,
1991 long_blocks,
1992 );
1993 }
1994 if let Some(ref mut lb) = lowband_buf {
1995 deinterleave_hadamard(
1996 lb,
1997 n_b >> recombine as usize,
1998 (b_blocks << recombine) as usize,
1999 long_blocks,
2000 );
2001 }
2002 }
2003
2004 let cm = if ctx.encode {
2005 quant_partition_encode(
2006 ctx,
2007 x,
2008 n,
2009 b,
2010 b_blocks,
2011 lowband_buf.as_deref_mut(),
2012 lm,
2013 gain,
2014 fill,
2015 )
2016 } else {
2017 quant_partition(
2018 ctx,
2019 x,
2020 n,
2021 b,
2022 b_blocks,
2023 lowband_buf.as_deref_mut(),
2024 lm,
2025 gain,
2026 fill,
2027 )
2028 };
2029
2030 if ctx.resynth {
2031 let mut cm = cm;
2032
2033 if b_blocks > 1 {
2034 interleave_hadamard(
2035 x,
2036 n_b >> recombine as usize,
2037 (b0_after << recombine) as usize,
2038 long_blocks,
2039 );
2040 }
2041
2042 let mut n_b_undo = n_b0;
2043 let mut b_undo = b0_after;
2044 for _ in 0..time_divide {
2045 b_undo >>= 1;
2046 n_b_undo <<= 1;
2047 cm |= cm >> b_undo;
2048 haar1(x, n_b_undo, b_undo as usize);
2049 }
2050
2051 static BIT_DEINTERLEAVE_TABLE: [u8; 16] = [
2052 0x00, 0x03, 0x0C, 0x0F, 0x30, 0x33, 0x3C, 0x3F, 0xC0, 0xC3, 0xCC, 0xCF, 0xF0, 0xF3,
2053 0xFC, 0xFF,
2054 ];
2055 for k in 0..recombine {
2056 cm = BIT_DEINTERLEAVE_TABLE[cm as usize & 0xF] as u32;
2057 haar1(x, n0 >> k, 1 << k);
2058 }
2059 let mut b_final = b0_after;
2060 b_final <<= recombine;
2061
2062 if let Some(lb_out) = lowband_out {
2063 let scale = (n0 as f32).sqrt();
2064 for j in 0..n0 {
2065 lb_out[j] = scale * x[j];
2066 }
2067 }
2068 cm &= (1u32 << b_final) - 1;
2069 return cm;
2070 }
2071
2072 cm
2073}
2074
2075pub fn stereo_merge(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
2076 #[cfg(target_arch = "aarch64")]
2077 {
2078 stereo_merge_neon(x, y, mid, side, n);
2079 return;
2080 }
2081 #[cfg(not(target_arch = "aarch64"))]
2082 {
2083 stereo_merge_scalar(x, y, mid, side, n);
2084 }
2085}
2086
2087#[cfg(not(target_arch = "aarch64"))]
2088#[inline]
2089fn stereo_merge_scalar(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
2090 for i in 0..n {
2091 let x_val = x[i] * mid;
2092 let y_val = y[i] * side;
2093 x[i] = x_val - y_val;
2094 y[i] = x_val + y_val;
2095 }
2096}
2097
2098#[cfg(target_arch = "aarch64")]
2099fn stereo_merge_neon(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
2100 use std::arch::aarch64::*;
2101
2102 unsafe {
2103 let vmid = vdupq_n_f32(mid);
2104 let vside = vdupq_n_f32(side);
2105
2106 let n16 = n & !15;
2108 for i in (0..n16).step_by(16) {
2109 let x0 = vld1q_f32(x.as_ptr().add(i));
2110 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2111 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2112 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2113
2114 let y0 = vld1q_f32(y.as_ptr().add(i));
2115 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
2116 let y2 = vld1q_f32(y.as_ptr().add(i + 8));
2117 let y3 = vld1q_f32(y.as_ptr().add(i + 12));
2118
2119 let xv0 = vmulq_f32(x0, vmid);
2121 let xv1 = vmulq_f32(x1, vmid);
2122 let xv2 = vmulq_f32(x2, vmid);
2123 let xv3 = vmulq_f32(x3, vmid);
2124
2125 let yv0 = vmulq_f32(y0, vside);
2126 let yv1 = vmulq_f32(y1, vside);
2127 let yv2 = vmulq_f32(y2, vside);
2128 let yv3 = vmulq_f32(y3, vside);
2129
2130 vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(xv0, yv0));
2132 vst1q_f32(x.as_mut_ptr().add(i + 4), vsubq_f32(xv1, yv1));
2133 vst1q_f32(x.as_mut_ptr().add(i + 8), vsubq_f32(xv2, yv2));
2134 vst1q_f32(x.as_mut_ptr().add(i + 12), vsubq_f32(xv3, yv3));
2135
2136 vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(xv0, yv0));
2137 vst1q_f32(y.as_mut_ptr().add(i + 4), vaddq_f32(xv1, yv1));
2138 vst1q_f32(y.as_mut_ptr().add(i + 8), vaddq_f32(xv2, yv2));
2139 vst1q_f32(y.as_mut_ptr().add(i + 12), vaddq_f32(xv3, yv3));
2140 }
2141
2142 let n4 = (n & !3) - n16;
2144 for i in (n16..n16 + n4).step_by(4) {
2145 let xv = vld1q_f32(x.as_ptr().add(i));
2146 let yv = vld1q_f32(y.as_ptr().add(i));
2147
2148 let x_val = vmulq_f32(xv, vmid);
2149 let y_val = vmulq_f32(yv, vside);
2150
2151 vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(x_val, y_val));
2152 vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(x_val, y_val));
2153 }
2154
2155 for i in (n16 + n4)..n {
2157 let x_val = x[i] * mid;
2158 let y_val = y[i] * side;
2159 x[i] = x_val - y_val;
2160 y[i] = x_val + y_val;
2161 }
2162 }
2163}
2164
2165#[allow(clippy::too_many_arguments)]
2166#[inline(always)]
2167pub fn quant_band_stereo(
2168 ctx: &mut BandCtx,
2169 x: &mut [f32],
2170 y: &mut [f32],
2171 n: usize,
2172 b: i32,
2173 b_blocks: i32,
2174 lowband: Option<&mut [f32]>,
2175 lm: i32,
2176 lowband_out: Option<&mut [f32]>,
2177 _gain: f32,
2178 fill: u32,
2179) -> u32 {
2180 if n == 1 {
2181 return quant_band_n1(ctx, x, Some(y), lowband_out);
2182 }
2183
2184 if ctx.encode
2185 && (ctx.band_e[ctx.i] < MIN_STEREO_ENERGY
2186 || ctx.band_e[ctx.m.nb_ebands + ctx.i] < MIN_STEREO_ENERGY)
2187 {
2188 if ctx.band_e[ctx.i] > ctx.band_e[ctx.m.nb_ebands + ctx.i] {
2189 y.copy_from_slice(x);
2190 } else {
2191 x.copy_from_slice(y);
2192 }
2193 }
2194
2195 let mut sctx = SplitCtx {
2196 inv: false,
2197 imid: 0,
2198 iside: 0,
2199 delta: 0,
2200 itheta: 0,
2201 qalloc: 0,
2202 };
2203 let mut b_mut = b;
2204 let mut fill_mut = fill;
2205 if ctx.encode {
2206 compute_theta_encode(
2207 ctx,
2208 &mut sctx,
2209 x,
2210 y,
2211 n,
2212 &mut b_mut,
2213 b_blocks,
2214 b_blocks,
2215 lm,
2216 true,
2217 &mut fill_mut,
2218 );
2219 } else {
2220 compute_theta(
2221 ctx,
2222 &mut sctx,
2223 x,
2224 y,
2225 n,
2226 &mut b_mut,
2227 b_blocks,
2228 b_blocks,
2229 lm,
2230 true,
2231 &mut fill_mut,
2232 );
2233 };
2234
2235 let mid_gain = sctx.imid as f32 / 32768.0;
2236 let side_gain = sctx.iside as f32 / 32768.0;
2237
2238 if n == 2 {
2239 let mut mbits = b_mut;
2240 let mut sbits = 0;
2241 if sctx.itheta != 0 && sctx.itheta != 16384 {
2242 sbits = 1 << BITRES;
2243 }
2244 mbits -= sbits;
2245 let c = sctx.itheta > 8192;
2246 ctx.remaining_bits -= sctx.qalloc + sbits;
2247
2248 let mut sign = 0;
2249 if sbits != 0 {
2250 if ctx.encode {
2251 sign = if c {
2252 if (y[0] * x[1] - y[1] * x[0]) < 0.0 {
2253 1
2254 } else {
2255 0
2256 }
2257 } else {
2258 if (x[0] * y[1] - x[1] * y[0]) < 0.0 {
2259 1
2260 } else {
2261 0
2262 }
2263 };
2264 ctx.rc.enc_bits(sign as u32, 1);
2265 } else {
2266 sign = ctx.rc.dec_bits(1) as i32;
2267 }
2268 }
2269 let sign_val = (1 - 2 * sign) as f32;
2270 let cm = if c {
2271 let cm = quant_band(
2272 ctx,
2273 y,
2274 n,
2275 mbits,
2276 b_blocks,
2277 lowband,
2278 lm,
2279 lowband_out,
2280 1.0,
2281 fill,
2282 );
2283 x[0] = -sign_val * y[1];
2284 x[1] = sign_val * y[0];
2285 cm
2286 } else {
2287 let cm = quant_band(
2288 ctx,
2289 x,
2290 n,
2291 mbits,
2292 b_blocks,
2293 lowband,
2294 lm,
2295 lowband_out,
2296 1.0,
2297 fill,
2298 );
2299 y[0] = -sign_val * x[1];
2300 y[1] = sign_val * x[0];
2301 cm
2302 };
2303
2304 if ctx.resynth {
2305 let x0 = x[0];
2306 let x1 = x[1];
2307 let y0 = y[0];
2308 let y1 = y[1];
2309 x[0] = mid_gain * x0 - side_gain * y0;
2310 x[1] = mid_gain * x1 - side_gain * y1;
2311 y[0] = mid_gain * x0 + side_gain * y0;
2312 y[1] = mid_gain * x1 + side_gain * y1;
2313 }
2314 return cm;
2315 }
2316
2317 ctx.remaining_bits -= sctx.qalloc;
2318 let mut mbits = (0).max((b_mut - sctx.delta) / 2).min(b_mut);
2319 let mut sbits = b_mut - mbits;
2320
2321 let mut rebalance = ctx.remaining_bits;
2322 let mut cm;
2323
2324 if mbits >= sbits {
2325 cm = quant_band(
2326 ctx,
2327 x,
2328 n,
2329 mbits,
2330 b_blocks,
2331 lowband,
2332 lm,
2333 lowband_out,
2334 1.0,
2335 fill_mut,
2336 );
2337 rebalance = mbits - (rebalance - ctx.remaining_bits);
2338 if rebalance > (3 << 3) && sctx.itheta != 0 {
2339 sbits += rebalance - (3 << 3);
2340 }
2341 cm |= quant_band(
2342 ctx,
2343 y,
2344 n,
2345 sbits,
2346 b_blocks,
2347 None,
2348 lm,
2349 None,
2350 side_gain,
2351 fill_mut >> b_blocks,
2352 ) << (b_blocks >> 1);
2353 } else {
2354 cm = quant_band(
2355 ctx,
2356 y,
2357 n,
2358 sbits,
2359 b_blocks,
2360 None,
2361 lm,
2362 None,
2363 side_gain,
2364 fill_mut >> b_blocks,
2365 ) << (b_blocks >> 1);
2366 rebalance = sbits - (rebalance - ctx.remaining_bits);
2367 if rebalance > (3 << 3) && sctx.itheta != 16384 {
2368 mbits += rebalance - (3 << 3);
2369 }
2370 cm |= quant_band(
2371 ctx,
2372 x,
2373 n,
2374 mbits,
2375 b_blocks,
2376 lowband,
2377 lm,
2378 lowband_out,
2379 1.0,
2380 fill_mut,
2381 );
2382 }
2383
2384 if ctx.resynth {
2385 stereo_merge(x, y, mid_gain, side_gain, n);
2386 if sctx.inv {
2387 for yv in y[..n].iter_mut() {
2388 *yv = -*yv;
2389 }
2390 }
2391 }
2392 cm
2393}
2394
2395#[allow(clippy::too_many_arguments)]
2396pub fn quant_all_bands(
2397 encode: bool,
2398 m: &CeltMode,
2399 start: usize,
2400 end: usize,
2401 x: &mut [f32],
2402 mut y: Option<&mut [f32]>,
2403 collapse_masks: &mut [u32],
2404 band_e: &[f32],
2405 pulses: &[i32],
2406 short_blocks: bool,
2407 spread: i32,
2408 dual_stereo: &mut bool,
2409 intensity: usize,
2410 tf_res: &[i32],
2411 total_bits: i32,
2412 balance: &mut i32,
2413 rc: &mut RangeCoder,
2414 lm: i32,
2415 coded_bands: i32,
2416 resynth: bool,
2417 seed: &mut u32,
2418) {
2419 let mut balance_val = *balance;
2420 let b_blocks = if short_blocks { 1 << lm } else { 1 };
2421 let c_channels = if y.is_some() { 2 } else { 1 };
2422 let m_val = 1usize << lm as usize;
2423
2424 let norm_offset = m_val * (m.e_bands[start] as usize);
2425 let norm_size = m_val * (m.e_bands[m.nb_ebands - 1] as usize) - norm_offset;
2426 const MAX_NORM_SIZE: usize = 800;
2428 debug_assert!(norm_size <= MAX_NORM_SIZE);
2429 let mut norm_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_NORM_SIZE];
2435 let norm =
2436 unsafe { std::slice::from_raw_parts_mut(norm_buf.as_mut_ptr() as *mut f32, norm_size) };
2437
2438 let mut lowband_scratch_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
2442 let lowband_scratch_ptr = lowband_scratch_buf.as_mut_ptr() as *mut f32;
2443
2444 let mut lowband_offset: usize = 0;
2445 let mut update_lowband = true;
2446 let mut avoid_split_noise = b_blocks > 1;
2447
2448 let e_bands = &m.e_bands;
2450
2451 let mut ctx_seed = *seed;
2453
2454 for i in start..end {
2455 let e_band_i = e_bands[i] as usize;
2457 let e_band_i1 = e_bands[i + 1] as usize;
2458 let offset = m_val * e_band_i;
2459 let n = m_val * (e_band_i1 - e_band_i);
2460 let last = i == end - 1;
2461
2462 let tell = tell_frac_inline!(rc);
2464 if i != start {
2465 balance_val -= tell;
2466 }
2467 let remaining_bits = total_bits - tell - 1;
2468
2469 let mut b = 0i32;
2470 if i < coded_bands as usize {
2471 let curr_balance = celt_sudiv(balance_val, 3i32.min(coded_bands - i as i32));
2472 b = 0i32.max(16383i32.min((remaining_bits + 1).min(pulses[i] + curr_balance)));
2473 }
2474
2475 let norm_pos = m_val * e_band_i - norm_offset;
2476
2477 let band_start = m_val * e_band_i;
2478 let bands_start = m_val * (e_bands[start] as usize);
2479 if resynth
2480 && (band_start as i32 - n as i32 >= bands_start as i32 || i == start + 1)
2481 && (update_lowband || lowband_offset == 0)
2482 {
2483 lowband_offset = i;
2484 }
2485
2486 let tf_change = tf_res[i];
2487
2488 let mut effective_lowband: i32 = -1;
2489 let mut x_cm: u32;
2490 let mut y_cm: u32;
2491
2492 if lowband_offset != 0 && (spread != SPREAD_AGGRESSIVE || b_blocks > 1 || tf_change < 0) {
2493 effective_lowband = 0i32.max(
2494 (m_val * e_bands[lowband_offset] as usize) as i32 - norm_offset as i32 - n as i32,
2495 );
2496 let el_abs = effective_lowband as usize + norm_offset;
2497
2498 let mut fold_start = lowband_offset;
2499 loop {
2500 if fold_start == 0 {
2501 break;
2502 }
2503 fold_start -= 1;
2504 if m_val * (e_bands[fold_start] as usize) <= el_abs {
2505 break;
2506 }
2507 }
2508 let mut fold_end = lowband_offset.saturating_sub(1);
2509 while fold_end + 1 < i && m_val * (e_bands[fold_end + 1] as usize) < el_abs + n {
2510 fold_end += 1;
2511 }
2512
2513 x_cm = 0;
2514 y_cm = 0;
2515 for fi in fold_start..fold_end {
2516 x_cm |= collapse_masks[fi * c_channels];
2517 y_cm |= collapse_masks[fi * c_channels + c_channels - 1];
2518 }
2519 } else {
2520 x_cm = (1u32 << b_blocks) - 1;
2521 y_cm = (1u32 << b_blocks) - 1;
2522 }
2523
2524 let mut ctx = BandCtx {
2525 encode,
2526 m,
2527 i,
2528 band_e,
2529 rc,
2530 spread,
2531 remaining_bits,
2532 resynth,
2533 tf_change,
2534 intensity,
2535 theta_round: 0,
2536 avoid_split_noise,
2537 arch: 0,
2538 disable_inv: false,
2539 seed: ctx_seed,
2540 };
2541
2542 if *dual_stereo && i == intensity {
2543 *dual_stereo = false;
2544 }
2545
2546 let mut lowband_scratch: Option<&mut [f32]> = if effective_lowband >= 0 {
2547 let lb_start = effective_lowband as usize;
2548 let lb_end = lb_start + n;
2549 if lb_end <= norm.len() {
2550 unsafe {
2551 std::ptr::copy_nonoverlapping(
2552 norm.as_ptr().add(lb_start),
2553 lowband_scratch_ptr,
2554 n,
2555 )
2556 };
2557 Some(unsafe { std::slice::from_raw_parts_mut(lowband_scratch_ptr, n) })
2558 } else {
2559 None
2560 }
2561 } else {
2562 None
2563 };
2564
2565 let x_slice = &mut x[offset..];
2571 if *dual_stereo {
2572 let y_slice = &mut y.as_mut().unwrap()[offset..];
2573 let lb_x = lowband_scratch.as_deref_mut();
2574 let lb_out_x = if !last && norm_pos + n <= norm.len() {
2575 Some(&mut norm[norm_pos..norm_pos + n])
2576 } else {
2577 None
2578 };
2579 x_cm = quant_band(
2580 &mut ctx,
2581 x_slice,
2582 n,
2583 b / 2,
2584 b_blocks,
2585 lb_x,
2586 lm,
2587 lb_out_x,
2588 1.0,
2589 x_cm,
2590 );
2591 y_cm = quant_band(
2592 &mut ctx,
2593 y_slice,
2594 n,
2595 b / 2,
2596 b_blocks,
2597 None,
2598 lm,
2599 None,
2600 1.0,
2601 y_cm,
2602 );
2603 } else {
2604 if let Some(y_all) = y.as_mut() {
2605 let y_slice = &mut y_all[offset..offset + n];
2606 let lb = lowband_scratch.as_deref_mut();
2607 let lb_out = if !last && norm_pos + n <= norm.len() {
2608 Some(&mut norm[norm_pos..norm_pos + n])
2609 } else {
2610 None
2611 };
2612 x_cm = quant_band_stereo(
2613 &mut ctx,
2614 x_slice,
2615 y_slice,
2616 n,
2617 b,
2618 b_blocks,
2619 lb,
2620 lm,
2621 lb_out,
2622 1.0,
2623 x_cm | y_cm,
2624 );
2625 y_cm = x_cm;
2626 } else {
2627 let lb = lowband_scratch.as_deref_mut();
2628 let lb_out = if !last && norm_pos + n <= norm.len() {
2629 Some(&mut norm[norm_pos..norm_pos + n])
2630 } else {
2631 None
2632 };
2633 x_cm = quant_band(&mut ctx, x_slice, n, b, b_blocks, lb, lm, lb_out, 1.0, x_cm);
2634 y_cm = x_cm;
2635 }
2636 }
2637
2638 collapse_masks[i * c_channels] = (x_cm & 0xFF) as u8 as u32; if c_channels == 2 {
2640 collapse_masks[i * c_channels + 1] = (y_cm & 0xFF) as u8 as u32;
2641 }
2642
2643 balance_val += pulses[i] + tell;
2644
2645 update_lowband = b > ((n as i32) << BITRES);
2646
2647 ctx_seed = ctx.seed;
2649
2650 avoid_split_noise = false;
2651 }
2652 *balance = balance_val;
2653 *seed = ctx_seed;
2654}
2655
2656#[cfg(target_arch = "aarch64")]
2658fn compute_band_energy_neon(band: &[f32]) -> f32 {
2659 use std::arch::aarch64::*;
2660
2661 let n = band.len();
2662 let mut sum = 1e-15f32;
2663
2664 unsafe {
2665 let n16 = n & !15;
2667 if n16 > 0 {
2668 let mut acc0 = vdupq_n_f32(0.0);
2669 let mut acc1 = vdupq_n_f32(0.0);
2670 let mut acc2 = vdupq_n_f32(0.0);
2671 let mut acc3 = vdupq_n_f32(0.0);
2672
2673 for i in (0..n16).step_by(16) {
2674 let v0 = vld1q_f32(band.as_ptr().add(i));
2675 let v1 = vld1q_f32(band.as_ptr().add(i + 4));
2676 let v2 = vld1q_f32(band.as_ptr().add(i + 8));
2677 let v3 = vld1q_f32(band.as_ptr().add(i + 12));
2678
2679 acc0 = vfmaq_f32(acc0, v0, v0);
2680 acc1 = vfmaq_f32(acc1, v1, v1);
2681 acc2 = vfmaq_f32(acc2, v2, v2);
2682 acc3 = vfmaq_f32(acc3, v3, v3);
2683 }
2684
2685 acc0 = vaddq_f32(acc0, acc1);
2687 acc2 = vaddq_f32(acc2, acc3);
2688 acc0 = vaddq_f32(acc0, acc2);
2689 sum += vaddvq_f32(acc0);
2690 }
2691
2692 let n4 = (n & !3) - n16;
2694 if n4 > 0 {
2695 let mut acc = vdupq_n_f32(0.0);
2696 for i in (n16..n16 + n4).step_by(4) {
2697 let v = vld1q_f32(band.as_ptr().add(i));
2698 acc = vfmaq_f32(acc, v, v);
2699 }
2700 sum += vaddvq_f32(acc);
2701 }
2702
2703 for i in (n16 + n4)..n {
2705 let v = band[i];
2706 sum += v * v;
2707 }
2708 }
2709
2710 sum.sqrt()
2711}
2712
2713pub fn compute_band_energies(
2714 m: &CeltMode,
2715 x: &[f32],
2716 band_e: &mut [f32],
2717 end: usize,
2718 channels: usize,
2719 lm: usize,
2720) {
2721 let frame_size = m.short_mdct_size << lm;
2722
2723 #[cfg(target_arch = "aarch64")]
2724 let _use_neon = true;
2725 #[cfg(not(target_arch = "aarch64"))]
2726 let _use_neon = false;
2727
2728 for c in 0..channels {
2729 let ch = &x[c * frame_size..(c + 1) * frame_size];
2730 for i in 0..end {
2731 let offset = (m.e_bands[i] as usize) << lm;
2732 let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2733 let band = &ch[offset..offset + n];
2734
2735 #[cfg(target_arch = "aarch64")]
2736 {
2737 band_e[c * m.nb_ebands + i] = compute_band_energy_neon(band);
2738 }
2739 #[cfg(not(target_arch = "aarch64"))]
2740 {
2741 let sum = band.iter().fold(1e-15f32, |acc, &v| acc + v * v);
2742 band_e[c * m.nb_ebands + i] = sum.sqrt();
2743 }
2744 }
2745 }
2746}
2747
2748pub fn amp2log2(
2749 m: &CeltMode,
2750 eff_ebands: usize,
2751 end: usize,
2752 band_e: &[f32],
2753 band_log_e: &mut [f32],
2754 channels: usize,
2755) {
2756 for c in 0..channels {
2757 for i in 0..eff_ebands {
2758 let val = band_e[c * m.nb_ebands + i].max(1e-10);
2759 band_log_e[c * m.nb_ebands + i] = val.log2() - m.e_means[i];
2760 }
2761 for i in eff_ebands..end {
2762 band_log_e[c * m.nb_ebands + i] = -14.0;
2763 }
2764 }
2765}
2766
2767pub fn log2amp(m: &CeltMode, end: usize, band_e: &mut [f32], band_log_e: &[f32], channels: usize) {
2770 for c in 0..channels {
2771 for i in 0..end {
2772 band_e[c * m.nb_ebands + i] = band_log_e[c * m.nb_ebands + i] + m.e_means[i];
2774 }
2775 }
2776}
2777
2778pub fn normalise_bands(
2779 m: &CeltMode,
2780 freq: &[f32],
2781 x: &mut [f32],
2782 band_e: &[f32],
2783 end: usize,
2784 channels: usize,
2785 m_val: usize,
2786) {
2787 let lm = m_val.trailing_zeros() as usize;
2788 let frame_size = m.short_mdct_size << lm;
2789 for c in 0..channels {
2790 for i in 0..end {
2791 let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2792 let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2793 let norm = 1.0 / (1e-15 + band_e[c * m.nb_ebands + i]);
2794 let src = &freq[base..base + n];
2795 let dst = &mut x[base..base + n];
2796 for (d, &s) in dst.iter_mut().zip(src) {
2797 *d = s * norm;
2798 }
2799 }
2800 }
2801}
2802
2803#[allow(clippy::too_many_arguments)]
2804pub fn denormalise_bands(
2805 m: &CeltMode,
2806 x: &[f32],
2807 freq: &mut [f32],
2808 band_e: &[f32],
2809 start: usize,
2810 end: usize,
2811 channels: usize,
2812 m_val: usize,
2813) {
2814 let lm = m_val.trailing_zeros() as usize;
2815 let frame_size = m.short_mdct_size << lm;
2816
2817 for c in 0..channels {
2818 for i in start..end {
2819 let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2820 let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2821 let band_log = band_e[c * m.nb_ebands + i];
2822 let g = (2.0f32).powf(band_log);
2824 let src = &x[base..base + n];
2825 let dst = &mut freq[base..base + n];
2826 for (d, &s) in dst.iter_mut().zip(src) {
2827 *d = s * g;
2828 }
2829 }
2830 }
2831}
2832
2833pub fn celt_lcg_rand(seed: u32) -> u32 {
2834 seed.wrapping_mul(1664525).wrapping_add(1013904223)
2835}
2836
2837#[cfg(target_arch = "aarch64")]
2839#[inline(always)]
2840#[allow(unsafe_op_in_unsafe_fn)]
2841unsafe fn renormalise_vector_neon(x: &mut [f32], n: usize, gain: f32) {
2842 use std::arch::aarch64::*;
2843
2844 let mut sum_vec = vdupq_n_f32(0.0);
2846 let mut i = 0;
2847
2848 while i + 16 <= n {
2850 let x0 = vld1q_f32(x.as_ptr().add(i));
2851 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2852 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2853 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2854 sum_vec = vfmaq_f32(sum_vec, x0, x0);
2855 sum_vec = vfmaq_f32(sum_vec, x1, x1);
2856 sum_vec = vfmaq_f32(sum_vec, x2, x2);
2857 sum_vec = vfmaq_f32(sum_vec, x3, x3);
2858 i += 16;
2859 }
2860
2861 while i + 8 <= n {
2863 let x0 = vld1q_f32(x.as_ptr().add(i));
2864 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2865 sum_vec = vfmaq_f32(sum_vec, x0, x0);
2866 sum_vec = vfmaq_f32(sum_vec, x1, x1);
2867 i += 8;
2868 }
2869
2870 while i + 4 <= n {
2872 let x0 = vld1q_f32(x.as_ptr().add(i));
2873 sum_vec = vfmaq_f32(sum_vec, x0, x0);
2874 i += 4;
2875 }
2876
2877 let mut e = 1e-15f32 + vaddvq_f32(sum_vec);
2878
2879 for j in i..n {
2881 e += x[j] * x[j];
2882 }
2883
2884 let norm = gain / e.sqrt();
2885 let vnorm = vdupq_n_f32(norm);
2886
2887 i = 0;
2889 while i + 16 <= n {
2890 let x0 = vld1q_f32(x.as_ptr().add(i));
2891 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2892 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2893 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2894 vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2895 vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2896 vst1q_f32(x.as_mut_ptr().add(i + 8), vmulq_f32(x2, vnorm));
2897 vst1q_f32(x.as_mut_ptr().add(i + 12), vmulq_f32(x3, vnorm));
2898 i += 16;
2899 }
2900
2901 while i + 8 <= n {
2902 let x0 = vld1q_f32(x.as_ptr().add(i));
2903 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2904 vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2905 vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2906 i += 8;
2907 }
2908
2909 while i + 4 <= n {
2910 let x0 = vld1q_f32(x.as_ptr().add(i));
2911 vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2912 i += 4;
2913 }
2914
2915 for j in i..n {
2917 x[j] *= norm;
2918 }
2919}
2920
2921pub fn renormalise_vector(x: &mut [f32], n: usize, gain: f32) {
2922 #[cfg(target_arch = "aarch64")]
2923 unsafe {
2924 renormalise_vector_neon(x, n, gain);
2925 return;
2926 }
2927 #[cfg(not(target_arch = "aarch64"))]
2928 {
2929 let mut e = 1e-15f32;
2930 for &xv in x[..n].iter() {
2931 e += xv * xv;
2932 }
2933 let norm = gain / e.sqrt();
2934 for xv in x[..n].iter_mut() {
2935 *xv *= norm;
2936 }
2937 }
2938}
2939
2940#[allow(clippy::too_many_arguments)]
2941pub fn anti_collapse(
2942 m: &CeltMode,
2943 x_buf: &mut [f32],
2944 collapse_masks: &[u32],
2945 lm: i32,
2946 channels: usize,
2947 size: usize,
2948 start: usize,
2949 end: usize,
2950 log_e: &[f32],
2951 prev1_log_e: &[f32],
2952 prev2_log_e: &[f32],
2953 pulses: &[i32],
2954 mut seed: u32,
2955) -> u32 {
2956 for i in start..end {
2957 let n0 = (m.e_bands[i + 1] - m.e_bands[i]) as usize;
2958 let depth = if n0 > 0 {
2959 ((1 + pulses[i]) / n0 as i32) >> lm
2960 } else {
2961 0
2962 };
2963
2964 let thresh = 0.5 * (-(0.125 * depth as f32)).exp2();
2965 let sqrt_1 = 1.0 / ((n0 << lm) as f32).sqrt();
2966
2967 for c in 0..channels {
2968 let p1 = prev1_log_e[c * m.nb_ebands + i];
2969 let p2 = prev2_log_e[c * m.nb_ebands + i];
2970
2971 let (p1_adj, p2_adj) = if channels == 1 && prev1_log_e.len() >= 2 * m.nb_ebands {
2972 (
2973 p1.max(prev1_log_e[m.nb_ebands + i]),
2974 p2.max(prev2_log_e[m.nb_ebands + i]),
2975 )
2976 } else {
2977 (p1, p2)
2978 };
2979
2980 let e_diff = log_e[c * m.nb_ebands + i] - p1_adj.min(p2_adj);
2981 let e_diff = e_diff.max(0.0);
2982
2983 let mut r = 2.0 * (-e_diff).exp2();
2984 if lm == 3 {
2985 r *= std::f32::consts::SQRT_2;
2986 }
2987 r = r.min(thresh);
2988 r *= sqrt_1;
2989
2990 let x_offset = c * size + ((m.e_bands[i] as usize) << lm);
2991 let mut renormalize = false;
2992 for k in 0..(1 << lm) {
2993 if (collapse_masks[i * channels + c] & (1 << k)) == 0 {
2994 for j in 0..n0 {
2995 seed = celt_lcg_rand(seed);
2996 x_buf[x_offset + (j << lm) + k] = if (seed & 0x8000) != 0 { r } else { -r };
2997 }
2998 renormalize = true;
2999 }
3000 }
3001 if renormalize {
3002 renormalise_vector(&mut x_buf[x_offset..x_offset + (n0 << lm)], n0 << lm, 1.0);
3003 }
3004 }
3005 }
3006 seed
3007}