1use crate::modes::CeltMode;
2use crate::pvq::*;
3use crate::range_coder::RangeCoder;
4use crate::rate::{BITRES, bits2pulses, get_pulses, pulses2bits};
5use crate::tell_frac_inline;
6
7const MIN_STEREO_ENERGY: f32 = 1e-10;
8
9pub struct BandCtx<'a> {
10 pub encode: bool,
11 pub m: &'a CeltMode,
12 pub i: usize,
13 pub band_e: &'a [f32],
14 pub rc: &'a mut RangeCoder,
15 pub spread: i32,
16 pub remaining_bits: i32,
17 pub resynth: bool,
18 pub tf_change: i32,
19 pub intensity: usize,
20 pub theta_round: i32,
21 pub avoid_split_noise: bool,
22 pub arch: i32,
23 pub disable_inv: bool,
24 pub seed: u32,
25}
26
27#[inline]
28fn bitexact_cos(x: i16) -> i16 {
29 #[inline(always)]
30 fn frac_mul16(a: i16, b: i16) -> i16 {
31 ((16384i32 + (a as i32) * (b as i32)) >> 15) as i16
32 }
33
34 let tmp = (4096i32 + (x as i32) * (x as i32)) >> 13;
35 let x2 = tmp as i16;
36 let x2 = (32767 - x2 as i32
37 + frac_mul16(x2, -7651 + frac_mul16(x2, 8277 + frac_mul16(-626, x2))) as i32)
38 as i16;
39 1 + x2
40}
41
42#[inline]
43pub fn bitexact_log2tan(isin: i32, icos: i32) -> i32 {
44 let ec_ilog = |x: u32| -> i32 {
45 if x == 0 {
46 0
47 } else {
48 32 - x.leading_zeros() as i32
49 }
50 };
51 let lc = ec_ilog(icos.max(0) as u32);
52 let ls = ec_ilog(isin.max(0) as u32);
53 let icos_shifted = if lc > 0 {
54 icos.max(0) << (15 - lc).max(0)
55 } else {
56 0
57 };
58 let isin_shifted = if ls > 0 {
59 isin.max(0) << (15 - ls).max(0)
60 } else {
61 0
62 };
63 let fract_mul = |a: i32, b: i32| -> i32 { (a * b + 16384) >> 15 };
64 (ls - lc) * (1 << 11) + fract_mul(isin_shifted, fract_mul(isin_shifted, -2597) + 7932)
65 - fract_mul(icos_shifted, fract_mul(icos_shifted, -2597) + 7932)
66}
67
68#[inline(always)]
69fn celt_sudiv(n: i32, d: i32) -> i32 {
70 n / d
71}
72
73#[inline]
74fn isqrt32(mut val: u32) -> u32 {
75 let mut g = 0u32;
76 let mut bshift = ((32 - val.leading_zeros()) as i32 - 1) >> 1;
77 let mut b = 1u32 << bshift;
78 while bshift >= 0 {
79 let t = (((g << 1) + b) as u64) << bshift;
80 if t <= val as u64 {
81 g += b;
82 val -= t as u32;
83 }
84 b >>= 1;
85 bshift -= 1;
86 }
87 g
88}
89
90pub const SPREAD_NONE: i32 = 0;
91pub const SPREAD_LIGHT: i32 = 1;
92pub const SPREAD_NORMAL: i32 = 2;
93pub const SPREAD_AGGRESSIVE: i32 = 3;
94
95#[allow(clippy::too_many_arguments)]
96pub fn spreading_decision(
97 m: &CeltMode,
98 x_buf: &[f32],
99 average: &mut i32,
100 last_decision: i32,
101 hf_average: &mut i32,
102 tapset_decision: &mut i32,
103 update_hf: bool,
104 end: usize,
105 channels: usize,
106 m_val: usize,
107 spread_weight: &[i32],
108) -> i32 {
109 let mut sum = 0;
110 let mut nb_bands = 0;
111 let n0 = m_val * m.short_mdct_size;
112 let mut hf_sum = 0;
113
114 if m_val * (m.e_bands[end] as usize - m.e_bands[end - 1] as usize) <= 8 {
115 return SPREAD_NONE;
116 }
117
118 for c in 0..channels {
119 for (i, &sw) in spread_weight[..end].iter().enumerate() {
120 let n = m_val * (m.e_bands[i + 1] as usize - m.e_bands[i] as usize);
121 if n <= 8 {
122 continue;
123 }
124
125 let mut tcount = [0; 3];
126 let offset = m_val * m.e_bands[i] as usize + c * n0;
127 let x = &x_buf[offset..offset + n];
128
129 for xv in x.iter().copied() {
130 let x2n = xv * xv * (n as f32);
131 if x2n < 0.25 {
132 tcount[0] += 1;
133 }
134 if x2n < 0.0625 {
135 tcount[1] += 1;
136 }
137 if x2n < 0.015625 {
138 tcount[2] += 1;
139 }
140 }
141
142 if i > m.nb_ebands - 4 {
143 hf_sum += 32 * (tcount[1] + tcount[0]) / (n as i32);
144 }
145
146 let tmp = (if 2 * tcount[2] >= (n as i32) { 1 } else { 0 })
147 + (if 2 * tcount[1] >= (n as i32) { 1 } else { 0 })
148 + (if 2 * tcount[0] >= (n as i32) { 1 } else { 0 });
149 sum += tmp * sw;
150 nb_bands += sw;
151 }
152 }
153
154 if update_hf {
155 if hf_sum > 0 {
156 hf_sum /= (channels as i32) * (4 - m.nb_ebands as i32 + end as i32);
157 }
158 *hf_average = (*hf_average + hf_sum) >> 1;
159 hf_sum = *hf_average;
160
161 if *tapset_decision == 2 {
162 hf_sum += 4;
163 } else if *tapset_decision == 0 {
164 hf_sum -= 4;
165 }
166
167 if hf_sum > 22 {
168 *tapset_decision = 2;
169 } else if hf_sum > 18 {
170 *tapset_decision = 1;
171 } else {
172 *tapset_decision = 0;
173 }
174 }
175
176 if nb_bands == 0 {
177 return SPREAD_NORMAL;
178 }
179
180 let mut sum_scaled = (sum << 8) / nb_bands;
181 sum_scaled = (sum_scaled + *average) >> 1;
182 *average = sum_scaled;
183
184 let sum_final = (3 * sum_scaled + (((3 - last_decision) << 7) + 64) + 2) >> 2;
185
186 if sum_final < 80 {
187 SPREAD_AGGRESSIVE
188 } else if sum_final < 256 {
189 SPREAD_NORMAL
190 } else if sum_final < 384 {
191 SPREAD_LIGHT
192 } else {
193 SPREAD_NONE
194 }
195}
196
197pub fn haar1(x: &mut [f32], n0: usize, stride: usize) {
198 #[cfg(target_arch = "aarch64")]
199 {
200 if stride == 1 && n0 >= 64 {
201 haar1_neon(x, n0);
202 } else {
203 haar1_scalar(x, n0, stride);
204 }
205 return;
206 }
207 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
208 unsafe {
209 if stride == 1 && n0 >= 16 && is_x86_feature_detected!("avx") {
210 haar1_avx(x, n0);
211 return;
212 }
213 }
214 #[cfg(not(target_arch = "aarch64"))]
215 haar1_scalar(x, n0, stride);
216}
217
218#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
219#[target_feature(enable = "avx")]
220unsafe fn haar1_avx(x: &mut [f32], n0: usize) {
221 use std::arch::x86_64::*;
222 let n = n0 >> 1;
223 let scale = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
224 let mut j = 0;
225 while j + 8 <= n {
226 let ptr = x.as_mut_ptr().add(2 * j);
227 let a = _mm256_loadu_ps(ptr);
228 let b = _mm256_loadu_ps(ptr.add(4));
229
230 let t0 = _mm256_unpacklo_ps(a, b);
231 let t1 = _mm256_unpackhi_ps(a, b);
232
233 let even = _mm256_unpacklo_ps(t0, t1);
234 let odd = _mm256_unpackhi_ps(t0, t1);
235
236 let sum = _mm256_mul_ps(_mm256_add_ps(even, odd), scale);
237 let diff = _mm256_mul_ps(_mm256_sub_ps(even, odd), scale);
238
239 let r0 = _mm256_unpacklo_ps(sum, diff);
240 let r1 = _mm256_unpackhi_ps(sum, diff);
241
242 let out0 = _mm256_permute2f128_ps(r0, r1, 0x20);
243 let out1 = _mm256_permute2f128_ps(r0, r1, 0x31);
244
245 _mm256_storeu_ps(ptr, out0);
246 _mm256_storeu_ps(ptr.add(8), out1);
247 j += 8;
248 }
249
250 let scale = std::f32::consts::FRAC_1_SQRT_2;
251 while j < n {
252 let idx1 = 2 * j;
253 let idx2 = 2 * j + 1;
254 let tmp1 = scale * x[idx1];
255 let tmp2 = scale * x[idx2];
256 x[idx1] = tmp1 + tmp2;
257 x[idx2] = tmp1 - tmp2;
258 j += 1;
259 }
260}
261
262#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
263#[inline]
264fn haar1_scalar(x: &mut [f32], n0: usize, stride: usize) {
265 let n = n0 >> 1;
266 let scale = std::f32::consts::FRAC_1_SQRT_2;
267 for i in 0..stride {
268 for j in 0..n {
269 let idx1 = stride * 2 * j + i;
270 let idx2 = stride * (2 * j + 1) + i;
271 let tmp1 = scale * x[idx1];
272 let tmp2 = scale * x[idx2];
273 x[idx1] = tmp1 + tmp2;
274 x[idx2] = tmp1 - tmp2;
275 }
276 }
277}
278
279#[cfg(target_arch = "aarch64")]
280fn haar1_neon(x: &mut [f32], n0: usize) {
281 use std::arch::aarch64::*;
282
283 let n = n0 >> 1;
284 let scale = std::f32::consts::FRAC_1_SQRT_2;
285
286 unsafe {
287 let vscale = vdupq_n_f32(scale);
288
289 let mut j = 0usize;
290 while j + 4 <= n {
291 let idx = 2 * j;
292 let pairs = vld2q_f32(x.as_ptr().add(idx));
293 let even = vmulq_f32(pairs.0, vscale);
294 let odd = vmulq_f32(pairs.1, vscale);
295
296 let out = float32x4x2_t {
297 0: vaddq_f32(even, odd),
298 1: vsubq_f32(even, odd),
299 };
300 vst2q_f32(x.as_mut_ptr().add(idx), out);
301 j += 4;
302 }
303
304 while j < n {
305 let idx1 = 2 * j;
306 let idx2 = idx1 + 1;
307 let tmp1 = scale * x[idx1];
308 let tmp2 = scale * x[idx2];
309 x[idx1] = tmp1 + tmp2;
310 x[idx2] = tmp1 - tmp2;
311 j += 1;
312 }
313 }
314}
315
316#[inline(always)]
317pub fn compute_qn(n: usize, b: i32, offset: i32, pulse_cap: i32, stereo: bool) -> i32 {
318 static EXP2_TABLE8: [i16; 8] = [16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048];
319 let mut n2 = (2 * n as i32) - 1;
320 if stereo && n == 2 {
321 n2 -= 1;
322 }
323 let mut qb = celt_sudiv(b + n2 * offset, n2);
324 qb = qb.min(b - pulse_cap - (4 << BITRES));
325 qb = qb.min(8 << BITRES);
326 if qb < (1i32 << BITRES >> 1) {
327 1
328 } else {
329 let val = EXP2_TABLE8[(qb & 0x7) as usize] as i32;
330 let shift = 14 - (qb >> BITRES);
331 let raw = if (0..32).contains(&shift) {
332 val >> shift
333 } else {
334 0
335 };
336 let qn = (raw + 1) >> 1 << 1;
337 qn.min(256)
338 }
339}
340
341#[cfg(target_arch = "aarch64")]
342#[inline(always)]
343#[allow(unsafe_op_in_unsafe_fn)]
344unsafe fn stereo_itheta_neon(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
345 use std::arch::aarch64::*;
346
347 let mut emid = 1e-15f32;
348 let mut eside = 1e-15f32;
349
350 if stereo {
351 let mut sum_mid = vdupq_n_f32(0.0);
352 let mut sum_side = vdupq_n_f32(0.0);
353 let mut i = 0;
354
355 while i + 16 <= n {
356 let x0 = vld1q_f32(x.as_ptr().add(i));
357 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
358 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
359 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
360 let y0 = vld1q_f32(y.as_ptr().add(i));
361 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
362 let y2 = vld1q_f32(y.as_ptr().add(i + 8));
363 let y3 = vld1q_f32(y.as_ptr().add(i + 12));
364
365 let m0 = vaddq_f32(x0, y0);
366 let m1 = vaddq_f32(x1, y1);
367 let m2 = vaddq_f32(x2, y2);
368 let m3 = vaddq_f32(x3, y3);
369 let s0 = vsubq_f32(x0, y0);
370 let s1 = vsubq_f32(x1, y1);
371 let s2 = vsubq_f32(x2, y2);
372 let s3 = vsubq_f32(x3, y3);
373
374 sum_mid = vfmaq_f32(sum_mid, m0, m0);
375 sum_mid = vfmaq_f32(sum_mid, m1, m1);
376 sum_mid = vfmaq_f32(sum_mid, m2, m2);
377 sum_mid = vfmaq_f32(sum_mid, m3, m3);
378 sum_side = vfmaq_f32(sum_side, s0, s0);
379 sum_side = vfmaq_f32(sum_side, s1, s1);
380 sum_side = vfmaq_f32(sum_side, s2, s2);
381 sum_side = vfmaq_f32(sum_side, s3, s3);
382
383 i += 16;
384 }
385
386 while i + 8 <= n {
387 let x0 = vld1q_f32(x.as_ptr().add(i));
388 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
389 let y0 = vld1q_f32(y.as_ptr().add(i));
390 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
391
392 let m0 = vaddq_f32(x0, y0);
393 let m1 = vaddq_f32(x1, y1);
394 let s0 = vsubq_f32(x0, y0);
395 let s1 = vsubq_f32(x1, y1);
396
397 sum_mid = vfmaq_f32(sum_mid, m0, m0);
398 sum_mid = vfmaq_f32(sum_mid, m1, m1);
399 sum_side = vfmaq_f32(sum_side, s0, s0);
400 sum_side = vfmaq_f32(sum_side, s1, s1);
401
402 i += 8;
403 }
404
405 while i + 4 <= n {
406 let x0 = vld1q_f32(x.as_ptr().add(i));
407 let y0 = vld1q_f32(y.as_ptr().add(i));
408 let m0 = vaddq_f32(x0, y0);
409 let s0 = vsubq_f32(x0, y0);
410 sum_mid = vfmaq_f32(sum_mid, m0, m0);
411 sum_side = vfmaq_f32(sum_side, s0, s0);
412 i += 4;
413 }
414
415 emid += vaddvq_f32(sum_mid);
416 eside += vaddvq_f32(sum_side);
417
418 for j in i..n {
419 let m = x[j] + y[j];
420 let s = x[j] - y[j];
421 emid += m * m;
422 eside += s * s;
423 }
424 } else {
425 let mut sum_mid = vdupq_n_f32(0.0);
426 let mut sum_side = vdupq_n_f32(0.0);
427 let mut i = 0;
428
429 while i + 16 <= n {
430 let x0 = vld1q_f32(x.as_ptr().add(i));
431 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
432 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
433 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
434 let y0 = vld1q_f32(y.as_ptr().add(i));
435 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
436 let y2 = vld1q_f32(y.as_ptr().add(i + 8));
437 let y3 = vld1q_f32(y.as_ptr().add(i + 12));
438
439 sum_mid = vfmaq_f32(sum_mid, x0, x0);
440 sum_mid = vfmaq_f32(sum_mid, x1, x1);
441 sum_mid = vfmaq_f32(sum_mid, x2, x2);
442 sum_mid = vfmaq_f32(sum_mid, x3, x3);
443 sum_side = vfmaq_f32(sum_side, y0, y0);
444 sum_side = vfmaq_f32(sum_side, y1, y1);
445 sum_side = vfmaq_f32(sum_side, y2, y2);
446 sum_side = vfmaq_f32(sum_side, y3, y3);
447
448 i += 16;
449 }
450
451 while i + 8 <= n {
452 let x0 = vld1q_f32(x.as_ptr().add(i));
453 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
454 let y0 = vld1q_f32(y.as_ptr().add(i));
455 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
456
457 sum_mid = vfmaq_f32(sum_mid, x0, x0);
458 sum_mid = vfmaq_f32(sum_mid, x1, x1);
459 sum_side = vfmaq_f32(sum_side, y0, y0);
460 sum_side = vfmaq_f32(sum_side, y1, y1);
461
462 i += 8;
463 }
464
465 while i + 4 <= n {
466 let x0 = vld1q_f32(x.as_ptr().add(i));
467 let y0 = vld1q_f32(y.as_ptr().add(i));
468 sum_mid = vfmaq_f32(sum_mid, x0, x0);
469 sum_side = vfmaq_f32(sum_side, y0, y0);
470 i += 4;
471 }
472
473 emid += vaddvq_f32(sum_mid);
474 eside += vaddvq_f32(sum_side);
475
476 for j in i..n {
477 emid += x[j] * x[j];
478 eside += y[j] * y[j];
479 }
480 }
481
482 let mid = emid.sqrt();
483 let side = eside.sqrt();
484 let theta_norm = celt_atan2p_norm(side, mid);
485 (0.5 + 16384.0 * theta_norm) as i32
486}
487
488#[inline(always)]
489#[cfg(target_arch = "aarch64")]
490pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
491 unsafe { stereo_itheta_neon(x, y, stereo, n) }
492}
493
494#[inline(always)]
495#[cfg(not(target_arch = "aarch64"))]
496pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
497 #[cfg(target_arch = "aarch64")]
498 unsafe {
499 return stereo_itheta_neon(x, y, stereo, n);
500 }
501 #[cfg(not(target_arch = "aarch64"))]
502 {
503 let mut emid = 1e-15f32;
504 let mut eside = 1e-15f32;
505 if stereo {
506 for i in 0..n {
507 let m = x[i] + y[i];
508 let s = x[i] - y[i];
509 emid += m * m;
510 eside += s * s;
511 }
512 } else {
513 for i in 0..n {
514 emid += x[i] * x[i];
515 eside += y[i] * y[i];
516 }
517 }
518 let mid = emid.sqrt();
519 let side = eside.sqrt();
520 let theta_norm = celt_atan2p_norm(side, mid);
521 (0.5 + 16384.0 * theta_norm) as i32
522 }
523}
524
525#[inline(always)]
526fn celt_atan2p_norm(y: f32, x: f32) -> f32 {
527 #[inline(always)]
528 fn atan_norm(x: f32) -> f32 {
529 const ATAN2_2_OVER_PI: f32 = std::f32::consts::FRAC_2_PI;
530 const A03: f32 = -3.333_166e-1_f32;
531 const A05: f32 = 1.996_270_4e-1_f32;
532 const A07: f32 = -1.397_658_3e-1_f32;
533 const A09: f32 = 9.794_234_e-2_f32;
534 const A11: f32 = -5.777_359_e-2_f32;
535 const A13: f32 = 2.304_014e-2_f32;
536 const A15: f32 = -4.355_406e-3_f32;
537 let x2 = x * x;
538 ATAN2_2_OVER_PI
539 * x
540 * (1.0
541 + x2 * (A03
542 + x2 * (A05 + x2 * (A07 + x2 * (A09 + x2 * (A11 + x2 * (A13 + x2 * A15)))))))
543 }
544 if x * x + y * y < 1e-18 {
545 return 0.0;
546 }
547 if y < x {
548 atan_norm(y / x)
549 } else {
550 1.0 - atan_norm(x / y)
551 }
552}
553
554pub struct SplitCtx {
555 pub inv: bool,
556 pub imid: i32,
557 pub iside: i32,
558 pub delta: i32,
559 pub itheta: i32,
560 pub qalloc: i32,
561}
562
563#[allow(clippy::too_many_arguments)]
564#[inline(always)]
565pub fn compute_theta(
566 ctx: &mut BandCtx,
567 sctx: &mut SplitCtx,
568 x: &[f32],
569 y: &[f32],
570 n: usize,
571 b: &mut i32,
572 b_blocks: i32,
573 b0: i32,
574 lm: i32,
575 stereo: bool,
576 fill: &mut u32,
577) {
578 let pulse_cap = ctx.m.log_n[ctx.i] as i32 + (lm << BITRES);
579 let offset = (pulse_cap >> 1) - if stereo && n == 2 { 16 } else { 4 };
580 let mut qn = compute_qn(n, *b, offset, pulse_cap, stereo);
581
582 if stereo && ctx.i >= ctx.intensity {
583 qn = 1;
584 }
585
586 let mut itheta = 0;
587 if ctx.encode {
588 itheta = stereo_itheta(x, y, stereo, n);
589 }
590
591 let tell_start = tell_frac_inline!(ctx.rc);
592
593 if qn != 1 {
594 if ctx.encode {
595 if !stereo || ctx.theta_round == 0 {
596 itheta = (itheta * qn + 8192) >> 14;
597 if !stereo && ctx.avoid_split_noise && itheta > 0 && itheta < qn {
598 let unquantized = (itheta * 16384) / qn;
599 let imid = bitexact_cos(unquantized as i16) as i32;
600 let iside = bitexact_cos((16384 - unquantized) as i16) as i32;
601 let delta =
602 (((n as i32 - 1) << 7) * bitexact_log2tan(iside, imid) + 16384) >> 15;
603 if delta > *b {
604 itheta = qn;
605 } else if delta < -*b {
606 itheta = 0;
607 }
608 }
609 } else {
610 let bias = if itheta > 8192 {
611 32767 / qn
612 } else {
613 -32767 / qn
614 };
615 let down = (itheta * qn + bias) >> 14;
616 let down = down.clamp(0, qn - 1);
617 if ctx.theta_round < 0 {
618 itheta = down;
619 } else {
620 itheta = down + 1;
621 }
622 }
623 }
624
625 if stereo && n > 2 {
626 let p0 = 3;
627 let x0 = qn / 2;
628 let ft = p0 * (x0 + 1) + x0;
629 if ctx.encode {
630 let fl = if itheta <= x0 {
631 p0 * itheta
632 } else {
633 (itheta - 1 - x0) + (x0 + 1) * p0
634 };
635 let fh = if itheta <= x0 {
636 p0 * (itheta + 1)
637 } else {
638 (itheta - x0) + (x0 + 1) * p0
639 };
640 ctx.rc.encode(fl as u32, fh as u32, ft as u32);
641 } else {
642 let fs = ctx.rc.decode(ft as u32);
643 if fs < (x0 + 1) as u32 * p0 as u32 {
644 itheta = fs as i32 / p0;
645 } else {
646 itheta = (x0 + 1) + (fs as i32 - (x0 + 1) * p0);
647 }
648 let fl = if itheta <= x0 {
649 p0 * itheta
650 } else {
651 (itheta - 1 - x0) + (x0 + 1) * p0
652 };
653 let fh = if itheta <= x0 {
654 p0 * (itheta + 1)
655 } else {
656 (itheta - x0) + (x0 + 1) * p0
657 };
658 ctx.rc.update(fl as u32, fh as u32, ft as u32);
659 }
660 } else if b0 > 1 || stereo {
661 if ctx.encode {
662 ctx.rc.enc_uint(itheta as u32, (qn + 1) as u32);
663 } else {
664 itheta = ctx.rc.dec_uint((qn + 1) as u32) as i32;
665 }
666 } else {
667 let ft = ((qn >> 1) + 1) * ((qn >> 1) + 1);
668 if ctx.encode {
669 let fs = if itheta <= (qn >> 1) {
670 itheta + 1
671 } else {
672 qn + 1 - itheta
673 };
674 let fl = if itheta <= (qn >> 1) {
675 (itheta * (itheta + 1)) >> 1
676 } else {
677 ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1)
678 };
679 ctx.rc.encode(fl as u32, (fl + fs) as u32, ft as u32);
680 } else {
681 let fm = ctx.rc.decode(ft as u32) as i32;
682 if fm < (((qn >> 1) * ((qn >> 1) + 1)) >> 1) {
683 itheta = (isqrt32((8 * fm + 1) as u32) as i32 - 1) >> 1;
684 let fl = (itheta * (itheta + 1)) >> 1;
685 let fs = itheta + 1;
686 ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
687 } else {
688 itheta = (2 * (qn + 1) - isqrt32((8 * (ft - fm - 1) + 1) as u32) as i32) >> 1;
689 let fs = qn + 1 - itheta;
690 let fl = ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1);
691 ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
692 }
693 }
694 }
695 itheta = (itheta as u32 * 16384 / qn as u32) as i32;
696 if ctx.encode && stereo {
697 let (bx, by) = (x.as_ptr() as *mut f32, y.as_ptr() as *mut f32);
698 let (sx, sy) = unsafe {
699 (
700 std::slice::from_raw_parts_mut(bx, n),
701 std::slice::from_raw_parts_mut(by, n),
702 )
703 };
704 if itheta == 0 {
705 intensity_stereo(ctx.m, sx, sy, ctx.band_e, ctx.i, n);
706 } else {
707 stereo_split(sx, sy, n);
708 }
709 }
710 } else if stereo {
711 if ctx.encode {
712 let inv = itheta > 8192 && !ctx.disable_inv;
713 let (bx, by) = (x.as_ptr() as *mut f32, y.as_ptr() as *mut f32);
714 let (sx, sy) = unsafe {
715 (
716 std::slice::from_raw_parts_mut(bx, n),
717 std::slice::from_raw_parts_mut(by, n),
718 )
719 };
720 if inv {
721 for yv in sy.iter_mut() {
722 *yv = -*yv;
723 }
724 }
725 intensity_stereo(ctx.m, sx, sy, ctx.band_e, ctx.i, n);
726 if *b > (2 << BITRES) && ctx.remaining_bits > (2 << BITRES) {
727 ctx.rc.encode_bit_logp(inv, 2);
728 }
729 itheta = 0;
730 sctx.inv = inv;
731 } else {
732 if *b > (2 << BITRES) && ctx.remaining_bits > (2 << BITRES) {
733 sctx.inv = ctx.rc.decode_bit_logp(2);
734 } else {
735 sctx.inv = false;
736 }
737 if ctx.disable_inv {
738 sctx.inv = false;
739 }
740 itheta = 0;
741 }
742 }
743
744 sctx.itheta = itheta;
745
746 sctx.qalloc = tell_frac_inline!(ctx.rc) - tell_start;
747 *b -= sctx.qalloc; if itheta == 0 {
750 sctx.imid = 32767;
751 sctx.iside = 0;
752 sctx.delta = -16384;
753 *fill &= (1 << b_blocks) - 1;
754 } else if itheta == 16384 {
755 sctx.imid = 0;
756 sctx.iside = 32767;
757 sctx.delta = 16384;
758 *fill &= ((1 << b_blocks) - 1) << b_blocks;
759 } else {
760 let imid = bitexact_cos(itheta as i16);
761 sctx.imid = imid as i32;
762 let iside = bitexact_cos((16384 - itheta) as i16);
763 sctx.iside = iside as i32;
764 sctx.delta =
765 (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
766 }
767}
768
769#[inline(always)]
770fn quant_partition_n2_encode(
771 ctx: &mut BandCtx,
772 x: &mut [f32],
773 b: i32,
774 b_blocks: i32,
775 lowband: Option<&mut [f32]>,
776 lm: i32,
777 gain: f32,
778 fill: u32,
779) -> u32 {
780 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
781 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
782 ctx.remaining_bits -= curr_bits;
783
784 while ctx.remaining_bits < 0 && q > 0 {
785 ctx.remaining_bits += curr_bits;
786 q -= 1;
787 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
788 ctx.remaining_bits -= curr_bits;
789 }
790
791 if q != 0 {
792 let k = get_pulses(q);
793 alg_quant(x, 2, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
794 } else {
795 let has_lowband = lowband.is_some();
796 if has_lowband {
797 fill
798 } else {
799 (1u32 << b_blocks) - 1
800 }
801 }
802}
803
804#[inline(always)]
805fn quant_partition_n4_encode(
806 ctx: &mut BandCtx,
807 x: &mut [f32],
808 b: i32,
809 b_blocks: i32,
810 lowband: Option<&mut [f32]>,
811 lm: i32,
812 gain: f32,
813 fill: u32,
814) -> u32 {
815 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
816 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
817 ctx.remaining_bits -= curr_bits;
818
819 while ctx.remaining_bits < 0 && q > 0 {
820 ctx.remaining_bits += curr_bits;
821 q -= 1;
822 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
823 ctx.remaining_bits -= curr_bits;
824 }
825
826 if q != 0 {
827 let k = get_pulses(q);
828 alg_quant(x, 4, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
829 } else {
830 let has_lowband = lowband.is_some();
831 if has_lowband {
832 fill
833 } else {
834 (1u32 << b_blocks) - 1
835 }
836 }
837}
838
839#[inline(always)]
840fn quant_partition_n8_encode(
841 ctx: &mut BandCtx,
842 x: &mut [f32],
843 b: i32,
844 b_blocks: i32,
845 lowband: Option<&mut [f32]>,
846 lm: i32,
847 gain: f32,
848 fill: u32,
849) -> u32 {
850 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
851 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
852 ctx.remaining_bits -= curr_bits;
853
854 while ctx.remaining_bits < 0 && q > 0 {
855 ctx.remaining_bits += curr_bits;
856 q -= 1;
857 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
858 ctx.remaining_bits -= curr_bits;
859 }
860
861 if q != 0 {
862 let k = get_pulses(q);
863 alg_quant(x, 8, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
864 } else {
865 let has_lowband = lowband.is_some();
866 if has_lowband {
867 fill
868 } else {
869 (1u32 << b_blocks) - 1
870 }
871 }
872}
873
874#[inline(always)]
875#[allow(clippy::too_many_arguments)]
876fn quant_partition_direct_encode(
877 ctx: &mut BandCtx,
878 x: &mut [f32],
879 n: usize,
880 b: i32,
881 b_blocks: i32,
882 lowband: Option<&mut [f32]>,
883 lm: i32,
884 gain: f32,
885 fill: u32,
886) -> u32 {
887 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
888 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
889 ctx.remaining_bits -= curr_bits;
890
891 while ctx.remaining_bits < 0 && q > 0 {
892 ctx.remaining_bits += curr_bits;
893 q -= 1;
894 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
895 ctx.remaining_bits -= curr_bits;
896 }
897
898 if q != 0 {
899 let k = get_pulses(q);
900 alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
901 } else {
902 let has_lowband = lowband.is_some();
903 if has_lowband {
904 fill
905 } else {
906 (1u32 << b_blocks) - 1
907 }
908 }
909}
910
911#[inline(always)]
912#[allow(clippy::too_many_arguments)]
913fn quant_partition_encode(
914 ctx: &mut BandCtx,
915 x: &mut [f32],
916 n: usize,
917 b: i32,
918 b_blocks: i32,
919 lowband: Option<&mut [f32]>,
920 lm: i32,
921 gain: f32,
922 fill: u32,
923) -> u32 {
924 if n == 2 {
926 return quant_partition_n2_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
927 }
928
929 let should_split = if lm >= 0 && n > 2 {
931 let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
932 let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) };
933 if cache_base >= 0 {
934 let cache_base = cache_base as usize;
935 let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
936 let max_q = unsafe { *cache_ptr } as usize;
937 b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
938 } else {
939 false
940 }
941 } else {
942 false
943 };
944
945 if should_split {
946 let mut sctx = SplitCtx {
947 inv: false,
948 imid: 0,
949 iside: 0,
950 delta: 0,
951 itheta: 0,
952 qalloc: 0,
953 };
954 let mut b_mut = b;
955 let mut fill_mut = fill;
956 let mid = n / 2;
957 let lm = lm - 1;
958 let b0 = b_blocks;
959 if b_blocks == 1 {
960 fill_mut = (fill_mut & 1) | (fill_mut << 1);
961 }
962 let b_blocks = (b_blocks + 1) >> 1;
963 let (x_mid, x_side) = x.split_at_mut(mid);
964
965 compute_theta(
966 ctx,
967 &mut sctx,
968 x_mid,
969 x_side,
970 mid,
971 &mut b_mut,
972 b_blocks,
973 b0,
974 lm,
975 false,
976 &mut fill_mut,
977 );
978
979 ctx.remaining_bits -= sctx.qalloc;
980 let mut delta = sctx.delta;
981 if b0 > 1 && (sctx.itheta & 0x3fff) != 0 {
983 if sctx.itheta > 8192 {
984 delta -= delta >> (4 - lm);
985 } else {
986 delta = 0.min(delta + ((mid as i32) << BITRES >> (5 - lm)));
987 }
988 }
989 let mbits = (0).max((b_mut - delta) / 2).min(b_mut);
990 let mut sbits = b_mut - mbits;
991 let mut mbits = mbits;
992
993 let mut rebalance = ctx.remaining_bits;
994 let mut cm;
995 let mid_gain = gain * (sctx.imid as f32 / 32768.0);
996 let side_gain = gain * (sctx.iside as f32 / 32768.0);
997
998 if mbits >= sbits {
999 if let Some(lb) = lowband {
1000 let (lb_mid, lb_side) = lb.split_at_mut(mid);
1001 cm = quant_partition_encode(
1002 ctx,
1003 x_mid,
1004 mid,
1005 mbits,
1006 b_blocks,
1007 Some(lb_mid),
1008 lm,
1009 mid_gain,
1010 fill_mut,
1011 );
1012 rebalance = mbits - (rebalance - ctx.remaining_bits);
1013 if rebalance > (3 << 3) && sctx.itheta != 0 {
1014 sbits += rebalance - (3 << 3);
1015 }
1016 cm |= quant_partition_encode(
1017 ctx,
1018 x_side,
1019 mid,
1020 sbits,
1021 b_blocks,
1022 Some(lb_side),
1023 lm,
1024 side_gain,
1025 fill_mut >> b_blocks,
1026 ) << (b0 >> 1);
1027 } else {
1028 cm = quant_partition_encode(
1029 ctx, x_mid, mid, mbits, b_blocks, None, lm, mid_gain, fill_mut,
1030 );
1031 rebalance = mbits - (rebalance - ctx.remaining_bits);
1032 if rebalance > (3 << 3) && sctx.itheta != 0 {
1033 sbits += rebalance - (3 << 3);
1034 }
1035 cm |= quant_partition_encode(
1036 ctx,
1037 x_side,
1038 mid,
1039 sbits,
1040 b_blocks,
1041 None,
1042 lm,
1043 side_gain,
1044 fill_mut >> b_blocks,
1045 ) << (b0 >> 1);
1046 }
1047 } else if let Some(lb) = lowband {
1048 let (lb_mid, lb_side) = lb.split_at_mut(mid);
1049 cm = quant_partition_encode(
1050 ctx,
1051 x_side,
1052 mid,
1053 sbits,
1054 b_blocks,
1055 Some(lb_side),
1056 lm,
1057 side_gain,
1058 fill_mut >> b_blocks,
1059 ) << (b0 >> 1);
1060 rebalance = sbits - (rebalance - ctx.remaining_bits);
1061 if rebalance > (3 << 3) && sctx.itheta != 16384 {
1062 mbits += rebalance - (3 << 3);
1063 }
1064 cm |= quant_partition_encode(
1065 ctx,
1066 x_mid,
1067 mid,
1068 mbits,
1069 b_blocks,
1070 Some(lb_mid),
1071 lm,
1072 mid_gain,
1073 fill_mut,
1074 );
1075 } else {
1076 cm = quant_partition_encode(
1077 ctx,
1078 x_side,
1079 mid,
1080 sbits,
1081 b_blocks,
1082 None,
1083 lm,
1084 side_gain,
1085 fill_mut >> b_blocks,
1086 ) << (b0 >> 1);
1087 rebalance = sbits - (rebalance - ctx.remaining_bits);
1088 if rebalance > (3 << 3) && sctx.itheta != 16384 {
1089 mbits += rebalance - (3 << 3);
1090 }
1091 cm |= quant_partition_encode(
1092 ctx, x_mid, mid, mbits, b_blocks, None, lm, mid_gain, fill_mut,
1093 );
1094 }
1095 cm
1096 } else {
1097 if n == 4 {
1099 return quant_partition_n4_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1100 }
1101 if n == 8 {
1102 return quant_partition_n8_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1103 }
1104 if n == 16 {
1105 return quant_partition_direct_encode(ctx, x, n, b, b_blocks, lowband, lm, gain, fill);
1106 }
1107 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1108 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1109 ctx.remaining_bits -= curr_bits;
1110
1111 while ctx.remaining_bits < 0 && q > 0 {
1112 ctx.remaining_bits += curr_bits;
1113 q -= 1;
1114 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1115 ctx.remaining_bits -= curr_bits;
1116 }
1117
1118 if q != 0 {
1119 let k = get_pulses(q);
1120 alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1121 } else if lowband.is_some() {
1122 fill
1123 } else {
1124 (1 << b_blocks) - 1
1125 }
1126 }
1127}
1128
1129#[inline(always)]
1130#[allow(clippy::too_many_arguments)]
1131pub fn quant_partition(
1132 ctx: &mut BandCtx,
1133 x: &mut [f32],
1134 n: usize,
1135 b: i32,
1136 b_blocks: i32,
1137 lowband: Option<&mut [f32]>,
1138 lm: i32,
1139 gain: f32,
1140 fill: u32,
1141) -> u32 {
1142 let should_split = if lm >= 0 && n > 2 {
1145 let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
1146 let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) };
1147 if cache_base >= 0 {
1148 let cache_base = cache_base as usize;
1149 let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
1150 let max_q = unsafe { *cache_ptr } as usize;
1151 b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
1152 } else {
1153 false
1154 }
1155 } else {
1156 false
1157 };
1158 if should_split {
1159 let mut sctx = SplitCtx {
1160 inv: false,
1161 imid: 0,
1162 iside: 0,
1163 delta: 0,
1164 itheta: 0,
1165 qalloc: 0,
1166 };
1167 let mut b_mut = b;
1168 let mut fill_mut = fill;
1169 let mid = n / 2;
1170 let lm = lm - 1;
1171 let b0 = b_blocks; if b_blocks == 1 {
1173 fill_mut = (fill_mut & 1) | (fill_mut << 1);
1174 }
1175 let b_blocks = (b_blocks + 1) >> 1;
1176 let (x_mid, x_side) = x.split_at_mut(mid);
1177
1178 compute_theta(
1179 ctx,
1180 &mut sctx,
1181 x_mid,
1182 x_side,
1183 mid,
1184 &mut b_mut,
1185 b_blocks,
1186 b0,
1187 lm,
1188 false,
1189 &mut fill_mut,
1190 );
1191
1192 ctx.remaining_bits -= sctx.qalloc;
1193 let mut delta = sctx.delta;
1194 if b0 > 1 && (sctx.itheta & 0x3fff) != 0 {
1197 if sctx.itheta > 8192 {
1198 delta -= delta >> (4 - lm);
1199 } else {
1200 delta = 0.min(delta + ((mid as i32) << BITRES >> (5 - lm)));
1201 }
1202 }
1203 let mbits = (0).max((b_mut - delta) / 2).min(b_mut);
1204 let mut sbits = b_mut - mbits;
1205 let mut mbits = mbits;
1206
1207 let mut rebalance = ctx.remaining_bits;
1208 let mut cm;
1209
1210 if mbits >= sbits {
1211 if let Some(lb) = lowband {
1212 let (lb_mid, lb_side) = lb.split_at_mut(mid);
1213 cm = quant_partition(
1214 ctx,
1215 x_mid,
1216 mid,
1217 mbits,
1218 b_blocks,
1219 Some(lb_mid),
1220 lm,
1221 gain * (sctx.imid as f32 / 32768.0),
1222 fill_mut,
1223 );
1224 rebalance = mbits - (rebalance - ctx.remaining_bits);
1225 if rebalance > (3 << 3) && sctx.itheta != 0 {
1226 sbits += rebalance - (3 << 3);
1227 }
1228 cm |= quant_partition(
1229 ctx,
1230 x_side,
1231 mid,
1232 sbits,
1233 b_blocks,
1234 Some(lb_side),
1235 lm,
1236 gain * (sctx.iside as f32 / 32768.0),
1237 fill_mut >> b_blocks,
1238 ) << (b0 >> 1);
1239 } else {
1240 cm = quant_partition(
1241 ctx,
1242 x_mid,
1243 mid,
1244 mbits,
1245 b_blocks,
1246 None,
1247 lm,
1248 gain * (sctx.imid as f32 / 32768.0),
1249 fill_mut,
1250 );
1251 rebalance = mbits - (rebalance - ctx.remaining_bits);
1252 if rebalance > (3 << 3) && sctx.itheta != 0 {
1253 sbits += rebalance - (3 << 3);
1254 }
1255 cm |= quant_partition(
1256 ctx,
1257 x_side,
1258 mid,
1259 sbits,
1260 b_blocks,
1261 None,
1262 lm,
1263 gain * (sctx.iside as f32 / 32768.0),
1264 fill_mut >> b_blocks,
1265 ) << (b0 >> 1);
1266 }
1267 } else if let Some(lb) = lowband {
1268 let (lb_mid, lb_side) = lb.split_at_mut(mid);
1269 cm = quant_partition(
1270 ctx,
1271 x_side,
1272 mid,
1273 sbits,
1274 b_blocks,
1275 Some(lb_side),
1276 lm,
1277 gain * (sctx.iside as f32 / 32768.0),
1278 fill_mut >> b_blocks,
1279 ) << (b0 >> 1);
1280 rebalance = sbits - (rebalance - ctx.remaining_bits);
1281 if rebalance > (3 << 3) && sctx.itheta != 16384 {
1282 mbits += rebalance - (3 << 3);
1283 }
1284 cm |= quant_partition(
1285 ctx,
1286 x_mid,
1287 mid,
1288 mbits,
1289 b_blocks,
1290 Some(lb_mid),
1291 lm,
1292 gain * (sctx.imid as f32 / 32768.0),
1293 fill_mut,
1294 );
1295 } else {
1296 cm = quant_partition(
1297 ctx,
1298 x_side,
1299 mid,
1300 sbits,
1301 b_blocks,
1302 None,
1303 lm,
1304 gain * (sctx.iside as f32 / 32768.0),
1305 fill_mut >> b_blocks,
1306 ) << (b0 >> 1);
1307 rebalance = sbits - (rebalance - ctx.remaining_bits);
1308 if rebalance > (3 << 3) && sctx.itheta != 16384 {
1309 mbits += rebalance - (3 << 3);
1310 }
1311 cm |= quant_partition(
1312 ctx,
1313 x_mid,
1314 mid,
1315 mbits,
1316 b_blocks,
1317 None,
1318 lm,
1319 gain * (sctx.imid as f32 / 32768.0),
1320 fill_mut,
1321 );
1322 }
1323 cm
1324 } else {
1325 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1326 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1327 ctx.remaining_bits -= curr_bits;
1328
1329 while ctx.remaining_bits < 0 && q > 0 {
1330 ctx.remaining_bits += curr_bits;
1331 q -= 1;
1332 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1333 ctx.remaining_bits -= curr_bits;
1334 }
1335
1336 if q != 0 {
1337 let k = get_pulses(q);
1338 if ctx.encode {
1339 alg_quant(
1340 x,
1341 n,
1342 k,
1343 ctx.spread,
1344 b_blocks as usize,
1345 ctx.rc,
1346 gain,
1347 ctx.resynth,
1348 )
1349 } else {
1350 alg_unquant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1351 }
1352 } else {
1353 let mut cm = 0u32;
1354 if ctx.resynth {
1355 let cm_mask = (1u32 << b_blocks) - 1;
1356 let fill_masked = fill & cm_mask;
1357 if fill_masked == 0 {
1358 x[..n].fill(0.0);
1359 } else if let Some(lb) = lowband {
1360 #[cfg(target_arch = "aarch64")]
1361 unsafe {
1362 use std::arch::aarch64::*;
1363 let n8 = n & !7;
1364 let mut i = 0;
1365 while i < n8 {
1366 let mut vals = [0.0f32; 8];
1367 for j in 0..8 {
1368 ctx.seed = celt_lcg_rand(ctx.seed);
1369 vals[j] = if ctx.seed & 0x8000 != 0 {
1370 1.0 / 256.0
1371 } else {
1372 -1.0 / 256.0
1373 };
1374 }
1375 let vnoise = vld1q_f32(vals.as_ptr());
1376 let vnoise1 = vld1q_f32(vals.as_ptr().add(4));
1377 let vlb = vld1q_f32(lb.as_ptr().add(i));
1378 let vlb1 = vld1q_f32(lb.as_ptr().add(i + 4));
1379 let vres = vaddq_f32(vlb, vnoise);
1380 let vres1 = vaddq_f32(vlb1, vnoise1);
1381 vst1q_f32(x.as_mut_ptr().add(i), vres);
1382 vst1q_f32(x.as_mut_ptr().add(i + 4), vres1);
1383 i += 8;
1384 }
1385 for j in i..n {
1386 ctx.seed = celt_lcg_rand(ctx.seed);
1387 x[j] = lb[j]
1388 + if ctx.seed & 0x8000 != 0 {
1389 1.0 / 256.0
1390 } else {
1391 -1.0 / 256.0
1392 };
1393 }
1394 }
1395 #[cfg(not(target_arch = "aarch64"))]
1396 {
1397 for j in 0..n {
1398 ctx.seed = celt_lcg_rand(ctx.seed);
1399 x[j] = lb[j]
1400 + if ctx.seed & 0x8000 != 0 {
1401 1.0 / 256.0
1402 } else {
1403 -1.0 / 256.0
1404 };
1405 }
1406 }
1407 renormalise_vector(x, n, gain);
1408 cm = fill_masked;
1409 } else {
1410 for xv in x[..n].iter_mut() {
1411 ctx.seed = celt_lcg_rand(ctx.seed);
1412 *xv = ((ctx.seed as i32 >> 20) as f32) / 16384.0;
1413 }
1414 renormalise_vector(x, n, gain);
1415 cm = cm_mask;
1416 }
1417 }
1418 cm
1419 }
1420 }
1421}
1422
1423#[cfg(target_arch = "aarch64")]
1424#[inline(always)]
1425unsafe fn deinterleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1426 let n = n0 * stride;
1427 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1428 let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1429
1430 for i in 0..stride {
1431 let src_offset = i;
1432 let dst_offset = i * n0;
1433 for j in 0..n0 {
1434 tmp[dst_offset + j] = x[j * stride + src_offset];
1435 }
1436 }
1437
1438 x[..n].copy_from_slice(tmp);
1439}
1440
1441pub fn deinterleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1442 let n = n0 * stride;
1443
1444 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1445
1446 let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1447 if hadamard {
1448 let offset = match stride {
1449 2 => 0,
1450 4 => 2,
1451 8 => 6,
1452 16 => 14,
1453 _ => 0,
1454 };
1455 let ordery = &ORDERY_TABLE[offset..offset + stride];
1456 for i in 0..stride {
1457 for j in 0..n0 {
1458 tmp[ordery[i] as usize * n0 + j] = x[j * stride + i];
1459 }
1460 }
1461 } else {
1462 #[cfg(target_arch = "aarch64")]
1463 unsafe {
1464 if n0 >= 4 {
1465 deinterleave_hadamard_neon(x, n0, stride);
1466 return;
1467 }
1468 }
1469 for i in 0..stride {
1470 for j in 0..n0 {
1471 tmp[i * n0 + j] = x[j * stride + i];
1472 }
1473 }
1474 }
1475 x[..n].copy_from_slice(tmp);
1476}
1477
1478#[cfg(target_arch = "aarch64")]
1479#[inline(always)]
1480unsafe fn interleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1481 let n = n0 * stride;
1482 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1483 let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1484
1485 for i in 0..stride {
1486 let src_offset = i * n0;
1487 let dst_offset = i;
1488 for j in 0..n0 {
1489 tmp[j * stride + dst_offset] = x[src_offset + j];
1490 }
1491 }
1492
1493 x[..n].copy_from_slice(tmp);
1494}
1495
1496pub fn interleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1497 let n = n0 * stride;
1498 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1499 let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1500 if hadamard {
1501 let offset = match stride {
1502 2 => 0,
1503 4 => 2,
1504 8 => 6,
1505 16 => 14,
1506 _ => 0,
1507 };
1508 let ordery = &ORDERY_TABLE[offset..offset + stride];
1509 for i in 0..stride {
1510 for j in 0..n0 {
1511 tmp[j * stride + i] = x[ordery[i] as usize * n0 + j];
1512 }
1513 }
1514 } else {
1515 #[cfg(target_arch = "aarch64")]
1516 unsafe {
1517 if n0 >= 4 {
1518 interleave_hadamard_neon(x, n0, stride);
1519 return;
1520 }
1521 }
1522 for i in 0..stride {
1523 for j in 0..n0 {
1524 tmp[j * stride + i] = x[i * n0 + j];
1525 }
1526 }
1527 }
1528 x[..n].copy_from_slice(tmp);
1529}
1530
1531const ORDERY_TABLE: [i32; 30] = [
1532 1, 0, 3, 0, 2, 1, 7, 0, 4, 3, 6, 1, 5, 2, 15, 0, 8, 7, 12, 3, 11, 4, 14, 1, 9, 6, 13, 2, 10, 5,
1533];
1534
1535fn quant_band_n1(
1536 ctx: &mut BandCtx,
1537 x: &mut [f32],
1538 y: Option<&mut [f32]>,
1539 lowband_out: Option<&mut [f32]>,
1540) -> u32 {
1541 let mut sign = 0;
1542 if ctx.remaining_bits >= 1 << BITRES {
1543 if ctx.encode {
1544 sign = if x[0] < 0.0 { 1 } else { 0 };
1545 ctx.rc.enc_bits(sign as u32, 1);
1546 } else {
1547 sign = ctx.rc.dec_bits(1) as i32;
1548 }
1549 ctx.remaining_bits -= 1 << BITRES;
1550 }
1551 if ctx.resynth {
1552 x[0] = if sign != 0 { -1.0 } else { 1.0 };
1553 }
1554 if let Some(y_val) = y {
1555 let mut y_sign = 0;
1556 if ctx.remaining_bits >= 1 << BITRES {
1557 if ctx.encode {
1558 y_sign = if y_val[0] < 0.0 { 1 } else { 0 };
1559 ctx.rc.enc_bits(y_sign as u32, 1);
1560 } else {
1561 y_sign = ctx.rc.dec_bits(1) as i32;
1562 }
1563 ctx.remaining_bits -= 1 << BITRES;
1564 }
1565 if ctx.resynth {
1566 y_val[0] = if y_sign != 0 { -1.0 } else { 1.0 };
1567 }
1568 }
1569 if let Some(l_out) = lowband_out {
1570 l_out[0] = x[0] / 16.0;
1571 }
1572 1
1573}
1574
1575#[allow(clippy::too_many_arguments)]
1576#[inline(always)]
1577pub fn quant_band(
1578 ctx: &mut BandCtx,
1579 x: &mut [f32],
1580 n: usize,
1581 b: i32,
1582 b_blocks: i32,
1583 lowband: Option<&mut [f32]>,
1584 lm: i32,
1585 lowband_out: Option<&mut [f32]>,
1586 gain: f32,
1587 fill: u32,
1588) -> u32 {
1589 let n0 = n;
1590 let b0 = b_blocks;
1591 let long_blocks = b0 == 1;
1592
1593 if n == 1 {
1594 return quant_band_n1(ctx, x, None, lowband_out);
1595 }
1596
1597 let mut b_blocks = b_blocks;
1598 let mut n_b = n / b_blocks as usize;
1599 let mut time_divide = 0;
1600 let mut recombine = 0;
1601 let mut tf_change_local = ctx.tf_change;
1602 let mut fill = fill;
1603
1604 if tf_change_local > 0 {
1605 recombine = tf_change_local;
1606 }
1607
1608 let mut lowband_buf = lowband;
1609
1610 static BIT_INTERLEAVE_TABLE: [u8; 16] = [0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3];
1611
1612 for k in 0..recombine {
1613 if ctx.encode {
1614 haar1(x, n >> k, 1 << k);
1615 }
1616 if let Some(ref mut lb) = lowband_buf {
1617 haar1(lb, n >> k, 1 << k);
1618 }
1619 fill = (BIT_INTERLEAVE_TABLE[(fill & 0xF) as usize] as u32)
1620 | ((BIT_INTERLEAVE_TABLE[(fill >> 4) as usize] as u32) << 2);
1621 }
1622 b_blocks >>= recombine;
1623 n_b <<= recombine;
1624
1625 while n_b & 1 == 0 && tf_change_local < 0 {
1626 if ctx.encode {
1627 haar1(x, n_b, b_blocks as usize);
1628 }
1629 if let Some(ref mut lb) = lowband_buf {
1630 haar1(lb, n_b, b_blocks as usize);
1631 }
1632 fill |= fill << b_blocks;
1633 b_blocks <<= 1;
1634 n_b >>= 1;
1635 time_divide += 1;
1636 tf_change_local += 1;
1637 }
1638
1639 let b0_after = b_blocks;
1640 let n_b0 = n_b;
1641
1642 if b_blocks > 1 {
1643 if ctx.encode {
1644 deinterleave_hadamard(
1645 x,
1646 n_b >> recombine as usize,
1647 (b_blocks << recombine) as usize,
1648 long_blocks,
1649 );
1650 }
1651 if let Some(ref mut lb) = lowband_buf {
1652 deinterleave_hadamard(
1653 lb,
1654 n_b >> recombine as usize,
1655 (b_blocks << recombine) as usize,
1656 long_blocks,
1657 );
1658 }
1659 }
1660
1661 let cm = if ctx.encode {
1662 quant_partition_encode(ctx, x, n, b, b_blocks, lowband_buf, lm, gain, fill)
1663 } else {
1664 quant_partition(ctx, x, n, b, b_blocks, lowband_buf, lm, gain, fill)
1665 };
1666
1667 if ctx.resynth {
1668 let mut cm = cm;
1669
1670 if b_blocks > 1 {
1671 interleave_hadamard(
1672 x,
1673 n_b >> recombine as usize,
1674 (b0_after << recombine) as usize,
1675 long_blocks,
1676 );
1677 }
1678
1679 let mut n_b_undo = n_b0;
1680 let mut b_undo = b0_after;
1681 for _ in 0..time_divide {
1682 b_undo >>= 1;
1683 n_b_undo <<= 1;
1684 cm |= cm >> b_undo;
1685 haar1(x, n_b_undo, b_undo as usize);
1686 }
1687
1688 static BIT_DEINTERLEAVE_TABLE: [u8; 16] = [
1689 0x00, 0x03, 0x0C, 0x0F, 0x30, 0x33, 0x3C, 0x3F, 0xC0, 0xC3, 0xCC, 0xCF, 0xF0, 0xF3,
1690 0xFC, 0xFF,
1691 ];
1692 for k in 0..recombine {
1693 cm = BIT_DEINTERLEAVE_TABLE[cm as usize & 0xF] as u32;
1694 haar1(x, n0 >> k, 1 << k);
1695 }
1696 let mut b_final = b_undo;
1697 b_final <<= recombine;
1698
1699 if let Some(lb_out) = lowband_out {
1700 let scale = (n0 as f32).sqrt();
1701 for j in 0..n0 {
1702 lb_out[j] = scale * x[j];
1703 }
1704 }
1705 cm &= (1u32 << b_final) - 1;
1706 return cm;
1707 }
1708
1709 cm
1710}
1711
1712pub fn stereo_merge(x: &mut [f32], y: &mut [f32], mid: f32, _side: f32, n: usize) {
1713 let mut xp = 0.0f32;
1714 let mut side_e = 0.0f32;
1715 for i in 0..n {
1716 xp += y[i] * x[i];
1717 side_e += y[i] * y[i];
1718 }
1719
1720 xp *= mid;
1721 let el = mid * mid + side_e - 2.0 * xp;
1722 let er = mid * mid + side_e + 2.0 * xp;
1723
1724 if er < 6e-4f32 || el < 6e-4f32 {
1725 y[..n].copy_from_slice(&x[..n]);
1726 return;
1727 }
1728
1729 let lgain = 1.0 / el.sqrt();
1730 let rgain = 1.0 / er.sqrt();
1731
1732 for i in 0..n {
1733 let l = mid * x[i];
1734 let r = y[i];
1735 x[i] = lgain * (l - r);
1736 y[i] = rgain * (l + r);
1737 }
1738}
1739
1740#[inline(always)]
1741fn stereo_split(x: &mut [f32], y: &mut [f32], n: usize) {
1742 let scale = std::f32::consts::FRAC_1_SQRT_2;
1743 for i in 0..n {
1744 let l = scale * x[i];
1745 let r = scale * y[i];
1746 x[i] = l + r;
1747 y[i] = r - l;
1748 }
1749}
1750
1751#[inline(always)]
1752fn intensity_stereo(
1753 m: &CeltMode,
1754 x: &mut [f32],
1755 y: &mut [f32],
1756 band_e: &[f32],
1757 band: usize,
1758 n: usize,
1759) {
1760 let left = band_e[band].max(MIN_STEREO_ENERGY);
1761 let right = band_e[m.nb_ebands + band].max(MIN_STEREO_ENERGY);
1762 let norm = (left * left + right * right).sqrt().max(MIN_STEREO_ENERGY);
1763 let a1 = left / norm;
1764 let a2 = right / norm;
1765 for i in 0..n {
1766 x[i] = a1 * x[i] + a2 * y[i];
1767 }
1768}
1769
1770#[inline(always)]
1771fn special_hybrid_folding(m: &CeltMode, norm: &mut [f32], start: usize, m_val: usize) {
1772 if start + 2 >= m.e_bands.len() {
1773 return;
1774 }
1775 let n1 = m_val * (m.e_bands[start + 1] - m.e_bands[start]) as usize;
1776 let n2 = m_val * (m.e_bands[start + 2] - m.e_bands[start + 1]) as usize;
1777 if n2 <= n1 {
1778 return;
1779 }
1780 let len = n2 - n1;
1781 let src_start = 2 * n1 - n2;
1782 if src_start + len <= norm.len() && n1 + len <= norm.len() {
1783 norm.copy_within(src_start..src_start + len, n1);
1784 }
1785}
1786
1787fn prepare_lowband_views(
1788 norm: &mut [f32],
1789 lowband_scratch_ptr: *mut f32,
1790 allow_lowband_scratch: bool,
1791 effective_lowband: i32,
1792 norm_pos: usize,
1793 n: usize,
1794 want_out: bool,
1795) -> (Option<&mut [f32]>, Option<&mut [f32]>) {
1796 let len = norm.len();
1797 let out_range = if want_out && norm_pos + n <= len {
1798 Some((norm_pos, norm_pos + n))
1799 } else {
1800 None
1801 };
1802
1803 let Some(lb_start) = (if effective_lowband >= 0 {
1804 Some(effective_lowband as usize)
1805 } else {
1806 None
1807 }) else {
1808 let lb_out = out_range.map(|(s, e)| &mut norm[s..e]);
1809 return (None, lb_out);
1810 };
1811 let lb_end = lb_start + n;
1812 if lb_end > len {
1813 let lb_out = out_range.map(|(s, e)| &mut norm[s..e]);
1814 return (None, lb_out);
1815 }
1816
1817 if allow_lowband_scratch {
1818 unsafe {
1819 std::ptr::copy_nonoverlapping(norm.as_ptr().add(lb_start), lowband_scratch_ptr, n)
1820 };
1821 let lb = Some(unsafe { std::slice::from_raw_parts_mut(lowband_scratch_ptr, n) });
1822 let lb_out = out_range.map(|(s, e)| &mut norm[s..e]);
1823 return (lb, lb_out);
1824 }
1825
1826 if let Some((out_start, out_end)) = out_range {
1827 if lb_end <= out_start {
1828 let (left, right) = norm.split_at_mut(out_start);
1829 let lb = Some(&mut left[lb_start..lb_end]);
1830 let lb_out = Some(&mut right[..(out_end - out_start)]);
1831 return (lb, lb_out);
1832 }
1833 if out_end <= lb_start {
1834 let (left, right) = norm.split_at_mut(lb_start);
1835 let lb_out = Some(&mut left[out_start..out_end]);
1836 let lb = Some(&mut right[..n]);
1837 return (lb, lb_out);
1838 }
1839 return (Some(&mut norm[lb_start..lb_end]), None);
1840 }
1841
1842 (Some(&mut norm[lb_start..lb_end]), None)
1843}
1844
1845#[cfg(target_arch = "x86_64")]
1846#[target_feature(enable = "avx2")]
1847#[allow(dead_code)]
1848unsafe fn stereo_merge_avx2(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1849 use std::arch::x86_64::*;
1850
1851 let mut i = 0;
1852
1853 let v_mid = _mm256_set1_ps(mid);
1854 let v_side = _mm256_set1_ps(side);
1855
1856 while i + 15 < n {
1857 let x0 = _mm256_loadu_ps(x.as_ptr().add(i));
1858 let x1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
1859 let y0 = _mm256_loadu_ps(y.as_ptr().add(i));
1860 let y1 = _mm256_loadu_ps(y.as_ptr().add(i + 8));
1861
1862 let x_val0 = _mm256_mul_ps(x0, v_mid);
1863 let x_val1 = _mm256_mul_ps(x1, v_mid);
1864 let y_val0 = _mm256_mul_ps(y0, v_side);
1865 let y_val1 = _mm256_mul_ps(y1, v_side);
1866
1867 let new_x0 = _mm256_sub_ps(x_val0, y_val0);
1868 let new_x1 = _mm256_sub_ps(x_val1, y_val1);
1869 let new_y0 = _mm256_add_ps(x_val0, y_val0);
1870 let new_y1 = _mm256_add_ps(x_val1, y_val1);
1871
1872 _mm256_storeu_ps(x.as_mut_ptr().add(i), new_x0);
1873 _mm256_storeu_ps(x.as_mut_ptr().add(i + 8), new_x1);
1874 _mm256_storeu_ps(y.as_mut_ptr().add(i), new_y0);
1875 _mm256_storeu_ps(y.as_mut_ptr().add(i + 8), new_y1);
1876
1877 i += 16;
1878 }
1879
1880 while i + 7 < n {
1881 let x0 = _mm256_loadu_ps(x.as_ptr().add(i));
1882 let y0 = _mm256_loadu_ps(y.as_ptr().add(i));
1883
1884 let x_val = _mm256_mul_ps(x0, v_mid);
1885 let y_val = _mm256_mul_ps(y0, v_side);
1886
1887 let new_x = _mm256_sub_ps(x_val, y_val);
1888 let new_y = _mm256_add_ps(x_val, y_val);
1889
1890 _mm256_storeu_ps(x.as_mut_ptr().add(i), new_x);
1891 _mm256_storeu_ps(y.as_mut_ptr().add(i), new_y);
1892
1893 i += 8;
1894 }
1895
1896 for j in i..n {
1897 let x_val = x[j] * mid;
1898 let y_val = y[j] * side;
1899 x[j] = x_val - y_val;
1900 y[j] = x_val + y_val;
1901 }
1902}
1903
1904#[allow(dead_code)]
1905#[inline]
1906fn stereo_merge_scalar(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1907 for i in 0..n {
1908 let x_val = x[i] * mid;
1909 let y_val = y[i] * side;
1910 x[i] = x_val - y_val;
1911 y[i] = x_val + y_val;
1912 }
1913}
1914
1915#[cfg(target_arch = "aarch64")]
1916#[allow(dead_code)]
1917fn stereo_merge_neon(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1918 use std::arch::aarch64::*;
1919
1920 unsafe {
1921 let vmid = vdupq_n_f32(mid);
1922 let vside = vdupq_n_f32(side);
1923
1924 let n16 = n & !15;
1925 for i in (0..n16).step_by(16) {
1926 let x0 = vld1q_f32(x.as_ptr().add(i));
1927 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
1928 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
1929 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
1930
1931 let y0 = vld1q_f32(y.as_ptr().add(i));
1932 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
1933 let y2 = vld1q_f32(y.as_ptr().add(i + 8));
1934 let y3 = vld1q_f32(y.as_ptr().add(i + 12));
1935
1936 let xv0 = vmulq_f32(x0, vmid);
1937 let xv1 = vmulq_f32(x1, vmid);
1938 let xv2 = vmulq_f32(x2, vmid);
1939 let xv3 = vmulq_f32(x3, vmid);
1940
1941 let yv0 = vmulq_f32(y0, vside);
1942 let yv1 = vmulq_f32(y1, vside);
1943 let yv2 = vmulq_f32(y2, vside);
1944 let yv3 = vmulq_f32(y3, vside);
1945
1946 vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(xv0, yv0));
1947 vst1q_f32(x.as_mut_ptr().add(i + 4), vsubq_f32(xv1, yv1));
1948 vst1q_f32(x.as_mut_ptr().add(i + 8), vsubq_f32(xv2, yv2));
1949 vst1q_f32(x.as_mut_ptr().add(i + 12), vsubq_f32(xv3, yv3));
1950
1951 vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(xv0, yv0));
1952 vst1q_f32(y.as_mut_ptr().add(i + 4), vaddq_f32(xv1, yv1));
1953 vst1q_f32(y.as_mut_ptr().add(i + 8), vaddq_f32(xv2, yv2));
1954 vst1q_f32(y.as_mut_ptr().add(i + 12), vaddq_f32(xv3, yv3));
1955 }
1956
1957 let n4 = (n & !3) - n16;
1958 for i in (n16..n16 + n4).step_by(4) {
1959 let xv = vld1q_f32(x.as_ptr().add(i));
1960 let yv = vld1q_f32(y.as_ptr().add(i));
1961
1962 let x_val = vmulq_f32(xv, vmid);
1963 let y_val = vmulq_f32(yv, vside);
1964
1965 vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(x_val, y_val));
1966 vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(x_val, y_val));
1967 }
1968
1969 for i in (n16 + n4)..n {
1970 let x_val = x[i] * mid;
1971 let y_val = y[i] * side;
1972 x[i] = x_val - y_val;
1973 y[i] = x_val + y_val;
1974 }
1975 }
1976}
1977
1978#[allow(clippy::too_many_arguments)]
1979#[inline(always)]
1980pub fn quant_band_stereo(
1981 ctx: &mut BandCtx,
1982 x: &mut [f32],
1983 y: &mut [f32],
1984 n: usize,
1985 b: i32,
1986 b_blocks: i32,
1987 lowband: Option<&mut [f32]>,
1988 lm: i32,
1989 lowband_out: Option<&mut [f32]>,
1990 _gain: f32,
1991 fill: u32,
1992) -> u32 {
1993 if n == 1 {
1994 return quant_band_n1(ctx, x, Some(y), lowband_out);
1995 }
1996
1997 if ctx.encode
1998 && (ctx.band_e[ctx.i] < MIN_STEREO_ENERGY
1999 || ctx.band_e[ctx.m.nb_ebands + ctx.i] < MIN_STEREO_ENERGY)
2000 {
2001 if ctx.band_e[ctx.i] > ctx.band_e[ctx.m.nb_ebands + ctx.i] {
2002 y.copy_from_slice(x);
2003 } else {
2004 x.copy_from_slice(y);
2005 }
2006 }
2007
2008 let mut sctx = SplitCtx {
2009 inv: false,
2010 imid: 0,
2011 iside: 0,
2012 delta: 0,
2013 itheta: 0,
2014 qalloc: 0,
2015 };
2016 let mut b_mut = b;
2017 let mut fill_mut = fill;
2018 compute_theta(
2019 ctx,
2020 &mut sctx,
2021 x,
2022 y,
2023 n,
2024 &mut b_mut,
2025 b_blocks,
2026 b_blocks,
2027 lm,
2028 true,
2029 &mut fill_mut,
2030 );
2031
2032 let mid_gain = sctx.imid as f32 / 32768.0;
2033 let side_gain = sctx.iside as f32 / 32768.0;
2034
2035 if n == 2 {
2036 let orig_fill = fill;
2037 let mut mbits = b_mut;
2038 let mut sbits = 0;
2039 if sctx.itheta != 0 && sctx.itheta != 16384 {
2040 sbits = 1 << BITRES;
2041 }
2042 mbits -= sbits;
2043 let c = sctx.itheta > 8192;
2044 ctx.remaining_bits -= sctx.qalloc + sbits;
2045
2046 let mut sign = 0;
2047 if sbits != 0 {
2048 if ctx.encode {
2049 sign = if c {
2050 if (y[0] * x[1] - y[1] * x[0]) < 0.0 {
2051 1
2052 } else {
2053 0
2054 }
2055 } else if (x[0] * y[1] - x[1] * y[0]) < 0.0 {
2056 1
2057 } else {
2058 0
2059 };
2060 ctx.rc.enc_bits(sign as u32, 1);
2061 } else {
2062 sign = ctx.rc.dec_bits(1) as i32;
2063 }
2064 }
2065 let sign_val = (1 - 2 * sign) as f32;
2066 let cm = if c {
2067 let cm = quant_band(
2068 ctx,
2069 y,
2070 n,
2071 mbits,
2072 b_blocks,
2073 lowband,
2074 lm,
2075 lowband_out,
2076 1.0,
2077 orig_fill,
2078 );
2079 x[0] = -sign_val * y[1];
2080 x[1] = sign_val * y[0];
2081 cm
2082 } else {
2083 let cm = quant_band(
2084 ctx,
2085 x,
2086 n,
2087 mbits,
2088 b_blocks,
2089 lowband,
2090 lm,
2091 lowband_out,
2092 1.0,
2093 orig_fill,
2094 );
2095 y[0] = -sign_val * x[1];
2096 y[1] = sign_val * x[0];
2097 cm
2098 };
2099
2100 if ctx.resynth {
2101 let x0 = x[0];
2102 let x1 = x[1];
2103 let y0 = y[0];
2104 let y1 = y[1];
2105 let mx0 = mid_gain * x0;
2106 let mx1 = mid_gain * x1;
2107 let sy0 = side_gain * y0;
2108 let sy1 = side_gain * y1;
2109 x[0] = mx0 - sy0;
2110 x[1] = mx1 - sy1;
2111 y[0] = mx0 + sy0;
2112 y[1] = mx1 + sy1;
2113 }
2114 return cm;
2115 }
2116
2117 ctx.remaining_bits -= sctx.qalloc;
2118 let mut mbits = (0).max((b_mut - sctx.delta) / 2).min(b_mut);
2119 let mut sbits = b_mut - mbits;
2120
2121 let mut rebalance = ctx.remaining_bits;
2122 let mut cm;
2123
2124 if mbits >= sbits {
2125 cm = quant_band(
2126 ctx,
2127 x,
2128 n,
2129 mbits,
2130 b_blocks,
2131 lowband,
2132 lm,
2133 lowband_out,
2134 1.0,
2135 fill_mut,
2136 );
2137 rebalance = mbits - (rebalance - ctx.remaining_bits);
2138 if rebalance > (3 << 3) && sctx.itheta != 0 {
2139 sbits += rebalance - (3 << 3);
2140 }
2141 cm |= quant_band(
2142 ctx,
2143 y,
2144 n,
2145 sbits,
2146 b_blocks,
2147 None,
2148 lm,
2149 None,
2150 side_gain,
2151 fill_mut >> b_blocks,
2152 );
2153 } else {
2154 cm = quant_band(
2155 ctx,
2156 y,
2157 n,
2158 sbits,
2159 b_blocks,
2160 None,
2161 lm,
2162 None,
2163 side_gain,
2164 fill_mut >> b_blocks,
2165 );
2166 rebalance = sbits - (rebalance - ctx.remaining_bits);
2167 if rebalance > (3 << 3) && sctx.itheta != 16384 {
2168 mbits += rebalance - (3 << 3);
2169 }
2170 cm |= quant_band(
2171 ctx,
2172 x,
2173 n,
2174 mbits,
2175 b_blocks,
2176 lowband,
2177 lm,
2178 lowband_out,
2179 1.0,
2180 fill_mut,
2181 );
2182 }
2183
2184 if ctx.resynth {
2185 stereo_merge(x, y, mid_gain, side_gain, n);
2186 if sctx.inv {
2187 for yv in y[..n].iter_mut() {
2188 *yv = -*yv;
2189 }
2190 }
2191 }
2192 cm
2193}
2194
2195#[allow(clippy::too_many_arguments)]
2196pub fn quant_all_bands(
2197 encode: bool,
2198 m: &CeltMode,
2199 start: usize,
2200 end: usize,
2201 x: &mut [f32],
2202 mut y: Option<&mut [f32]>,
2203 collapse_masks: &mut [u32],
2204 band_e: &[f32],
2205 pulses: &[i32],
2206 short_blocks: bool,
2207 spread: i32,
2208 dual_stereo: &mut bool,
2209 intensity: usize,
2210 tf_res: &[i32],
2211 total_bits: i32,
2212 balance: &mut i32,
2213 rc: &mut RangeCoder,
2214 lm: i32,
2215 coded_bands: i32,
2216 resynth: bool,
2217 disable_inv: bool,
2218 seed: &mut u32,
2219) {
2220 let mut balance_val = *balance;
2221 let b_blocks = if short_blocks { 1 << lm } else { 1 };
2222 let c_channels = if y.is_some() { 2 } else { 1 };
2223 let m_val = 1usize << lm as usize;
2224
2225 let norm_offset = m_val * (m.e_bands[start] as usize);
2226 let norm_size = m_val * (m.e_bands[m.nb_ebands - 1] as usize) - norm_offset;
2227
2228 const MAX_NORM_SIZE: usize = 800;
2229 debug_assert!(norm_size <= MAX_NORM_SIZE);
2230
2231 let mut norm_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_NORM_SIZE];
2232 let norm =
2233 unsafe { std::slice::from_raw_parts_mut(norm_buf.as_mut_ptr() as *mut f32, norm_size) };
2234 let mut norm2_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_NORM_SIZE];
2235 let norm2 =
2236 unsafe { std::slice::from_raw_parts_mut(norm2_buf.as_mut_ptr() as *mut f32, norm_size) };
2237
2238 let mut lowband_scratch_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
2239 let lowband_scratch_ptr = lowband_scratch_buf.as_mut_ptr() as *mut f32;
2240
2241 let mut lowband_offset: usize = 0;
2242 let mut update_lowband = true;
2243 let mut avoid_split_noise = b_blocks > 1;
2244
2245 let e_bands = &m.e_bands;
2246 let mut ctx_seed = *seed;
2247
2248 for i in start..end {
2249 let e_band_i = e_bands[i] as usize;
2250 let e_band_i1 = e_bands[i + 1] as usize;
2251 let offset = m_val * e_band_i;
2252 let n = m_val * (e_band_i1 - e_band_i);
2253 let last = i == end - 1;
2254
2255 let tell = tell_frac_inline!(rc);
2256 if i != start {
2257 balance_val -= tell;
2258 }
2259 let remaining_bits = total_bits - tell - 1;
2260
2261 let mut b = 0i32;
2262 if i < coded_bands as usize {
2263 let curr_balance = celt_sudiv(balance_val, 3i32.min(coded_bands - i as i32));
2264 b = 0i32.max(16383i32.min((remaining_bits + 1).min(pulses[i] + curr_balance)));
2265 }
2266
2267 let norm_pos = m_val * e_band_i - norm_offset;
2268 let tf_change = tf_res[i];
2269
2270 let mut effective_lowband: i32 = -1;
2271 let mut x_cm: u32;
2272 let mut y_cm: u32;
2273
2274 let band_start_abs = m_val * e_band_i;
2275 let start_abs = m_val * (e_bands[start] as usize);
2276 if resynth
2277 && ((band_start_abs as isize - n as isize >= start_abs as isize) || i == start + 1)
2278 && (update_lowband || lowband_offset == 0)
2279 {
2280 lowband_offset = i;
2281 }
2282
2283 if resynth && i == start + 1 {
2284 special_hybrid_folding(m, norm, start, m_val);
2285 if *dual_stereo {
2286 special_hybrid_folding(m, norm2, start, m_val);
2287 }
2288 }
2289
2290 if lowband_offset != 0 && (spread != SPREAD_AGGRESSIVE || b_blocks > 1 || tf_change < 0) {
2291 effective_lowband = 0i32.max(
2292 (m_val * e_bands[lowband_offset] as usize) as i32 - norm_offset as i32 - n as i32,
2293 );
2294 let el_abs = effective_lowband as usize + norm_offset;
2295
2296 let mut fold_start = lowband_offset;
2297 while fold_start > 0 {
2298 fold_start -= 1;
2299 if m_val * (e_bands[fold_start] as usize) <= el_abs {
2300 break;
2301 }
2302 }
2303
2304 let mut fold_end = lowband_offset.saturating_sub(1);
2305 loop {
2306 fold_end += 1;
2307 if fold_end >= i || m_val * (e_bands[fold_end] as usize) >= el_abs + n {
2308 break;
2309 }
2310 }
2311
2312 x_cm = 0;
2313 y_cm = 0;
2314 let mut fi = fold_start;
2315 loop {
2316 x_cm |= collapse_masks[fi * c_channels];
2317 y_cm |= collapse_masks[fi * c_channels + c_channels - 1];
2318 fi += 1;
2319 if fi >= fold_end {
2320 break;
2321 }
2322 }
2323 } else {
2324 x_cm = (1u32 << b_blocks) - 1;
2325 y_cm = (1u32 << b_blocks) - 1;
2326 }
2327
2328 let mut ctx = BandCtx {
2329 encode,
2330 m,
2331 i,
2332 band_e,
2333 rc,
2334 spread,
2335 remaining_bits,
2336 resynth,
2337 tf_change,
2338 intensity,
2339 theta_round: 0,
2340 avoid_split_noise,
2341 arch: 0,
2342 disable_inv,
2343 seed: ctx_seed,
2344 };
2345
2346 let x_slice = &mut x[offset..offset + n];
2347 let band_uses_direct_norm = i >= m.eff_ebands;
2348 let allow_lowband_scratch = !(band_uses_direct_norm || (last && !encode));
2349 if *dual_stereo && i == intensity {
2350 *dual_stereo = false;
2351 if resynth {
2352 for j in 0..norm_pos {
2353 norm[j] = 0.5 * (norm[j] + norm2[j]);
2354 }
2355 }
2356 }
2357
2358 if *dual_stereo {
2359 let y_slice = &mut y.as_mut().unwrap()[offset..offset + n];
2360
2361 let (lb_x, lb_out_x) = prepare_lowband_views(
2362 norm,
2363 lowband_scratch_ptr,
2364 allow_lowband_scratch,
2365 effective_lowband,
2366 norm_pos,
2367 n,
2368 !last,
2369 );
2370 x_cm = quant_band(
2371 &mut ctx,
2372 x_slice,
2373 n,
2374 b / 2,
2375 b_blocks,
2376 lb_x,
2377 lm,
2378 lb_out_x,
2379 1.0,
2380 x_cm,
2381 );
2382
2383 let (lb_y, lb_out_y) = prepare_lowband_views(
2384 norm2,
2385 lowband_scratch_ptr,
2386 allow_lowband_scratch,
2387 effective_lowband,
2388 norm_pos,
2389 n,
2390 !last,
2391 );
2392 y_cm = quant_band(
2393 &mut ctx,
2394 y_slice,
2395 n,
2396 b / 2,
2397 b_blocks,
2398 lb_y,
2399 lm,
2400 lb_out_y,
2401 1.0,
2402 y_cm,
2403 );
2404 } else if let Some(y_all) = y.as_mut() {
2405 let y_slice = &mut y_all[offset..offset + n];
2406 let (lb, lb_out) = prepare_lowband_views(
2407 norm,
2408 lowband_scratch_ptr,
2409 allow_lowband_scratch,
2410 effective_lowband,
2411 norm_pos,
2412 n,
2413 !last,
2414 );
2415 x_cm = quant_band_stereo(
2416 &mut ctx,
2417 x_slice,
2418 y_slice,
2419 n,
2420 b,
2421 b_blocks,
2422 lb,
2423 lm,
2424 lb_out,
2425 1.0,
2426 x_cm | y_cm,
2427 );
2428 y_cm = x_cm;
2429 } else {
2430 let (lb, lb_out) = prepare_lowband_views(
2431 norm,
2432 lowband_scratch_ptr,
2433 allow_lowband_scratch,
2434 effective_lowband,
2435 norm_pos,
2436 n,
2437 !last,
2438 );
2439 x_cm = quant_band(&mut ctx, x_slice, n, b, b_blocks, lb, lm, lb_out, 1.0, x_cm);
2440 y_cm = x_cm;
2441 }
2442
2443 collapse_masks[i * c_channels] = (x_cm & 0xFF) as u8 as u32;
2444 if c_channels == 2 {
2445 collapse_masks[i * c_channels + 1] = (y_cm & 0xFF) as u8 as u32;
2446 }
2447
2448 balance_val += pulses[i] + tell;
2449 ctx_seed = ctx.seed;
2450 update_lowband = b > ((n as i32) << BITRES);
2451
2452 avoid_split_noise = false;
2453 }
2454 *balance = balance_val;
2455 *seed = ctx_seed;
2456}
2457
2458#[cfg(target_arch = "aarch64")]
2459fn compute_band_energy_neon(band: &[f32]) -> f32 {
2460 use std::arch::aarch64::*;
2461
2462 let n = band.len();
2463 let mut sum = 1e-27f32;
2464
2465 unsafe {
2466 let n16 = n & !15;
2467 if n16 > 0 {
2468 let mut acc0 = vdupq_n_f32(0.0);
2469 let mut acc1 = vdupq_n_f32(0.0);
2470 let mut acc2 = vdupq_n_f32(0.0);
2471 let mut acc3 = vdupq_n_f32(0.0);
2472
2473 for i in (0..n16).step_by(16) {
2474 let v0 = vld1q_f32(band.as_ptr().add(i));
2475 let v1 = vld1q_f32(band.as_ptr().add(i + 4));
2476 let v2 = vld1q_f32(band.as_ptr().add(i + 8));
2477 let v3 = vld1q_f32(band.as_ptr().add(i + 12));
2478
2479 acc0 = vfmaq_f32(acc0, v0, v0);
2480 acc1 = vfmaq_f32(acc1, v1, v1);
2481 acc2 = vfmaq_f32(acc2, v2, v2);
2482 acc3 = vfmaq_f32(acc3, v3, v3);
2483 }
2484
2485 acc0 = vaddq_f32(acc0, acc1);
2486 acc2 = vaddq_f32(acc2, acc3);
2487 acc0 = vaddq_f32(acc0, acc2);
2488 sum += vaddvq_f32(acc0);
2489 }
2490
2491 let n4 = (n & !3) - n16;
2492 if n4 > 0 {
2493 let mut acc = vdupq_n_f32(0.0);
2494 for i in (n16..n16 + n4).step_by(4) {
2495 let v = vld1q_f32(band.as_ptr().add(i));
2496 acc = vfmaq_f32(acc, v, v);
2497 }
2498 sum += vaddvq_f32(acc);
2499 }
2500
2501 for i in (n16 + n4)..n {
2502 let v = band[i];
2503 sum += v * v;
2504 }
2505 }
2506
2507 sum.sqrt()
2508}
2509
2510#[cfg(target_arch = "x86_64")]
2511#[target_feature(enable = "avx2,fma")]
2512unsafe fn compute_band_energy_avx2(band: &[f32]) -> f32 {
2513 use std::arch::x86_64::*;
2514
2515 let n = band.len();
2516 let mut i = 0usize;
2517
2518 let mut acc0 = _mm256_setzero_ps();
2519 let mut acc1 = _mm256_setzero_ps();
2520
2521 while i + 16 <= n {
2522 let v0 = _mm256_loadu_ps(band.as_ptr().add(i));
2523 let v1 = _mm256_loadu_ps(band.as_ptr().add(i + 8));
2524 acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2525 acc1 = _mm256_fmadd_ps(v1, v1, acc1);
2526 i += 16;
2527 }
2528
2529 if i + 8 <= n {
2530 let v0 = _mm256_loadu_ps(band.as_ptr().add(i));
2531 acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2532 i += 8;
2533 }
2534
2535 let acc = _mm256_add_ps(acc0, acc1);
2536 let hi = _mm256_extractf128_ps(acc, 1);
2537 let lo = _mm256_castps256_ps128(acc);
2538 let s4 = _mm_add_ps(lo, hi);
2539 let t1 = _mm_movehl_ps(s4, s4);
2540 let s2 = _mm_add_ps(s4, t1);
2541 let t2 = _mm_shuffle_ps(s2, s2, 0x55);
2542 let mut sum = 1e-27f32 + _mm_cvtss_f32(_mm_add_ss(s2, t2));
2543
2544 for &v in &band[i..] {
2545 sum += v * v;
2546 }
2547
2548 sum.sqrt()
2549}
2550
2551pub fn compute_band_energies(
2552 m: &CeltMode,
2553 x: &[f32],
2554 band_e: &mut [f32],
2555 end: usize,
2556 channels: usize,
2557 lm: usize,
2558) {
2559 let frame_size = m.short_mdct_size << lm;
2560
2561 #[cfg(target_arch = "x86_64")]
2562 let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2563
2564 for c in 0..channels {
2565 let ch = &x[c * frame_size..(c + 1) * frame_size];
2566 for i in 0..end {
2567 let offset = (m.e_bands[i] as usize) << lm;
2568 let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2569 let band = &ch[offset..offset + n];
2570
2571 #[cfg(target_arch = "aarch64")]
2572 {
2573 band_e[c * m.nb_ebands + i] = compute_band_energy_neon(band);
2574 }
2575 #[cfg(target_arch = "x86_64")]
2576 {
2577 if n >= 8 && use_avx2 {
2578 band_e[c * m.nb_ebands + i] = unsafe { compute_band_energy_avx2(band) };
2579 } else {
2580 let sum = band.iter().fold(1e-27f32, |acc, &v| acc + v * v);
2581 band_e[c * m.nb_ebands + i] = sum.sqrt();
2582 }
2583 }
2584 #[cfg(all(not(target_arch = "aarch64"), not(target_arch = "x86_64")))]
2585 {
2586 let sum = band.iter().fold(1e-27f32, |acc, &v| acc + v * v);
2587 band_e[c * m.nb_ebands + i] = sum.sqrt();
2588 }
2589 }
2590 }
2591}
2592
2593pub fn amp2log2(
2594 m: &CeltMode,
2595 start: usize,
2596 end: usize,
2597 band_e: &[f32],
2598 band_log_e: &mut [f32],
2599 channels: usize,
2600) {
2601 for c in 0..channels {
2602 for i in 0..start {
2603 band_log_e[c * m.nb_ebands + i] = -14.0;
2604 }
2605 for i in start..end {
2606 let val = band_e[c * m.nb_ebands + i].max(1e-10);
2607 band_log_e[c * m.nb_ebands + i] = val.log2() - m.e_means[i];
2608 }
2609 }
2610}
2611
2612pub fn log2amp(m: &CeltMode, end: usize, band_e: &mut [f32], band_log_e: &[f32], channels: usize) {
2613 for c in 0..channels {
2614 for i in 0..end {
2615 band_e[c * m.nb_ebands + i] = band_log_e[c * m.nb_ebands + i] + m.e_means[i];
2616 }
2617 }
2618}
2619
2620pub fn normalise_bands(
2621 m: &CeltMode,
2622 freq: &[f32],
2623 x: &mut [f32],
2624 band_e: &[f32],
2625 end: usize,
2626 channels: usize,
2627 m_val: usize,
2628) {
2629 let lm = m_val.trailing_zeros() as usize;
2630 let frame_size = m.short_mdct_size << lm;
2631 #[cfg(target_arch = "x86_64")]
2632 let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2633 for c in 0..channels {
2634 for i in 0..end {
2635 let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2636 let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2637 let norm = 1.0 / (1e-27 + band_e[c * m.nb_ebands + i]);
2638 let src = &freq[base..base + n];
2639 let dst = &mut x[base..base + n];
2640 #[cfg(target_arch = "x86_64")]
2641 if n >= 8 && use_avx2 {
2642 unsafe { scale_slice_avx2(src, dst, norm, n) };
2643 continue;
2644 }
2645 #[cfg(target_arch = "aarch64")]
2646 if n >= 8 {
2647 unsafe { scale_slice_neon(src, dst, norm, n) };
2648 continue;
2649 }
2650 for (d, &s) in dst.iter_mut().zip(src) {
2651 *d = s * norm;
2652 }
2653 }
2654 }
2655}
2656
2657#[cfg(target_arch = "x86_64")]
2658#[target_feature(enable = "avx2")]
2659unsafe fn scale_slice_avx2(src: &[f32], dst: &mut [f32], scale: f32, n: usize) {
2660 use std::arch::x86_64::*;
2661 let vscale = _mm256_set1_ps(scale);
2662 let mut i = 0;
2663
2664 while i + 16 <= n {
2665 let s0 = _mm256_loadu_ps(src.as_ptr().add(i));
2666 let s1 = _mm256_loadu_ps(src.as_ptr().add(i + 8));
2667 _mm256_storeu_ps(dst.as_mut_ptr().add(i), _mm256_mul_ps(s0, vscale));
2668 _mm256_storeu_ps(dst.as_mut_ptr().add(i + 8), _mm256_mul_ps(s1, vscale));
2669 i += 16;
2670 }
2671 while i + 8 <= n {
2672 let sv = _mm256_loadu_ps(src.as_ptr().add(i));
2673 _mm256_storeu_ps(dst.as_mut_ptr().add(i), _mm256_mul_ps(sv, vscale));
2674 i += 8;
2675 }
2676 for j in i..n {
2677 dst[j] = src[j] * scale;
2678 }
2679}
2680
2681#[cfg(target_arch = "aarch64")]
2682#[inline(always)]
2683#[allow(unsafe_op_in_unsafe_fn)]
2684unsafe fn scale_slice_neon(src: &[f32], dst: &mut [f32], scale: f32, n: usize) {
2685 use std::arch::aarch64::*;
2686 let vscale = vdupq_n_f32(scale);
2687 let mut i = 0;
2688
2689 while i + 16 <= n {
2690 let s0 = vld1q_f32(src.as_ptr().add(i));
2691 let s1 = vld1q_f32(src.as_ptr().add(i + 4));
2692 let s2 = vld1q_f32(src.as_ptr().add(i + 8));
2693 let s3 = vld1q_f32(src.as_ptr().add(i + 12));
2694 vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2695 vst1q_f32(dst.as_mut_ptr().add(i + 4), vmulq_f32(s1, vscale));
2696 vst1q_f32(dst.as_mut_ptr().add(i + 8), vmulq_f32(s2, vscale));
2697 vst1q_f32(dst.as_mut_ptr().add(i + 12), vmulq_f32(s3, vscale));
2698 i += 16;
2699 }
2700 while i + 8 <= n {
2701 let s0 = vld1q_f32(src.as_ptr().add(i));
2702 let s1 = vld1q_f32(src.as_ptr().add(i + 4));
2703 vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2704 vst1q_f32(dst.as_mut_ptr().add(i + 4), vmulq_f32(s1, vscale));
2705 i += 8;
2706 }
2707 while i + 4 <= n {
2708 let s0 = vld1q_f32(src.as_ptr().add(i));
2709 vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2710 i += 4;
2711 }
2712 for j in i..n {
2713 dst[j] = src[j] * scale;
2714 }
2715}
2716
2717#[allow(clippy::too_many_arguments)]
2718pub fn denormalise_bands(
2719 m: &CeltMode,
2720 x: &[f32],
2721 freq: &mut [f32],
2722 band_e: &[f32],
2723 start: usize,
2724 end: usize,
2725 channels: usize,
2726 m_val: usize,
2727) {
2728 let lm = m_val.trailing_zeros() as usize;
2729 let frame_size = m.short_mdct_size << lm;
2730 #[cfg(target_arch = "x86_64")]
2731 let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2732
2733 for c in 0..channels {
2734 for i in start..end {
2735 let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2736 let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2737 let band_log = band_e[c * m.nb_ebands + i];
2738
2739 let g = (2.0f32).powf(band_log.min(32.0));
2741 let src = &x[base..base + n];
2742 let dst = &mut freq[base..base + n];
2743 #[cfg(target_arch = "x86_64")]
2744 if n >= 8 && use_avx2 {
2745 unsafe { scale_slice_avx2(src, dst, g, n) };
2746 continue;
2747 }
2748 #[cfg(target_arch = "aarch64")]
2749 if n >= 8 {
2750 unsafe { scale_slice_neon(src, dst, g, n) };
2751 continue;
2752 }
2753 for (d, &s) in dst.iter_mut().zip(src) {
2754 *d = s * g;
2755 }
2756 }
2757 }
2758}
2759
2760pub fn celt_lcg_rand(seed: u32) -> u32 {
2761 seed.wrapping_mul(1664525).wrapping_add(1013904223)
2762}
2763
2764#[cfg(target_arch = "aarch64")]
2765#[inline(always)]
2766#[allow(unsafe_op_in_unsafe_fn)]
2767unsafe fn renormalise_vector_neon(x: &mut [f32], n: usize, gain: f32) {
2768 use std::arch::aarch64::*;
2769
2770 let mut sum_vec = vdupq_n_f32(0.0);
2771 let mut i = 0;
2772
2773 while i + 16 <= n {
2774 let x0 = vld1q_f32(x.as_ptr().add(i));
2775 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2776 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2777 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2778 sum_vec = vfmaq_f32(sum_vec, x0, x0);
2779 sum_vec = vfmaq_f32(sum_vec, x1, x1);
2780 sum_vec = vfmaq_f32(sum_vec, x2, x2);
2781 sum_vec = vfmaq_f32(sum_vec, x3, x3);
2782 i += 16;
2783 }
2784
2785 while i + 8 <= n {
2786 let x0 = vld1q_f32(x.as_ptr().add(i));
2787 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2788 sum_vec = vfmaq_f32(sum_vec, x0, x0);
2789 sum_vec = vfmaq_f32(sum_vec, x1, x1);
2790 i += 8;
2791 }
2792
2793 while i + 4 <= n {
2794 let x0 = vld1q_f32(x.as_ptr().add(i));
2795 sum_vec = vfmaq_f32(sum_vec, x0, x0);
2796 i += 4;
2797 }
2798
2799 let mut e = 1e-15f32 + vaddvq_f32(sum_vec);
2800
2801 for j in i..n {
2802 e += x[j] * x[j];
2803 }
2804
2805 let norm = gain / e.sqrt();
2806 let vnorm = vdupq_n_f32(norm);
2807
2808 i = 0;
2809 while i + 16 <= n {
2810 let x0 = vld1q_f32(x.as_ptr().add(i));
2811 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2812 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2813 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2814 vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2815 vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2816 vst1q_f32(x.as_mut_ptr().add(i + 8), vmulq_f32(x2, vnorm));
2817 vst1q_f32(x.as_mut_ptr().add(i + 12), vmulq_f32(x3, vnorm));
2818 i += 16;
2819 }
2820
2821 while i + 8 <= n {
2822 let x0 = vld1q_f32(x.as_ptr().add(i));
2823 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2824 vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2825 vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2826 i += 8;
2827 }
2828
2829 while i + 4 <= n {
2830 let x0 = vld1q_f32(x.as_ptr().add(i));
2831 vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2832 i += 4;
2833 }
2834
2835 for j in i..n {
2836 x[j] *= norm;
2837 }
2838}
2839
2840#[cfg(target_arch = "x86_64")]
2841#[target_feature(enable = "avx2,fma")]
2842unsafe fn renormalise_vector_avx2(x: &mut [f32], n: usize, gain: f32) {
2843 use std::arch::x86_64::*;
2844
2845 let mut i = 0usize;
2846
2847 let mut acc0 = _mm256_setzero_ps();
2848 let mut acc1 = _mm256_setzero_ps();
2849
2850 while i + 16 <= n {
2851 let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2852 let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
2853 acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2854 acc1 = _mm256_fmadd_ps(v1, v1, acc1);
2855 i += 16;
2856 }
2857
2858 if i + 8 <= n {
2859 let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2860 acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2861 i += 8;
2862 }
2863
2864 let acc = _mm256_add_ps(acc0, acc1);
2865 let hi = _mm256_extractf128_ps(acc, 1);
2866 let lo = _mm256_castps256_ps128(acc);
2867 let s4 = _mm_add_ps(lo, hi);
2868 let t1 = _mm_movehl_ps(s4, s4);
2869 let s2 = _mm_add_ps(s4, t1);
2870 let t2 = _mm_shuffle_ps(s2, s2, 0x55);
2871 let mut e = 1e-15f32 + _mm_cvtss_f32(_mm_add_ss(s2, t2));
2872
2873 for &v in &x[i..n] {
2874 e += v * v;
2875 }
2876
2877 let norm = gain / e.sqrt();
2878 let vnorm = _mm256_set1_ps(norm);
2879
2880 i = 0;
2881 while i + 16 <= n {
2882 let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2883 let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
2884 _mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v0, vnorm));
2885 _mm256_storeu_ps(x.as_mut_ptr().add(i + 8), _mm256_mul_ps(v1, vnorm));
2886 i += 16;
2887 }
2888 while i + 8 <= n {
2889 let v = _mm256_loadu_ps(x.as_ptr().add(i));
2890 _mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v, vnorm));
2891 i += 8;
2892 }
2893 for v in &mut x[i..n] {
2894 *v *= norm;
2895 }
2896}
2897
2898pub fn renormalise_vector(x: &mut [f32], n: usize, gain: f32) {
2899 #[cfg(target_arch = "aarch64")]
2900 unsafe {
2901 renormalise_vector_neon(x, n, gain);
2902 }
2903 #[cfg(target_arch = "x86_64")]
2904 unsafe {
2905 if n >= 16 && std::arch::is_x86_feature_detected!("avx2") {
2906 renormalise_vector_avx2(x, n, gain);
2907 return;
2908 }
2909 }
2910 #[cfg(all(not(target_arch = "aarch64"), not(target_arch = "x86_64")))]
2911 {
2912 let mut e = 1e-15f32;
2913 for &xv in x[..n].iter() {
2914 e += xv * xv;
2915 }
2916 let norm = gain / e.sqrt();
2917 for xv in x[..n].iter_mut() {
2918 *xv *= norm;
2919 }
2920 }
2921 #[cfg(target_arch = "x86_64")]
2922 {
2923 let mut e = 1e-15f32;
2924 for &xv in x[..n].iter() {
2925 e += xv * xv;
2926 }
2927 let norm = gain / e.sqrt();
2928 for xv in x[..n].iter_mut() {
2929 *xv *= norm;
2930 }
2931 }
2932}
2933
2934#[allow(clippy::too_many_arguments)]
2935pub fn anti_collapse(
2936 m: &CeltMode,
2937 x_buf: &mut [f32],
2938 collapse_masks: &[u32],
2939 lm: i32,
2940 channels: usize,
2941 size: usize,
2942 start: usize,
2943 end: usize,
2944 log_e: &[f32],
2945 prev1_log_e: &[f32],
2946 prev2_log_e: &[f32],
2947 pulses: &[i32],
2948 mut seed: u32,
2949) -> u32 {
2950 for i in start..end {
2951 let n0 = (m.e_bands[i + 1] - m.e_bands[i]) as usize;
2952 let depth = if n0 > 0 {
2953 ((1 + pulses[i]) / n0 as i32) >> lm
2954 } else {
2955 0
2956 };
2957
2958 let thresh = 0.5 * (-(0.125 * depth as f32)).exp2();
2959 let sqrt_1 = 1.0 / ((n0 << lm) as f32).sqrt();
2960
2961 for c in 0..channels {
2962 let p1 = prev1_log_e[c * m.nb_ebands + i];
2963 let p2 = prev2_log_e[c * m.nb_ebands + i];
2964
2965 let (p1_adj, p2_adj) = if channels == 1 && prev1_log_e.len() >= 2 * m.nb_ebands {
2966 (
2967 p1.max(prev1_log_e[m.nb_ebands + i]),
2968 p2.max(prev2_log_e[m.nb_ebands + i]),
2969 )
2970 } else {
2971 (p1, p2)
2972 };
2973
2974 let e_diff = log_e[c * m.nb_ebands + i] - p1_adj.min(p2_adj);
2975 let e_diff = e_diff.max(0.0);
2976
2977 let mut r = 2.0 * (-e_diff).exp2();
2978 if lm == 3 {
2979 r *= std::f32::consts::SQRT_2;
2980 }
2981 r = r.min(thresh);
2982 r *= sqrt_1;
2983
2984 let x_offset = c * size + ((m.e_bands[i] as usize) << lm);
2985 let mut renormalize = false;
2986 for k in 0..(1 << lm) {
2987 if (collapse_masks[i * channels + c] & (1 << k)) == 0 {
2988 for j in 0..n0 {
2989 seed = celt_lcg_rand(seed);
2990 x_buf[x_offset + (j << lm) + k] = if (seed & 0x8000) != 0 { r } else { -r };
2991 }
2992 renormalize = true;
2993 }
2994 }
2995 if renormalize {
2996 renormalise_vector(&mut x_buf[x_offset..x_offset + (n0 << lm)], n0 << lm, 1.0);
2997 }
2998 }
2999 }
3000 seed
3001}
3002
3003#[cfg(test)]
3004mod tests {
3005 use super::*;
3006
3007 #[test]
3008 fn test_bitexact_primitives_reference_values() {
3009 assert_eq!(bitexact_cos(64), 32767);
3010 assert_eq!(bitexact_cos(8192), 23171);
3011 assert_eq!(bitexact_cos(16320), 200);
3012
3013 assert_eq!(bitexact_log2tan(32767, 200), 15059);
3014 assert_eq!(bitexact_log2tan(30274, 12540), 2611);
3015 assert_eq!(bitexact_log2tan(23171, 23171), 0);
3016 assert_eq!(bitexact_log2tan(200, 32767), -15059);
3017 }
3018}