1use crate::modes::CeltMode;
2use crate::pvq::*;
3use crate::range_coder::RangeCoder;
4use crate::rate::{BITRES, bits2pulses, get_pulses, pulses2bits};
5use crate::tell_frac_inline;
6
7const MIN_STEREO_ENERGY: f32 = 1e-10;
8
9pub struct BandCtx<'a> {
10 pub encode: bool,
11 pub m: &'a CeltMode,
12 pub i: usize,
13 pub band_e: &'a [f32],
14 pub rc: &'a mut RangeCoder,
15 pub spread: i32,
16 pub remaining_bits: i32,
17 pub resynth: bool,
18 pub tf_change: i32,
19 pub intensity: usize,
20 pub theta_round: i32,
21 pub avoid_split_noise: bool,
22 pub arch: i32,
23 pub disable_inv: bool,
24 pub seed: u32,
25}
26
27#[inline]
28fn bitexact_cos(x: i16) -> i16 {
29 #[inline(always)]
30 fn frac_mul16(a: i16, b: i16) -> i16 {
31 ((16384i32 + (a as i32) * (b as i32)) >> 15) as i16
32 }
33
34 let tmp = (4096i32 + (x as i32) * (x as i32)) >> 13;
35 let x2 = tmp as i16;
36 let x2 = (32767 - x2 as i32
37 + frac_mul16(x2, -7651 + frac_mul16(x2, 8277 + frac_mul16(-626, x2))) as i32)
38 as i16;
39 1 + x2
40}
41
42#[inline]
43pub fn bitexact_log2tan(isin: i32, icos: i32) -> i32 {
44 let ec_ilog = |x: u32| -> i32 {
45 if x == 0 {
46 0
47 } else {
48 32 - x.leading_zeros() as i32
49 }
50 };
51 let lc = ec_ilog(icos.max(0) as u32);
52 let ls = ec_ilog(isin.max(0) as u32);
53 let icos_shifted = if lc > 0 {
54 icos.max(0) << (15 - lc).max(0)
55 } else {
56 0
57 };
58 let isin_shifted = if ls > 0 {
59 isin.max(0) << (15 - ls).max(0)
60 } else {
61 0
62 };
63 let fract_mul = |a: i32, b: i32| -> i32 { (a * b + 16384) >> 15 };
64 (ls - lc) * (1 << 11) + fract_mul(isin_shifted, fract_mul(isin_shifted, -2597) + 7932)
65 - fract_mul(icos_shifted, fract_mul(icos_shifted, -2597) + 7932)
66}
67
68#[inline(always)]
69fn celt_sudiv(n: i32, d: i32) -> i32 {
70 n / d
71}
72
73#[inline]
74fn isqrt32(mut val: u32) -> u32 {
75 let mut g = 0u32;
76 let mut bshift = ((32 - val.leading_zeros()) as i32 - 1) >> 1;
77 let mut b = 1u32 << bshift;
78 while bshift >= 0 {
79 let t = (((g << 1) + b) as u64) << bshift;
80 if t <= val as u64 {
81 g += b;
82 val -= t as u32;
83 }
84 b >>= 1;
85 bshift -= 1;
86 }
87 g
88}
89
90pub const SPREAD_NONE: i32 = 0;
91pub const SPREAD_LIGHT: i32 = 1;
92pub const SPREAD_NORMAL: i32 = 2;
93pub const SPREAD_AGGRESSIVE: i32 = 3;
94
95#[allow(clippy::too_many_arguments)]
96pub fn spreading_decision(
97 m: &CeltMode,
98 x_buf: &[f32],
99 average: &mut i32,
100 last_decision: i32,
101 hf_average: &mut i32,
102 tapset_decision: &mut i32,
103 update_hf: bool,
104 end: usize,
105 channels: usize,
106 m_val: usize,
107 spread_weight: &[i32],
108) -> i32 {
109 let mut sum = 0;
110 let mut nb_bands = 0;
111 let n0 = m_val * m.short_mdct_size;
112 let mut hf_sum = 0;
113
114 if m_val * (m.e_bands[end] as usize - m.e_bands[end - 1] as usize) <= 8 {
115 return SPREAD_NONE;
116 }
117
118 for c in 0..channels {
119 for (i, &sw) in spread_weight[..end].iter().enumerate() {
120 let n = m_val * (m.e_bands[i + 1] as usize - m.e_bands[i] as usize);
121 if n <= 8 {
122 continue;
123 }
124
125 let mut tcount = [0; 3];
126 let offset = m_val * m.e_bands[i] as usize + c * n0;
127 let x = &x_buf[offset..offset + n];
128
129 for xv in x.iter().copied() {
130 let x2n = xv * xv * (n as f32);
131 if x2n < 0.25 {
132 tcount[0] += 1;
133 }
134 if x2n < 0.0625 {
135 tcount[1] += 1;
136 }
137 if x2n < 0.015625 {
138 tcount[2] += 1;
139 }
140 }
141
142 if i > m.nb_ebands - 4 {
143 hf_sum += 32 * (tcount[1] + tcount[0]) / (n as i32);
144 }
145
146 let tmp = (if 2 * tcount[2] >= (n as i32) { 1 } else { 0 })
147 + (if 2 * tcount[1] >= (n as i32) { 1 } else { 0 })
148 + (if 2 * tcount[0] >= (n as i32) { 1 } else { 0 });
149 sum += tmp * sw;
150 nb_bands += sw;
151 }
152 }
153
154 if update_hf {
155 if hf_sum > 0 {
156 hf_sum /= (channels as i32) * (4 - m.nb_ebands as i32 + end as i32);
157 }
158 *hf_average = (*hf_average + hf_sum) >> 1;
159 hf_sum = *hf_average;
160
161 if *tapset_decision == 2 {
162 hf_sum += 4;
163 } else if *tapset_decision == 0 {
164 hf_sum -= 4;
165 }
166
167 if hf_sum > 22 {
168 *tapset_decision = 2;
169 } else if hf_sum > 18 {
170 *tapset_decision = 1;
171 } else {
172 *tapset_decision = 0;
173 }
174 }
175
176 if nb_bands == 0 {
177 return SPREAD_NORMAL;
178 }
179
180 let mut sum_scaled = (sum << 8) / nb_bands;
181 sum_scaled = (sum_scaled + *average) >> 1;
182 *average = sum_scaled;
183
184 let sum_final = (3 * sum_scaled + (((3 - last_decision) << 7) + 64) + 2) >> 2;
185
186 if sum_final < 80 {
187 SPREAD_AGGRESSIVE
188 } else if sum_final < 256 {
189 SPREAD_NORMAL
190 } else if sum_final < 384 {
191 SPREAD_LIGHT
192 } else {
193 SPREAD_NONE
194 }
195}
196
197pub fn haar1(x: &mut [f32], n0: usize, stride: usize) {
198 #[cfg(target_arch = "aarch64")]
199 {
200 haar1_neon(x, n0, stride);
201 }
202 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
203 unsafe {
204 if stride == 1 && n0 >= 16 && is_x86_feature_detected!("avx") {
205 haar1_avx(x, n0);
206 return;
207 }
208 }
209 #[cfg(not(target_arch = "aarch64"))]
210 haar1_scalar(x, n0, stride);
211}
212
213#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
214#[target_feature(enable = "avx")]
215unsafe fn haar1_avx(x: &mut [f32], n0: usize) {
216 use std::arch::x86_64::*;
217 let n = n0 >> 1;
218 let scale = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
219 let mut j = 0;
220 while j + 8 <= n {
221 let ptr = x.as_mut_ptr().add(2 * j);
222 let a = _mm256_loadu_ps(ptr);
223 let b = _mm256_loadu_ps(ptr.add(4));
224
225 let t0 = _mm256_unpacklo_ps(a, b);
226 let t1 = _mm256_unpackhi_ps(a, b);
227
228 let even = _mm256_unpacklo_ps(t0, t1);
229 let odd = _mm256_unpackhi_ps(t0, t1);
230
231 let sum = _mm256_mul_ps(_mm256_add_ps(even, odd), scale);
232 let diff = _mm256_mul_ps(_mm256_sub_ps(even, odd), scale);
233
234 let r0 = _mm256_unpacklo_ps(sum, diff);
235 let r1 = _mm256_unpackhi_ps(sum, diff);
236
237 let out0 = _mm256_permute2f128_ps(r0, r1, 0x20);
238 let out1 = _mm256_permute2f128_ps(r0, r1, 0x31);
239
240 _mm256_storeu_ps(ptr, out0);
241 _mm256_storeu_ps(ptr.add(8), out1);
242 j += 8;
243 }
244
245 let scale = std::f32::consts::FRAC_1_SQRT_2;
246 while j < n {
247 let idx1 = 2 * j;
248 let idx2 = 2 * j + 1;
249 let tmp1 = scale * x[idx1];
250 let tmp2 = scale * x[idx2];
251 x[idx1] = tmp1 + tmp2;
252 x[idx2] = tmp1 - tmp2;
253 j += 1;
254 }
255}
256
257#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
258#[inline]
259fn haar1_scalar(x: &mut [f32], n0: usize, stride: usize) {
260 let n = n0 >> 1;
261 let scale = std::f32::consts::FRAC_1_SQRT_2;
262 for i in 0..stride {
263 for j in 0..n {
264 let idx1 = stride * 2 * j + i;
265 let idx2 = stride * (2 * j + 1) + i;
266 let tmp1 = scale * x[idx1];
267 let tmp2 = scale * x[idx2];
268 x[idx1] = tmp1 + tmp2;
269 x[idx2] = tmp1 - tmp2;
270 }
271 }
272}
273
274#[cfg(target_arch = "aarch64")]
275fn haar1_neon(x: &mut [f32], n0: usize, stride: usize) {
276 use std::arch::aarch64::*;
277
278 let n = n0 >> 1;
279 let scale = std::f32::consts::FRAC_1_SQRT_2;
280
281 unsafe {
282 let vscale = vdupq_n_f32(scale);
283
284 for i in 0..stride {
285 let mut j = 0;
286 while j + 4 <= n {
287 let idx_even = stride * 2 * j + i;
288 let idx_odd = stride * (2 * j + 1) + i;
289
290 let ve0 = vld1q_f32(x.as_ptr().add(idx_even));
291 let ve1 = vld1q_f32(x.as_ptr().add(idx_even + stride * 2));
292 let ve2 = vld1q_f32(x.as_ptr().add(idx_even + stride * 4));
293 let ve3 = vld1q_f32(x.as_ptr().add(idx_even + stride * 6));
294
295 let vo0 = vld1q_f32(x.as_ptr().add(idx_odd));
296 let vo1 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 2));
297 let vo2 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 4));
298 let vo3 = vld1q_f32(x.as_ptr().add(idx_odd + stride * 6));
299
300 let te0 = vmulq_f32(ve0, vscale);
301 let te1 = vmulq_f32(ve1, vscale);
302 let te2 = vmulq_f32(ve2, vscale);
303 let te3 = vmulq_f32(ve3, vscale);
304
305 let to0 = vmulq_f32(vo0, vscale);
306 let to1 = vmulq_f32(vo1, vscale);
307 let to2 = vmulq_f32(vo2, vscale);
308 let to3 = vmulq_f32(vo3, vscale);
309
310 vst1q_f32(x.as_mut_ptr().add(idx_even), vaddq_f32(te0, to0));
311 vst1q_f32(
312 x.as_mut_ptr().add(idx_even + stride * 2),
313 vaddq_f32(te1, to1),
314 );
315 vst1q_f32(
316 x.as_mut_ptr().add(idx_even + stride * 4),
317 vaddq_f32(te2, to2),
318 );
319 vst1q_f32(
320 x.as_mut_ptr().add(idx_even + stride * 6),
321 vaddq_f32(te3, to3),
322 );
323
324 vst1q_f32(x.as_mut_ptr().add(idx_odd), vsubq_f32(te0, to0));
325 vst1q_f32(
326 x.as_mut_ptr().add(idx_odd + stride * 2),
327 vsubq_f32(te1, to1),
328 );
329 vst1q_f32(
330 x.as_mut_ptr().add(idx_odd + stride * 4),
331 vsubq_f32(te2, to2),
332 );
333 vst1q_f32(
334 x.as_mut_ptr().add(idx_odd + stride * 6),
335 vsubq_f32(te3, to3),
336 );
337
338 j += 4;
339 }
340
341 while j < n {
342 let idx1 = stride * 2 * j + i;
343 let idx2 = stride * (2 * j + 1) + i;
344 let tmp1 = scale * x[idx1];
345 let tmp2 = scale * x[idx2];
346 x[idx1] = tmp1 + tmp2;
347 x[idx2] = tmp1 - tmp2;
348 j += 1;
349 }
350 }
351 }
352}
353
354#[inline(always)]
355pub fn compute_qn(n: usize, b: i32, offset: i32, pulse_cap: i32, stereo: bool) -> i32 {
356 static EXP2_TABLE8: [i16; 8] = [16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048];
357 let mut n2 = (2 * n as i32) - 1;
358 if stereo && n == 2 {
359 n2 -= 1;
360 }
361 let mut qb = celt_sudiv(b + n2 * offset, n2);
362 qb = qb.min(b - pulse_cap - (4 << BITRES));
363 qb = qb.min(8 << BITRES);
364 if qb < (1i32 << BITRES >> 1) {
365 1
366 } else {
367 let val = EXP2_TABLE8[(qb & 0x7) as usize] as i32;
368 let shift = 14 - (qb >> BITRES);
369 let raw = if (0..32).contains(&shift) {
370 val >> shift
371 } else {
372 0
373 };
374 let qn = (raw + 1) >> 1 << 1;
375 qn.min(256)
376 }
377}
378
379#[cfg(target_arch = "aarch64")]
380#[inline(always)]
381#[allow(unsafe_op_in_unsafe_fn)]
382unsafe fn stereo_itheta_neon(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
383 use std::arch::aarch64::*;
384
385 let mut emid = 1e-15f32;
386 let mut eside = 1e-15f32;
387
388 if stereo {
389 let mut sum_mid = vdupq_n_f32(0.0);
390 let mut sum_side = vdupq_n_f32(0.0);
391 let mut i = 0;
392
393 while i + 16 <= n {
394 let x0 = vld1q_f32(x.as_ptr().add(i));
395 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
396 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
397 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
398 let y0 = vld1q_f32(y.as_ptr().add(i));
399 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
400 let y2 = vld1q_f32(y.as_ptr().add(i + 8));
401 let y3 = vld1q_f32(y.as_ptr().add(i + 12));
402
403 let m0 = vaddq_f32(x0, y0);
404 let m1 = vaddq_f32(x1, y1);
405 let m2 = vaddq_f32(x2, y2);
406 let m3 = vaddq_f32(x3, y3);
407 let s0 = vsubq_f32(x0, y0);
408 let s1 = vsubq_f32(x1, y1);
409 let s2 = vsubq_f32(x2, y2);
410 let s3 = vsubq_f32(x3, y3);
411
412 sum_mid = vfmaq_f32(sum_mid, m0, m0);
413 sum_mid = vfmaq_f32(sum_mid, m1, m1);
414 sum_mid = vfmaq_f32(sum_mid, m2, m2);
415 sum_mid = vfmaq_f32(sum_mid, m3, m3);
416 sum_side = vfmaq_f32(sum_side, s0, s0);
417 sum_side = vfmaq_f32(sum_side, s1, s1);
418 sum_side = vfmaq_f32(sum_side, s2, s2);
419 sum_side = vfmaq_f32(sum_side, s3, s3);
420
421 i += 16;
422 }
423
424 while i + 8 <= n {
425 let x0 = vld1q_f32(x.as_ptr().add(i));
426 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
427 let y0 = vld1q_f32(y.as_ptr().add(i));
428 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
429
430 let m0 = vaddq_f32(x0, y0);
431 let m1 = vaddq_f32(x1, y1);
432 let s0 = vsubq_f32(x0, y0);
433 let s1 = vsubq_f32(x1, y1);
434
435 sum_mid = vfmaq_f32(sum_mid, m0, m0);
436 sum_mid = vfmaq_f32(sum_mid, m1, m1);
437 sum_side = vfmaq_f32(sum_side, s0, s0);
438 sum_side = vfmaq_f32(sum_side, s1, s1);
439
440 i += 8;
441 }
442
443 while i + 4 <= n {
444 let x0 = vld1q_f32(x.as_ptr().add(i));
445 let y0 = vld1q_f32(y.as_ptr().add(i));
446 let m0 = vaddq_f32(x0, y0);
447 let s0 = vsubq_f32(x0, y0);
448 sum_mid = vfmaq_f32(sum_mid, m0, m0);
449 sum_side = vfmaq_f32(sum_side, s0, s0);
450 i += 4;
451 }
452
453 emid += vaddvq_f32(sum_mid);
454 eside += vaddvq_f32(sum_side);
455
456 for j in i..n {
457 let m = x[j] + y[j];
458 let s = x[j] - y[j];
459 emid += m * m;
460 eside += s * s;
461 }
462 } else {
463 let mut sum_mid = vdupq_n_f32(0.0);
464 let mut sum_side = vdupq_n_f32(0.0);
465 let mut i = 0;
466
467 while i + 16 <= n {
468 let x0 = vld1q_f32(x.as_ptr().add(i));
469 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
470 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
471 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
472 let y0 = vld1q_f32(y.as_ptr().add(i));
473 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
474 let y2 = vld1q_f32(y.as_ptr().add(i + 8));
475 let y3 = vld1q_f32(y.as_ptr().add(i + 12));
476
477 sum_mid = vfmaq_f32(sum_mid, x0, x0);
478 sum_mid = vfmaq_f32(sum_mid, x1, x1);
479 sum_mid = vfmaq_f32(sum_mid, x2, x2);
480 sum_mid = vfmaq_f32(sum_mid, x3, x3);
481 sum_side = vfmaq_f32(sum_side, y0, y0);
482 sum_side = vfmaq_f32(sum_side, y1, y1);
483 sum_side = vfmaq_f32(sum_side, y2, y2);
484 sum_side = vfmaq_f32(sum_side, y3, y3);
485
486 i += 16;
487 }
488
489 while i + 8 <= n {
490 let x0 = vld1q_f32(x.as_ptr().add(i));
491 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
492 let y0 = vld1q_f32(y.as_ptr().add(i));
493 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
494
495 sum_mid = vfmaq_f32(sum_mid, x0, x0);
496 sum_mid = vfmaq_f32(sum_mid, x1, x1);
497 sum_side = vfmaq_f32(sum_side, y0, y0);
498 sum_side = vfmaq_f32(sum_side, y1, y1);
499
500 i += 8;
501 }
502
503 while i + 4 <= n {
504 let x0 = vld1q_f32(x.as_ptr().add(i));
505 let y0 = vld1q_f32(y.as_ptr().add(i));
506 sum_mid = vfmaq_f32(sum_mid, x0, x0);
507 sum_side = vfmaq_f32(sum_side, y0, y0);
508 i += 4;
509 }
510
511 emid += vaddvq_f32(sum_mid);
512 eside += vaddvq_f32(sum_side);
513
514 for j in i..n {
515 emid += x[j] * x[j];
516 eside += y[j] * y[j];
517 }
518 }
519
520 let mid = emid.sqrt();
521 let side = eside.sqrt();
522 let theta_norm = celt_atan2p_norm(side, mid);
523 (0.5 + 16384.0 * theta_norm) as i32
524}
525
526#[inline(always)]
527#[cfg(target_arch = "aarch64")]
528pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
529 unsafe { stereo_itheta_neon(x, y, stereo, n) }
530}
531
532#[inline(always)]
533#[cfg(not(target_arch = "aarch64"))]
534pub fn stereo_itheta(x: &[f32], y: &[f32], stereo: bool, n: usize) -> i32 {
535 #[cfg(target_arch = "aarch64")]
536 unsafe {
537 return stereo_itheta_neon(x, y, stereo, n);
538 }
539 #[cfg(not(target_arch = "aarch64"))]
540 {
541 let mut emid = 1e-15f32;
542 let mut eside = 1e-15f32;
543 if stereo {
544 for i in 0..n {
545 let m = x[i] + y[i];
546 let s = x[i] - y[i];
547 emid += m * m;
548 eside += s * s;
549 }
550 } else {
551 for i in 0..n {
552 emid += x[i] * x[i];
553 eside += y[i] * y[i];
554 }
555 }
556 let mid = emid.sqrt();
557 let side = eside.sqrt();
558 let theta_norm = celt_atan2p_norm(side, mid);
559 (0.5 + 16384.0 * theta_norm) as i32
560 }
561}
562
563#[inline(always)]
564fn celt_atan2p_norm(y: f32, x: f32) -> f32 {
565 #[inline(always)]
566 fn atan_norm(x: f32) -> f32 {
567 const ATAN2_2_OVER_PI: f32 = std::f32::consts::FRAC_2_PI;
568 const A03: f32 = -3.333_166e-1_f32;
569 const A05: f32 = 1.996_270_4e-1_f32;
570 const A07: f32 = -1.397_658_3e-1_f32;
571 const A09: f32 = 9.794_234_e-2_f32;
572 const A11: f32 = -5.777_359_e-2_f32;
573 const A13: f32 = 2.304_014e-2_f32;
574 const A15: f32 = -4.355_406e-3_f32;
575 let x2 = x * x;
576 ATAN2_2_OVER_PI
577 * x
578 * (1.0
579 + x2 * (A03
580 + x2 * (A05 + x2 * (A07 + x2 * (A09 + x2 * (A11 + x2 * (A13 + x2 * A15)))))))
581 }
582 if x * x + y * y < 1e-18 {
583 return 0.0;
584 }
585 if y < x {
586 atan_norm(y / x)
587 } else {
588 1.0 - atan_norm(x / y)
589 }
590}
591
592pub struct SplitCtx {
593 pub inv: bool,
594 pub imid: i32,
595 pub iside: i32,
596 pub delta: i32,
597 pub itheta: i32,
598 pub qalloc: i32,
599}
600
601#[allow(clippy::too_many_arguments)]
602#[inline(always)]
603fn compute_theta_encode(
604 ctx: &mut BandCtx,
605 sctx: &mut SplitCtx,
606 x: &[f32],
607 y: &[f32],
608 n: usize,
609 b: &mut i32,
610 b_blocks: i32,
611 b0: i32,
612 lm: i32,
613 stereo: bool,
614 fill: &mut u32,
615) {
616 let pulse_cap = ctx.m.log_n[ctx.i] as i32 + (lm << BITRES);
617 let offset = (pulse_cap >> 1) - if stereo && n == 2 { 16 } else { 4 };
618 let mut qn = compute_qn(n, *b, offset, pulse_cap, stereo);
619
620 if stereo && ctx.i >= ctx.intensity {
621 qn = 1;
622 }
623
624 if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
625 sctx.itheta = 8192;
626 sctx.qalloc = 0;
627 let imid = bitexact_cos(8192i16);
628 sctx.imid = imid as i32;
629 let iside = bitexact_cos(8192i16);
630 sctx.iside = iside as i32;
631 sctx.delta =
632 (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
633 return;
634 }
635
636 let mut itheta = stereo_itheta(x, y, stereo, n);
637
638 let tell_start = tell_frac_inline!(ctx.rc);
639
640 if qn != 1 {
641 if !stereo || ctx.theta_round == 0 {
642 itheta = (itheta * qn + 8192) >> 14;
643 if !stereo && ctx.avoid_split_noise && itheta > 0 && itheta < qn {
644 let unquantized = (itheta * 16384) / qn;
645 let imid = bitexact_cos(unquantized as i16) as i32;
646 let iside = bitexact_cos((16384 - unquantized) as i16) as i32;
647 let delta = (((n as i32 - 1) << 7) * bitexact_log2tan(iside, imid) + 16384) >> 15;
648 if delta > *b {
649 itheta = qn;
650 } else if delta < -*b {
651 itheta = 0;
652 }
653 }
654 } else {
655 let bias = if itheta > 8192 {
656 32767 / qn
657 } else {
658 -32767 / qn
659 };
660 let down = (itheta * qn + bias) >> 14;
661 let down = down.clamp(0, qn - 1);
662 if ctx.theta_round < 0 {
663 itheta = down;
664 } else {
665 itheta = down + 1;
666 }
667 }
668
669 if stereo && n > 2 {
670 let p0 = 3;
671 let x0 = qn / 2;
672 let ft = p0 * (x0 + 1) + x0;
673 let fl = if itheta <= x0 {
674 p0 * itheta
675 } else {
676 (itheta - 1 - x0) + (x0 + 1) * p0
677 };
678 let fh = if itheta <= x0 {
679 p0 * (itheta + 1)
680 } else {
681 (itheta - x0) + (x0 + 1) * p0
682 };
683 ctx.rc.encode(fl as u32, fh as u32, ft as u32);
684 } else if b0 > 1 || stereo {
685 ctx.rc.enc_uint(itheta as u32, (qn + 1) as u32);
686 } else {
687 let ft = ((qn >> 1) + 1) * ((qn >> 1) + 1);
688 let fs = if itheta <= (qn >> 1) {
689 itheta + 1
690 } else {
691 qn + 1 - itheta
692 };
693 let fl = if itheta <= (qn >> 1) {
694 (itheta * (itheta + 1)) >> 1
695 } else {
696 ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1)
697 };
698 ctx.rc.encode(fl as u32, (fl + fs) as u32, ft as u32);
699 }
700 itheta = (itheta as u32 * 16384 / qn as u32) as i32;
701 } else if stereo && ctx.i >= ctx.intensity {
702 let mut emid = 1e-15f32;
703 let mut eside = 1e-15f32;
704 for i in 0..n {
705 let m = x[i] + y[i];
706 let s = x[i] - y[i];
707 emid += m * m;
708 eside += s * s;
709 }
710 let inv = eside > emid;
711 ctx.rc.encode_bit_logp(inv, 1);
712 itheta = 0;
713 sctx.inv = inv;
714 } else {
715 itheta = 8192;
716 }
717
718 sctx.itheta = itheta;
719
720 sctx.qalloc = if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
721 0
722 } else {
723 tell_frac_inline!(ctx.rc) - tell_start
724 };
725 *b -= sctx.qalloc; if itheta == 0 {
728 sctx.imid = 32767;
729 sctx.iside = 0;
730 sctx.delta = -16384;
731 *fill &= (1 << b_blocks) - 1;
732 } else if itheta == 16384 {
733 sctx.imid = 0;
734 sctx.iside = 32767;
735 sctx.delta = 16384;
736 *fill &= !((1 << b_blocks) - 1);
737 } else {
738 let imid = bitexact_cos(itheta as i16);
739 sctx.imid = imid as i32;
740 let iside = bitexact_cos((16384 - itheta) as i16);
741 sctx.iside = iside as i32;
742 sctx.delta =
743 (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
744 }
745}
746
747#[allow(clippy::too_many_arguments)]
748#[inline(always)]
749pub fn compute_theta(
750 ctx: &mut BandCtx,
751 sctx: &mut SplitCtx,
752 x: &[f32],
753 y: &[f32],
754 n: usize,
755 b: &mut i32,
756 b_blocks: i32,
757 b0: i32,
758 lm: i32,
759 stereo: bool,
760 fill: &mut u32,
761) {
762 let pulse_cap = ctx.m.log_n[ctx.i] as i32 + (lm << BITRES);
763 let offset = (pulse_cap >> 1) - if stereo && n == 2 { 16 } else { 4 };
764 let mut qn = compute_qn(n, *b, offset, pulse_cap, stereo);
765
766 if stereo && ctx.i >= ctx.intensity {
767 qn = 1;
768 }
769
770 let mut itheta = 0;
771 if ctx.encode {
772 itheta = stereo_itheta(x, y, stereo, n);
773 }
774
775 let tell_start = tell_frac_inline!(ctx.rc);
776
777 if qn != 1 {
778 if ctx.encode {
779 if !stereo || ctx.theta_round == 0 {
780 itheta = (itheta * qn + 8192) >> 14;
781 if !stereo && ctx.avoid_split_noise && itheta > 0 && itheta < qn {
782 let unquantized = (itheta * 16384) / qn;
783 let imid = bitexact_cos(unquantized as i16) as i32;
784 let iside = bitexact_cos((16384 - unquantized) as i16) as i32;
785 let delta =
786 (((n as i32 - 1) << 7) * bitexact_log2tan(iside, imid) + 16384) >> 15;
787 if delta > *b {
788 itheta = qn;
789 } else if delta < -*b {
790 itheta = 0;
791 }
792 }
793 } else {
794 let bias = if itheta > 8192 {
795 32767 / qn
796 } else {
797 -32767 / qn
798 };
799 let down = (itheta * qn + bias) >> 14;
800 let down = down.clamp(0, qn - 1);
801 if ctx.theta_round < 0 {
802 itheta = down;
803 } else {
804 itheta = down + 1;
805 }
806 }
807 }
808
809 if stereo && n > 2 {
810 let p0 = 3;
811 let x0 = qn / 2;
812 let ft = p0 * (x0 + 1) + x0;
813 if ctx.encode {
814 let fl = if itheta <= x0 {
815 p0 * itheta
816 } else {
817 (itheta - 1 - x0) + (x0 + 1) * p0
818 };
819 let fh = if itheta <= x0 {
820 p0 * (itheta + 1)
821 } else {
822 (itheta - x0) + (x0 + 1) * p0
823 };
824 ctx.rc.encode(fl as u32, fh as u32, ft as u32);
825 } else {
826 let fs = ctx.rc.decode(ft as u32);
827 if fs < (x0 + 1) as u32 * p0 as u32 {
828 itheta = fs as i32 / p0;
829 } else {
830 itheta = (x0 + 1) + (fs as i32 - (x0 + 1) * p0);
831 }
832 let fl = if itheta <= x0 {
833 p0 * itheta
834 } else {
835 (itheta - 1 - x0) + (x0 + 1) * p0
836 };
837 let fh = if itheta <= x0 {
838 p0 * (itheta + 1)
839 } else {
840 (itheta - x0) + (x0 + 1) * p0
841 };
842 ctx.rc.update(fl as u32, fh as u32, ft as u32);
843 }
844 } else if b0 > 1 || stereo {
845 if ctx.encode {
846 ctx.rc.enc_uint(itheta as u32, (qn + 1) as u32);
847 } else {
848 itheta = ctx.rc.dec_uint((qn + 1) as u32) as i32;
849 }
850 } else {
851 let ft = ((qn >> 1) + 1) * ((qn >> 1) + 1);
852 if ctx.encode {
853 let fs = if itheta <= (qn >> 1) {
854 itheta + 1
855 } else {
856 qn + 1 - itheta
857 };
858 let fl = if itheta <= (qn >> 1) {
859 (itheta * (itheta + 1)) >> 1
860 } else {
861 ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1)
862 };
863 ctx.rc.encode(fl as u32, (fl + fs) as u32, ft as u32);
864 } else {
865 let fm = ctx.rc.decode(ft as u32) as i32;
866 if fm < (((qn >> 1) * ((qn >> 1) + 1)) >> 1) {
867 itheta = (isqrt32((8 * fm + 1) as u32) as i32 - 1) >> 1;
868 let fl = (itheta * (itheta + 1)) >> 1;
869 let fs = itheta + 1;
870 ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
871 } else {
872 itheta = (2 * (qn + 1) - isqrt32((8 * (ft - fm - 1) + 1) as u32) as i32) >> 1;
873 let fs = qn + 1 - itheta;
874 let fl = ft - (((qn + 1 - itheta) * (qn + 2 - itheta)) >> 1);
875 ctx.rc.update(fl as u32, (fl + fs) as u32, ft as u32);
876 }
877 }
878 }
879 itheta = (itheta as u32 * 16384 / qn as u32) as i32;
880 } else if stereo && ctx.i >= ctx.intensity {
881 if ctx.encode {
882 let mut emid = 1e-15f32;
883 let mut eside = 1e-15f32;
884 for i in 0..n {
885 let m = x[i] + y[i];
886 let s = x[i] - y[i];
887 emid += m * m;
888 eside += s * s;
889 }
890 let inv = eside > emid;
891 ctx.rc.encode_bit_logp(inv, 1);
892 itheta = 0;
893 sctx.inv = inv;
894 } else {
895 sctx.inv = ctx.rc.decode_bit_logp(1);
896 itheta = 0;
897 }
898 } else {
899 itheta = 8192;
900 }
901
902 sctx.itheta = itheta;
903
904 sctx.qalloc = if qn == 1 && !(stereo && ctx.i >= ctx.intensity) {
905 0
906 } else {
907 tell_frac_inline!(ctx.rc) - tell_start
908 };
909 *b -= sctx.qalloc; if itheta == 0 {
912 sctx.imid = 32767;
913 sctx.iside = 0;
914 sctx.delta = -16384;
915 *fill &= (1 << b_blocks) - 1;
916 } else if itheta == 16384 {
917 sctx.imid = 0;
918 sctx.iside = 32767;
919 sctx.delta = 16384;
920 *fill &= !((1 << b_blocks) - 1);
921 } else {
922 let imid = bitexact_cos(itheta as i16);
923 sctx.imid = imid as i32;
924 let iside = bitexact_cos((16384 - itheta) as i16);
925 sctx.iside = iside as i32;
926 sctx.delta =
927 (((n as i32 - 1) << 7) * bitexact_log2tan(sctx.iside, sctx.imid) + 16384) >> 15;
928 }
929}
930
931#[inline(always)]
932fn quant_partition_n2_encode(
933 ctx: &mut BandCtx,
934 x: &mut [f32],
935 b: i32,
936 b_blocks: i32,
937 lowband: Option<&mut [f32]>,
938 lm: i32,
939 gain: f32,
940 fill: u32,
941) -> u32 {
942 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
943 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
944 ctx.remaining_bits -= curr_bits;
945
946 while ctx.remaining_bits < 0 && q > 0 {
947 ctx.remaining_bits += curr_bits;
948 q -= 1;
949 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
950 ctx.remaining_bits -= curr_bits;
951 }
952
953 if q != 0 {
954 let k = get_pulses(q);
955 alg_quant(x, 2, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
956 } else {
957 let has_lowband = lowband.is_some();
958 if has_lowband {
959 fill
960 } else {
961 (1u32 << b_blocks) - 1
962 }
963 }
964}
965
966#[inline(always)]
967fn quant_partition_n4_encode(
968 ctx: &mut BandCtx,
969 x: &mut [f32],
970 b: i32,
971 b_blocks: i32,
972 lowband: Option<&mut [f32]>,
973 lm: i32,
974 gain: f32,
975 fill: u32,
976) -> u32 {
977 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
978 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
979 ctx.remaining_bits -= curr_bits;
980
981 while ctx.remaining_bits < 0 && q > 0 {
982 ctx.remaining_bits += curr_bits;
983 q -= 1;
984 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
985 ctx.remaining_bits -= curr_bits;
986 }
987
988 if q != 0 {
989 let k = get_pulses(q);
990 alg_quant(x, 4, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
991 } else {
992 let has_lowband = lowband.is_some();
993 if has_lowband {
994 fill
995 } else {
996 (1u32 << b_blocks) - 1
997 }
998 }
999}
1000
1001#[inline(always)]
1002fn quant_partition_n8_encode(
1003 ctx: &mut BandCtx,
1004 x: &mut [f32],
1005 b: i32,
1006 b_blocks: i32,
1007 lowband: Option<&mut [f32]>,
1008 lm: i32,
1009 gain: f32,
1010 fill: u32,
1011) -> u32 {
1012 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1013 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1014 ctx.remaining_bits -= curr_bits;
1015
1016 while ctx.remaining_bits < 0 && q > 0 {
1017 ctx.remaining_bits += curr_bits;
1018 q -= 1;
1019 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1020 ctx.remaining_bits -= curr_bits;
1021 }
1022
1023 if q != 0 {
1024 let k = get_pulses(q);
1025 alg_quant(x, 8, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1026 } else {
1027 let has_lowband = lowband.is_some();
1028 if has_lowband {
1029 fill
1030 } else {
1031 (1u32 << b_blocks) - 1
1032 }
1033 }
1034}
1035
1036#[inline(always)]
1037#[allow(clippy::too_many_arguments)]
1038fn quant_partition_direct_encode(
1039 ctx: &mut BandCtx,
1040 x: &mut [f32],
1041 n: usize,
1042 b: i32,
1043 b_blocks: i32,
1044 lowband: Option<&mut [f32]>,
1045 lm: i32,
1046 gain: f32,
1047 fill: u32,
1048) -> u32 {
1049 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1050 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1051 ctx.remaining_bits -= curr_bits;
1052
1053 while ctx.remaining_bits < 0 && q > 0 {
1054 ctx.remaining_bits += curr_bits;
1055 q -= 1;
1056 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1057 ctx.remaining_bits -= curr_bits;
1058 }
1059
1060 if q != 0 {
1061 let k = get_pulses(q);
1062 alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1063 } else {
1064 let has_lowband = lowband.is_some();
1065 if has_lowband {
1066 fill
1067 } else {
1068 (1u32 << b_blocks) - 1
1069 }
1070 }
1071}
1072
1073#[inline(always)]
1074#[allow(clippy::too_many_arguments)]
1075fn quant_partition_encode(
1076 ctx: &mut BandCtx,
1077 x: &mut [f32],
1078 n: usize,
1079 b: i32,
1080 b_blocks: i32,
1081 lowband: Option<&mut [f32]>,
1082 lm: i32,
1083 gain: f32,
1084 fill: u32,
1085) -> u32 {
1086 if n == 2 {
1088 return quant_partition_n2_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1089 }
1090
1091 let should_split = if lm >= 0 && n > 2 {
1093 let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
1094 let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) } as usize;
1095 if cache_base > 0 {
1096 let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
1097 let max_q = unsafe { *cache_ptr } as usize;
1098 b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
1099 } else {
1100 false
1101 }
1102 } else {
1103 false
1104 };
1105
1106 if should_split {
1107 let mut sctx = SplitCtx {
1108 inv: false,
1109 imid: 0,
1110 iside: 0,
1111 delta: 0,
1112 itheta: 0,
1113 qalloc: 0,
1114 };
1115 let mut b_mut = b;
1116 let mut fill_mut = fill;
1117 let mid = n / 2;
1118 let lm = lm - 1;
1119 let b0 = b_blocks;
1120 if b_blocks == 1 {
1121 fill_mut = (fill_mut & 1) | (fill_mut << 1);
1122 }
1123 let b_blocks = (b_blocks + 1) >> 1;
1124 let (x_mid, x_side) = x.split_at_mut(mid);
1125
1126 compute_theta_encode(
1127 ctx,
1128 &mut sctx,
1129 x_mid,
1130 x_side,
1131 mid,
1132 &mut b_mut,
1133 b_blocks,
1134 b0,
1135 lm,
1136 false,
1137 &mut fill_mut,
1138 );
1139
1140 ctx.remaining_bits -= sctx.qalloc;
1141 let mut delta = sctx.delta;
1142 if b0 > 1 && (sctx.itheta & 0x3fff) != 0 {
1144 if sctx.itheta > 8192 {
1145 delta -= delta >> (4 - lm);
1146 } else {
1147 delta = 0.min(delta + ((mid as i32) << BITRES >> (5 - lm)));
1148 }
1149 }
1150 let mbits = (0).max((b_mut - delta) / 2).min(b_mut);
1151 let mut sbits = b_mut - mbits;
1152 let mut mbits = mbits;
1153
1154 let mut rebalance = ctx.remaining_bits;
1155 let mut cm;
1156
1157 if mbits >= sbits {
1158 cm = quant_partition_encode(
1159 ctx, x_mid, mid, mbits, b_blocks, lowband, lm, gain, fill_mut,
1160 );
1161 rebalance = mbits - (rebalance - ctx.remaining_bits);
1162 if rebalance > (3 << 3) && sctx.itheta != 0 {
1163 sbits += rebalance - (3 << 3);
1164 }
1165 cm |= quant_partition_encode(
1166 ctx,
1167 x_side,
1168 mid,
1169 sbits,
1170 b_blocks,
1171 None,
1172 lm,
1173 gain,
1174 fill_mut >> b_blocks,
1175 ) << (b0 >> 1);
1176 } else {
1177 cm = quant_partition_encode(
1178 ctx,
1179 x_side,
1180 mid,
1181 sbits,
1182 b_blocks,
1183 None,
1184 lm,
1185 gain,
1186 fill_mut >> b_blocks,
1187 ) << (b0 >> 1);
1188 rebalance = sbits - (rebalance - ctx.remaining_bits);
1189 if rebalance > (3 << 3) && sctx.itheta != 16384 {
1190 mbits += rebalance - (3 << 3);
1191 }
1192 cm |= quant_partition_encode(
1193 ctx, x_mid, mid, mbits, b_blocks, lowband, lm, gain, fill_mut,
1194 );
1195 }
1196 cm
1197 } else {
1198 if n == 4 {
1200 return quant_partition_n4_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1201 }
1202 if n == 8 {
1203 return quant_partition_n8_encode(ctx, x, b, b_blocks, lowband, lm, gain, fill);
1204 }
1205 if n == 16 {
1206 return quant_partition_direct_encode(ctx, x, n, b, b_blocks, lowband, lm, gain, fill);
1207 }
1208 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1209 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1210 ctx.remaining_bits -= curr_bits;
1211
1212 while ctx.remaining_bits < 0 && q > 0 {
1213 ctx.remaining_bits += curr_bits;
1214 q -= 1;
1215 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1216 ctx.remaining_bits -= curr_bits;
1217 }
1218
1219 if q != 0 {
1220 let k = get_pulses(q);
1221 alg_quant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain, false)
1222 } else if lowband.is_some() {
1223 fill
1224 } else {
1225 (1 << b_blocks) - 1
1226 }
1227 }
1228}
1229
1230#[inline(always)]
1231#[allow(clippy::too_many_arguments)]
1232pub fn quant_partition(
1233 ctx: &mut BandCtx,
1234 x: &mut [f32],
1235 n: usize,
1236 b: i32,
1237 b_blocks: i32,
1238 lowband: Option<&mut [f32]>,
1239 lm: i32,
1240 gain: f32,
1241 fill: u32,
1242) -> u32 {
1243 let should_split = if lm >= 0 && n > 2 {
1246 let cache_idx = (lm + 1) as usize * ctx.m.nb_ebands + ctx.i;
1247 let cache_base = unsafe { *ctx.m.cache.index.get_unchecked(cache_idx) } as usize;
1248 if cache_base > 0 {
1249 let cache_ptr = ctx.m.cache.bits.as_ptr().wrapping_add(cache_base);
1250 let max_q = unsafe { *cache_ptr } as usize;
1251 b > (unsafe { *cache_ptr.add(max_q) } as i32) + 12
1252 } else {
1253 false
1254 }
1255 } else {
1256 false
1257 };
1258 if should_split {
1259 let mut sctx = SplitCtx {
1260 inv: false,
1261 imid: 0,
1262 iside: 0,
1263 delta: 0,
1264 itheta: 0,
1265 qalloc: 0,
1266 };
1267 let mut b_mut = b;
1268 let mut fill_mut = fill;
1269 let mid = n / 2;
1270 let lm = lm - 1;
1271 let b0 = b_blocks; if b_blocks == 1 {
1273 fill_mut = (fill_mut & 1) | (fill_mut << 1);
1274 }
1275 let b_blocks = (b_blocks + 1) >> 1;
1276 let (x_mid, x_side) = x.split_at_mut(mid);
1277
1278 compute_theta(
1279 ctx,
1280 &mut sctx,
1281 x_mid,
1282 x_side,
1283 mid,
1284 &mut b_mut,
1285 b_blocks,
1286 b0,
1287 lm,
1288 false,
1289 &mut fill_mut,
1290 );
1291
1292 ctx.remaining_bits -= sctx.qalloc;
1293 let mut delta = sctx.delta;
1294 if b0 > 1 && (sctx.itheta & 0x3fff) != 0 {
1297 if sctx.itheta > 8192 {
1298 delta -= delta >> (4 - lm);
1299 } else {
1300 delta = 0.min(delta + ((mid as i32) << BITRES >> (5 - lm)));
1301 }
1302 }
1303 let mbits = (0).max((b_mut - delta) / 2).min(b_mut);
1304 let mut sbits = b_mut - mbits;
1305 let mut mbits = mbits;
1306
1307 let mut rebalance = ctx.remaining_bits;
1308 let mut cm;
1309
1310 if mbits >= sbits {
1311 cm = quant_partition(
1312 ctx,
1313 x_mid,
1314 mid,
1315 mbits,
1316 b_blocks,
1317 lowband,
1318 lm,
1319 gain * (sctx.imid as f32 / 32768.0),
1320 fill_mut,
1321 );
1322 rebalance = mbits - (rebalance - ctx.remaining_bits);
1323 if rebalance > (3 << 3) && sctx.itheta != 0 {
1324 sbits += rebalance - (3 << 3);
1325 }
1326 cm |= quant_partition(
1327 ctx,
1328 x_side,
1329 mid,
1330 sbits,
1331 b_blocks,
1332 None,
1333 lm,
1334 gain * (sctx.iside as f32 / 32768.0),
1335 fill_mut >> b_blocks,
1336 ) << (b0 >> 1);
1337 } else {
1338 cm = quant_partition(
1339 ctx,
1340 x_side,
1341 mid,
1342 sbits,
1343 b_blocks,
1344 None,
1345 lm,
1346 gain * (sctx.iside as f32 / 32768.0),
1347 fill_mut >> b_blocks,
1348 ) << (b0 >> 1);
1349 rebalance = sbits - (rebalance - ctx.remaining_bits);
1350 if rebalance > (3 << 3) && sctx.itheta != 16384 {
1351 mbits += rebalance - (3 << 3);
1352 }
1353 cm |= quant_partition(
1354 ctx,
1355 x_mid,
1356 mid,
1357 mbits,
1358 b_blocks,
1359 lowband,
1360 lm,
1361 gain * (sctx.imid as f32 / 32768.0),
1362 fill_mut,
1363 );
1364 }
1365 cm
1366 } else {
1367 let mut q = bits2pulses(ctx.m, ctx.i, lm, b);
1368 let mut curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1369 ctx.remaining_bits -= curr_bits;
1370
1371 while ctx.remaining_bits < 0 && q > 0 {
1372 ctx.remaining_bits += curr_bits;
1373 q -= 1;
1374 curr_bits = pulses2bits(ctx.m, ctx.i, lm, q);
1375 ctx.remaining_bits -= curr_bits;
1376 }
1377
1378 if q != 0 {
1379 let k = get_pulses(q);
1380 if ctx.encode {
1381 alg_quant(
1382 x,
1383 n,
1384 k,
1385 ctx.spread,
1386 b_blocks as usize,
1387 ctx.rc,
1388 gain,
1389 ctx.resynth,
1390 )
1391 } else {
1392 alg_unquant(x, n, k, ctx.spread, b_blocks as usize, ctx.rc, gain)
1393 }
1394 } else {
1395 let has_lowband = lowband.is_some();
1396 if ctx.resynth {
1397 let cm_mask = (1u32 << b_blocks) - 1;
1398 let fill_masked = fill & cm_mask;
1399 if fill_masked == 0 {
1400 x[..n].fill(0.0);
1401 } else if has_lowband {
1402 let lb = lowband.unwrap();
1403 #[cfg(target_arch = "aarch64")]
1404 unsafe {
1405 use std::arch::aarch64::*;
1406 let n8 = n & !7;
1407 let mut i = 0;
1408 while i < n8 {
1409 let mut vals = [0.0f32; 8];
1410 for j in 0..8 {
1411 ctx.seed = celt_lcg_rand(ctx.seed);
1412 vals[j] = if ctx.seed & 0x8000 != 0 {
1413 1.0 / 256.0
1414 } else {
1415 -1.0 / 256.0
1416 };
1417 }
1418 let vnoise = vld1q_f32(vals.as_ptr());
1419 let vnoise1 = vld1q_f32(vals.as_ptr().add(4));
1420 let vlb = vld1q_f32(lb.as_ptr().add(i));
1421 let vlb1 = vld1q_f32(lb.as_ptr().add(i + 4));
1422 let vres = vaddq_f32(vlb, vnoise);
1423 let vres1 = vaddq_f32(vlb1, vnoise1);
1424 vst1q_f32(x.as_mut_ptr().add(i), vres);
1425 vst1q_f32(x.as_mut_ptr().add(i + 4), vres1);
1426 i += 8;
1427 }
1428 for j in i..n {
1429 ctx.seed = celt_lcg_rand(ctx.seed);
1430 x[j] = lb[j]
1431 + if ctx.seed & 0x8000 != 0 {
1432 1.0 / 256.0
1433 } else {
1434 -1.0 / 256.0
1435 };
1436 }
1437 }
1438 #[cfg(not(target_arch = "aarch64"))]
1439 {
1440 for j in 0..n {
1441 ctx.seed = celt_lcg_rand(ctx.seed);
1442 x[j] = lb[j]
1443 + if ctx.seed & 0x8000 != 0 {
1444 1.0 / 256.0
1445 } else {
1446 -1.0 / 256.0
1447 };
1448 }
1449 }
1450 renormalise_vector(x, n, gain);
1451 } else {
1452 for xv in x[..n].iter_mut() {
1453 ctx.seed = celt_lcg_rand(ctx.seed);
1454 *xv = ((ctx.seed as i32 >> 20) as f32) / 16384.0;
1455 }
1456 renormalise_vector(x, n, gain);
1457 }
1458 }
1459 if has_lowband {
1460 fill
1461 } else {
1462 (1 << b_blocks) - 1
1463 }
1464 }
1465 }
1466}
1467
1468#[cfg(target_arch = "aarch64")]
1469#[inline(always)]
1470unsafe fn deinterleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1471 let n = n0 * stride;
1472 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1473 let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1474
1475 for i in 0..stride {
1476 let src_offset = i;
1477 let dst_offset = i * n0;
1478 for j in 0..n0 {
1479 tmp[dst_offset + j] = x[j * stride + src_offset];
1480 }
1481 }
1482
1483 x[..n].copy_from_slice(tmp);
1484}
1485
1486pub fn deinterleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1487 let n = n0 * stride;
1488
1489 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1490
1491 let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1492 if hadamard {
1493 let offset = match stride {
1494 2 => 0,
1495 4 => 2,
1496 8 => 6,
1497 16 => 14,
1498 _ => 0,
1499 };
1500 let ordery = &ORDERY_TABLE[offset..offset + stride];
1501 for i in 0..stride {
1502 for j in 0..n0 {
1503 tmp[ordery[i] as usize * n0 + j] = x[j * stride + i];
1504 }
1505 }
1506 } else {
1507 #[cfg(target_arch = "aarch64")]
1508 unsafe {
1509 if n0 >= 4 {
1510 deinterleave_hadamard_neon(x, n0, stride);
1511 return;
1512 }
1513 }
1514 for i in 0..stride {
1515 for j in 0..n0 {
1516 tmp[i * n0 + j] = x[j * stride + i];
1517 }
1518 }
1519 }
1520 x[..n].copy_from_slice(tmp);
1521}
1522
1523#[cfg(target_arch = "aarch64")]
1524#[inline(always)]
1525unsafe fn interleave_hadamard_neon(x: &mut [f32], n0: usize, stride: usize) {
1526 let n = n0 * stride;
1527 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1528 let tmp = std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n);
1529
1530 for i in 0..stride {
1531 let src_offset = i * n0;
1532 let dst_offset = i;
1533 for j in 0..n0 {
1534 tmp[j * stride + dst_offset] = x[src_offset + j];
1535 }
1536 }
1537
1538 x[..n].copy_from_slice(tmp);
1539}
1540
1541pub fn interleave_hadamard(x: &mut [f32], n0: usize, stride: usize, hadamard: bool) {
1542 let n = n0 * stride;
1543 let mut tmp_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
1544 let tmp = unsafe { std::slice::from_raw_parts_mut(tmp_buf.as_mut_ptr() as *mut f32, n) };
1545 if hadamard {
1546 let offset = match stride {
1547 2 => 0,
1548 4 => 2,
1549 8 => 6,
1550 16 => 14,
1551 _ => 0,
1552 };
1553 let ordery = &ORDERY_TABLE[offset..offset + stride];
1554 for i in 0..stride {
1555 for j in 0..n0 {
1556 tmp[j * stride + i] = x[ordery[i] as usize * n0 + j];
1557 }
1558 }
1559 } else {
1560 #[cfg(target_arch = "aarch64")]
1561 unsafe {
1562 if n0 >= 4 {
1563 interleave_hadamard_neon(x, n0, stride);
1564 return;
1565 }
1566 }
1567 for i in 0..stride {
1568 for j in 0..n0 {
1569 tmp[j * stride + i] = x[i * n0 + j];
1570 }
1571 }
1572 }
1573 x[..n].copy_from_slice(tmp);
1574}
1575
1576const ORDERY_TABLE: [i32; 30] = [
1577 1, 0, 3, 0, 2, 1, 7, 0, 4, 3, 6, 1, 5, 2, 15, 0, 8, 7, 12, 3, 11, 4, 14, 1, 9, 6, 13, 2, 10, 5,
1578];
1579
1580fn quant_band_n1(
1581 ctx: &mut BandCtx,
1582 x: &mut [f32],
1583 y: Option<&mut [f32]>,
1584 lowband_out: Option<&mut [f32]>,
1585) -> u32 {
1586 let mut sign = 0;
1587 if ctx.remaining_bits >= 1 << BITRES {
1588 if ctx.encode {
1589 sign = if x[0] < 0.0 { 1 } else { 0 };
1590 ctx.rc.enc_bits(sign as u32, 1);
1591 } else {
1592 sign = ctx.rc.dec_bits(1) as i32;
1593 }
1594 ctx.remaining_bits -= 1 << BITRES;
1595 }
1596 if ctx.resynth {
1597 x[0] = if sign != 0 { -1.0 } else { 1.0 };
1598 }
1599 if let Some(y_val) = y {
1600 let mut y_sign = 0;
1601 if ctx.remaining_bits >= 1 << BITRES {
1602 if ctx.encode {
1603 y_sign = if y_val[0] < 0.0 { 1 } else { 0 };
1604 ctx.rc.enc_bits(y_sign as u32, 1);
1605 } else {
1606 y_sign = ctx.rc.dec_bits(1) as i32;
1607 }
1608 ctx.remaining_bits -= 1 << BITRES;
1609 }
1610 if ctx.resynth {
1611 y_val[0] = if y_sign != 0 { -1.0 } else { 1.0 };
1612 }
1613 }
1614 if let Some(l_out) = lowband_out {
1615 l_out[0] = x[0] / 16.0;
1616 }
1617 1
1618}
1619
1620#[allow(clippy::too_many_arguments)]
1621#[inline(always)]
1622pub fn quant_band(
1623 ctx: &mut BandCtx,
1624 x: &mut [f32],
1625 n: usize,
1626 b: i32,
1627 b_blocks: i32,
1628 lowband: Option<&mut [f32]>,
1629 lm: i32,
1630 lowband_out: Option<&mut [f32]>,
1631 gain: f32,
1632 fill: u32,
1633) -> u32 {
1634 let n0 = n;
1635 let b0 = b_blocks;
1636 let long_blocks = b0 == 1;
1637
1638 if n == 1 {
1639 return quant_band_n1(ctx, x, None, lowband_out);
1640 }
1641
1642 let mut b_blocks = b_blocks;
1643 let mut n_b = n / b_blocks as usize;
1644 let mut time_divide = 0;
1645 let mut recombine = 0;
1646 let mut tf_change_local = ctx.tf_change;
1647 let mut fill = fill;
1648
1649 if tf_change_local > 0 {
1650 recombine = tf_change_local;
1651 }
1652
1653 let mut lowband_buf = lowband;
1654
1655 static BIT_INTERLEAVE_TABLE: [u8; 16] = [0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3];
1656
1657 for k in 0..recombine {
1658 if ctx.encode {
1659 haar1(x, n >> k, 1 << k);
1660 }
1661 if let Some(ref mut lb) = lowband_buf {
1662 haar1(lb, n >> k, 1 << k);
1663 }
1664 fill = (BIT_INTERLEAVE_TABLE[(fill & 0xF) as usize] as u32)
1665 | ((BIT_INTERLEAVE_TABLE[(fill >> 4) as usize] as u32) << 2);
1666 }
1667 b_blocks >>= recombine;
1668 n_b <<= recombine;
1669
1670 while n_b & 1 == 0 && tf_change_local < 0 {
1671 if ctx.encode {
1672 haar1(x, n_b, b_blocks as usize);
1673 }
1674 if let Some(ref mut lb) = lowband_buf {
1675 haar1(lb, n_b, b_blocks as usize);
1676 }
1677 fill |= fill << b_blocks;
1678 b_blocks <<= 1;
1679 n_b >>= 1;
1680 time_divide += 1;
1681 tf_change_local += 1;
1682 }
1683
1684 let b0_after = b_blocks;
1685 let n_b0 = n_b;
1686
1687 if b_blocks > 1 {
1688 if ctx.encode {
1689 deinterleave_hadamard(
1690 x,
1691 n_b >> recombine as usize,
1692 (b_blocks << recombine) as usize,
1693 long_blocks,
1694 );
1695 }
1696 if let Some(ref mut lb) = lowband_buf {
1697 deinterleave_hadamard(
1698 lb,
1699 n_b >> recombine as usize,
1700 (b_blocks << recombine) as usize,
1701 long_blocks,
1702 );
1703 }
1704 }
1705
1706 let cm = if ctx.encode {
1707 quant_partition_encode(ctx, x, n, b, b_blocks, lowband_buf, lm, gain, fill)
1708 } else {
1709 quant_partition(ctx, x, n, b, b_blocks, lowband_buf, lm, gain, fill)
1710 };
1711
1712 if ctx.resynth {
1713 let mut cm = cm;
1714
1715 if b_blocks > 1 {
1716 interleave_hadamard(
1717 x,
1718 n_b >> recombine as usize,
1719 (b0_after << recombine) as usize,
1720 long_blocks,
1721 );
1722 }
1723
1724 let mut n_b_undo = n_b0;
1725 let mut b_undo = b0_after;
1726 for _ in 0..time_divide {
1727 b_undo >>= 1;
1728 n_b_undo <<= 1;
1729 cm |= cm >> b_undo;
1730 haar1(x, n_b_undo, b_undo as usize);
1731 }
1732
1733 static BIT_DEINTERLEAVE_TABLE: [u8; 16] = [
1734 0x00, 0x03, 0x0C, 0x0F, 0x30, 0x33, 0x3C, 0x3F, 0xC0, 0xC3, 0xCC, 0xCF, 0xF0, 0xF3,
1735 0xFC, 0xFF,
1736 ];
1737 for k in 0..recombine {
1738 cm = BIT_DEINTERLEAVE_TABLE[cm as usize & 0xF] as u32;
1739 haar1(x, n0 >> k, 1 << k);
1740 }
1741 let mut b_final = b0_after;
1742 b_final <<= recombine;
1743
1744 if let Some(lb_out) = lowband_out {
1745 let scale = (n0 as f32).sqrt();
1746 for j in 0..n0 {
1747 lb_out[j] = scale * x[j];
1748 }
1749 }
1750 cm &= (1u32 << b_final) - 1;
1751 return cm;
1752 }
1753
1754 cm
1755}
1756
1757pub fn stereo_merge(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1758 #[cfg(target_arch = "aarch64")]
1759 {
1760 stereo_merge_neon(x, y, mid, side, n);
1761 }
1762 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
1763 {
1764 unsafe { stereo_merge_avx2(x, y, mid, side, n) };
1765 return;
1766 }
1767 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx2")))]
1768 {
1769 if std::arch::is_x86_feature_detected!("avx2") {
1770 unsafe { stereo_merge_avx2(x, y, mid, side, n) };
1771 return;
1772 }
1773 }
1774 #[cfg(not(target_arch = "aarch64"))]
1775 stereo_merge_scalar(x, y, mid, side, n);
1776}
1777
1778#[cfg(target_arch = "x86_64")]
1779#[target_feature(enable = "avx2")]
1780unsafe fn stereo_merge_avx2(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1781 use std::arch::x86_64::*;
1782
1783 let mut i = 0;
1784
1785 let v_mid = _mm256_set1_ps(mid);
1786 let v_side = _mm256_set1_ps(side);
1787
1788 while i + 15 < n {
1789 let x0 = _mm256_loadu_ps(x.as_ptr().add(i));
1790 let x1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
1791 let y0 = _mm256_loadu_ps(y.as_ptr().add(i));
1792 let y1 = _mm256_loadu_ps(y.as_ptr().add(i + 8));
1793
1794 let x_val0 = _mm256_mul_ps(x0, v_mid);
1795 let x_val1 = _mm256_mul_ps(x1, v_mid);
1796 let y_val0 = _mm256_mul_ps(y0, v_side);
1797 let y_val1 = _mm256_mul_ps(y1, v_side);
1798
1799 let new_x0 = _mm256_sub_ps(x_val0, y_val0);
1800 let new_x1 = _mm256_sub_ps(x_val1, y_val1);
1801 let new_y0 = _mm256_add_ps(x_val0, y_val0);
1802 let new_y1 = _mm256_add_ps(x_val1, y_val1);
1803
1804 _mm256_storeu_ps(x.as_mut_ptr().add(i), new_x0);
1805 _mm256_storeu_ps(x.as_mut_ptr().add(i + 8), new_x1);
1806 _mm256_storeu_ps(y.as_mut_ptr().add(i), new_y0);
1807 _mm256_storeu_ps(y.as_mut_ptr().add(i + 8), new_y1);
1808
1809 i += 16;
1810 }
1811
1812 while i + 7 < n {
1813 let x0 = _mm256_loadu_ps(x.as_ptr().add(i));
1814 let y0 = _mm256_loadu_ps(y.as_ptr().add(i));
1815
1816 let x_val = _mm256_mul_ps(x0, v_mid);
1817 let y_val = _mm256_mul_ps(y0, v_side);
1818
1819 let new_x = _mm256_sub_ps(x_val, y_val);
1820 let new_y = _mm256_add_ps(x_val, y_val);
1821
1822 _mm256_storeu_ps(x.as_mut_ptr().add(i), new_x);
1823 _mm256_storeu_ps(y.as_mut_ptr().add(i), new_y);
1824
1825 i += 8;
1826 }
1827
1828 for j in i..n {
1829 let x_val = x[j] * mid;
1830 let y_val = y[j] * side;
1831 x[j] = x_val - y_val;
1832 y[j] = x_val + y_val;
1833 }
1834}
1835
1836#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1837#[inline]
1838fn stereo_merge_scalar(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1839 for i in 0..n {
1840 let x_val = x[i] * mid;
1841 let y_val = y[i] * side;
1842 x[i] = x_val - y_val;
1843 y[i] = x_val + y_val;
1844 }
1845}
1846
1847#[cfg(target_arch = "aarch64")]
1848fn stereo_merge_neon(x: &mut [f32], y: &mut [f32], mid: f32, side: f32, n: usize) {
1849 use std::arch::aarch64::*;
1850
1851 unsafe {
1852 let vmid = vdupq_n_f32(mid);
1853 let vside = vdupq_n_f32(side);
1854
1855 let n16 = n & !15;
1856 for i in (0..n16).step_by(16) {
1857 let x0 = vld1q_f32(x.as_ptr().add(i));
1858 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
1859 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
1860 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
1861
1862 let y0 = vld1q_f32(y.as_ptr().add(i));
1863 let y1 = vld1q_f32(y.as_ptr().add(i + 4));
1864 let y2 = vld1q_f32(y.as_ptr().add(i + 8));
1865 let y3 = vld1q_f32(y.as_ptr().add(i + 12));
1866
1867 let xv0 = vmulq_f32(x0, vmid);
1868 let xv1 = vmulq_f32(x1, vmid);
1869 let xv2 = vmulq_f32(x2, vmid);
1870 let xv3 = vmulq_f32(x3, vmid);
1871
1872 let yv0 = vmulq_f32(y0, vside);
1873 let yv1 = vmulq_f32(y1, vside);
1874 let yv2 = vmulq_f32(y2, vside);
1875 let yv3 = vmulq_f32(y3, vside);
1876
1877 vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(xv0, yv0));
1878 vst1q_f32(x.as_mut_ptr().add(i + 4), vsubq_f32(xv1, yv1));
1879 vst1q_f32(x.as_mut_ptr().add(i + 8), vsubq_f32(xv2, yv2));
1880 vst1q_f32(x.as_mut_ptr().add(i + 12), vsubq_f32(xv3, yv3));
1881
1882 vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(xv0, yv0));
1883 vst1q_f32(y.as_mut_ptr().add(i + 4), vaddq_f32(xv1, yv1));
1884 vst1q_f32(y.as_mut_ptr().add(i + 8), vaddq_f32(xv2, yv2));
1885 vst1q_f32(y.as_mut_ptr().add(i + 12), vaddq_f32(xv3, yv3));
1886 }
1887
1888 let n4 = (n & !3) - n16;
1889 for i in (n16..n16 + n4).step_by(4) {
1890 let xv = vld1q_f32(x.as_ptr().add(i));
1891 let yv = vld1q_f32(y.as_ptr().add(i));
1892
1893 let x_val = vmulq_f32(xv, vmid);
1894 let y_val = vmulq_f32(yv, vside);
1895
1896 vst1q_f32(x.as_mut_ptr().add(i), vsubq_f32(x_val, y_val));
1897 vst1q_f32(y.as_mut_ptr().add(i), vaddq_f32(x_val, y_val));
1898 }
1899
1900 for i in (n16 + n4)..n {
1901 let x_val = x[i] * mid;
1902 let y_val = y[i] * side;
1903 x[i] = x_val - y_val;
1904 y[i] = x_val + y_val;
1905 }
1906 }
1907}
1908
1909#[allow(clippy::too_many_arguments)]
1910#[inline(always)]
1911pub fn quant_band_stereo(
1912 ctx: &mut BandCtx,
1913 x: &mut [f32],
1914 y: &mut [f32],
1915 n: usize,
1916 b: i32,
1917 b_blocks: i32,
1918 lowband: Option<&mut [f32]>,
1919 lm: i32,
1920 lowband_out: Option<&mut [f32]>,
1921 _gain: f32,
1922 fill: u32,
1923) -> u32 {
1924 if n == 1 {
1925 return quant_band_n1(ctx, x, Some(y), lowband_out);
1926 }
1927
1928 if ctx.encode
1929 && (ctx.band_e[ctx.i] < MIN_STEREO_ENERGY
1930 || ctx.band_e[ctx.m.nb_ebands + ctx.i] < MIN_STEREO_ENERGY)
1931 {
1932 if ctx.band_e[ctx.i] > ctx.band_e[ctx.m.nb_ebands + ctx.i] {
1933 y.copy_from_slice(x);
1934 } else {
1935 x.copy_from_slice(y);
1936 }
1937 }
1938
1939 let mut sctx = SplitCtx {
1940 inv: false,
1941 imid: 0,
1942 iside: 0,
1943 delta: 0,
1944 itheta: 0,
1945 qalloc: 0,
1946 };
1947 let mut b_mut = b;
1948 let mut fill_mut = fill;
1949 if ctx.encode {
1950 compute_theta_encode(
1951 ctx,
1952 &mut sctx,
1953 x,
1954 y,
1955 n,
1956 &mut b_mut,
1957 b_blocks,
1958 b_blocks,
1959 lm,
1960 true,
1961 &mut fill_mut,
1962 );
1963 } else {
1964 compute_theta(
1965 ctx,
1966 &mut sctx,
1967 x,
1968 y,
1969 n,
1970 &mut b_mut,
1971 b_blocks,
1972 b_blocks,
1973 lm,
1974 true,
1975 &mut fill_mut,
1976 );
1977 };
1978
1979 let mid_gain = sctx.imid as f32 / 32768.0;
1980 let side_gain = sctx.iside as f32 / 32768.0;
1981
1982 if n == 2 {
1983 let mut mbits = b_mut;
1984 let mut sbits = 0;
1985 if sctx.itheta != 0 && sctx.itheta != 16384 {
1986 sbits = 1 << BITRES;
1987 }
1988 mbits -= sbits;
1989 let c = sctx.itheta > 8192;
1990 ctx.remaining_bits -= sctx.qalloc + sbits;
1991
1992 let mut sign = 0;
1993 if sbits != 0 {
1994 if ctx.encode {
1995 sign = if c {
1996 if (y[0] * x[1] - y[1] * x[0]) < 0.0 {
1997 1
1998 } else {
1999 0
2000 }
2001 } else if (x[0] * y[1] - x[1] * y[0]) < 0.0 {
2002 1
2003 } else {
2004 0
2005 };
2006 ctx.rc.enc_bits(sign as u32, 1);
2007 } else {
2008 sign = ctx.rc.dec_bits(1) as i32;
2009 }
2010 }
2011 let sign_val = (1 - 2 * sign) as f32;
2012 let cm = if c {
2013 let cm = quant_band(
2014 ctx,
2015 y,
2016 n,
2017 mbits,
2018 b_blocks,
2019 lowband,
2020 lm,
2021 lowband_out,
2022 1.0,
2023 fill,
2024 );
2025 x[0] = -sign_val * y[1];
2026 x[1] = sign_val * y[0];
2027 cm
2028 } else {
2029 let cm = quant_band(
2030 ctx,
2031 x,
2032 n,
2033 mbits,
2034 b_blocks,
2035 lowband,
2036 lm,
2037 lowband_out,
2038 1.0,
2039 fill,
2040 );
2041 y[0] = -sign_val * x[1];
2042 y[1] = sign_val * x[0];
2043 cm
2044 };
2045
2046 if ctx.resynth {
2047 let x0 = x[0];
2048 let x1 = x[1];
2049 let y0 = y[0];
2050 let y1 = y[1];
2051 x[0] = mid_gain * x0 - side_gain * y0;
2052 x[1] = mid_gain * x1 - side_gain * y1;
2053 y[0] = mid_gain * x0 + side_gain * y0;
2054 y[1] = mid_gain * x1 + side_gain * y1;
2055 }
2056 return cm;
2057 }
2058
2059 ctx.remaining_bits -= sctx.qalloc;
2060 let mut mbits = (0).max((b_mut - sctx.delta) / 2).min(b_mut);
2061 let mut sbits = b_mut - mbits;
2062
2063 let mut rebalance = ctx.remaining_bits;
2064 let mut cm;
2065
2066 if mbits >= sbits {
2067 cm = quant_band(
2068 ctx,
2069 x,
2070 n,
2071 mbits,
2072 b_blocks,
2073 lowband,
2074 lm,
2075 lowband_out,
2076 1.0,
2077 fill_mut,
2078 );
2079 rebalance = mbits - (rebalance - ctx.remaining_bits);
2080 if rebalance > (3 << 3) && sctx.itheta != 0 {
2081 sbits += rebalance - (3 << 3);
2082 }
2083 cm |= quant_band(
2084 ctx,
2085 y,
2086 n,
2087 sbits,
2088 b_blocks,
2089 None,
2090 lm,
2091 None,
2092 side_gain,
2093 fill_mut >> b_blocks,
2094 ) << (b_blocks >> 1);
2095 } else {
2096 cm = quant_band(
2097 ctx,
2098 y,
2099 n,
2100 sbits,
2101 b_blocks,
2102 None,
2103 lm,
2104 None,
2105 side_gain,
2106 fill_mut >> b_blocks,
2107 ) << (b_blocks >> 1);
2108 rebalance = sbits - (rebalance - ctx.remaining_bits);
2109 if rebalance > (3 << 3) && sctx.itheta != 16384 {
2110 mbits += rebalance - (3 << 3);
2111 }
2112 cm |= quant_band(
2113 ctx,
2114 x,
2115 n,
2116 mbits,
2117 b_blocks,
2118 lowband,
2119 lm,
2120 lowband_out,
2121 1.0,
2122 fill_mut,
2123 );
2124 }
2125
2126 if ctx.resynth {
2127 stereo_merge(x, y, mid_gain, side_gain, n);
2128 if sctx.inv {
2129 for yv in y[..n].iter_mut() {
2130 *yv = -*yv;
2131 }
2132 }
2133 }
2134 cm
2135}
2136
2137#[allow(clippy::too_many_arguments)]
2138pub fn quant_all_bands(
2139 encode: bool,
2140 m: &CeltMode,
2141 start: usize,
2142 end: usize,
2143 x: &mut [f32],
2144 mut y: Option<&mut [f32]>,
2145 collapse_masks: &mut [u32],
2146 band_e: &[f32],
2147 pulses: &[i32],
2148 short_blocks: bool,
2149 spread: i32,
2150 dual_stereo: &mut bool,
2151 intensity: usize,
2152 tf_res: &[i32],
2153 total_bits: i32,
2154 balance: &mut i32,
2155 rc: &mut RangeCoder,
2156 lm: i32,
2157 coded_bands: i32,
2158 resynth: bool,
2159 seed: &mut u32,
2160) {
2161 let mut balance_val = *balance;
2162 let b_blocks = if short_blocks { 1 << lm } else { 1 };
2163 let c_channels = if y.is_some() { 2 } else { 1 };
2164 let m_val = 1usize << lm as usize;
2165
2166 let norm_offset = m_val * (m.e_bands[start] as usize);
2167 let norm_size = m_val * (m.e_bands[m.nb_ebands - 1] as usize) - norm_offset;
2168
2169 const MAX_NORM_SIZE: usize = 800;
2170 debug_assert!(norm_size <= MAX_NORM_SIZE);
2171
2172 let mut norm_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_NORM_SIZE];
2173 let norm =
2174 unsafe { std::slice::from_raw_parts_mut(norm_buf.as_mut_ptr() as *mut f32, norm_size) };
2175
2176 let mut lowband_scratch_buf = [std::mem::MaybeUninit::<f32>::uninit(); MAX_PVQ_N];
2177 let lowband_scratch_ptr = lowband_scratch_buf.as_mut_ptr() as *mut f32;
2178
2179 let lowband_offset: usize = 0;
2180 let mut avoid_split_noise = b_blocks > 1;
2181
2182 let e_bands = &m.e_bands;
2183 let mut ctx_seed = *seed;
2184
2185 for i in start..end {
2186 let e_band_i = e_bands[i] as usize;
2187 let e_band_i1 = e_bands[i + 1] as usize;
2188 let offset = m_val * e_band_i;
2189 let n = m_val * (e_band_i1 - e_band_i);
2190 let last = i == end - 1;
2191
2192 let tell = tell_frac_inline!(rc);
2193 if i != start {
2194 balance_val -= tell;
2195 }
2196 let remaining_bits = total_bits - tell - 1;
2197
2198 let mut b = 0i32;
2199 if i < coded_bands as usize {
2200 let curr_balance = celt_sudiv(balance_val, 3i32.min(coded_bands - i as i32));
2201 b = 0i32.max(16383i32.min((remaining_bits + 1).min(pulses[i] + curr_balance)));
2202 }
2203
2204 let norm_pos = m_val * e_band_i - norm_offset;
2205 let tf_change = tf_res[i];
2206
2207 let mut effective_lowband: i32 = -1;
2208 let mut x_cm: u32;
2209 let mut y_cm: u32;
2210
2211 if lowband_offset != 0 && (spread != SPREAD_AGGRESSIVE || b_blocks > 1 || tf_change < 0) {
2212 effective_lowband = 0i32.max(
2213 (m_val * e_bands[lowband_offset] as usize) as i32 - norm_offset as i32 - n as i32,
2214 );
2215 let el_abs = effective_lowband as usize + norm_offset;
2216
2217 let mut fold_start = lowband_offset;
2218 loop {
2219 if fold_start == 0 {
2220 break;
2221 }
2222 fold_start -= 1;
2223 if m_val * (e_bands[fold_start] as usize) <= el_abs {
2224 break;
2225 }
2226 }
2227 let mut fold_end = lowband_offset.saturating_sub(1);
2228 while fold_end + 1 < i && m_val * (e_bands[fold_end + 1] as usize) < el_abs + n {
2229 fold_end += 1;
2230 }
2231
2232 x_cm = 0;
2233 y_cm = 0;
2234 for fi in fold_start..fold_end {
2235 x_cm |= collapse_masks[fi * c_channels];
2236 y_cm |= collapse_masks[fi * c_channels + c_channels - 1];
2237 }
2238 } else {
2239 x_cm = (1u32 << b_blocks) - 1;
2240 y_cm = (1u32 << b_blocks) - 1;
2241 }
2242
2243 let mut ctx = BandCtx {
2244 encode,
2245 m,
2246 i,
2247 band_e,
2248 rc,
2249 spread,
2250 remaining_bits,
2251 resynth,
2252 tf_change,
2253 intensity,
2254 theta_round: 0,
2255 avoid_split_noise,
2256 arch: 0,
2257 disable_inv: false,
2258 seed: ctx_seed,
2259 };
2260
2261 if *dual_stereo && i == intensity {
2262 *dual_stereo = false;
2263 }
2264
2265 let mut lowband_scratch: Option<&mut [f32]> = if effective_lowband >= 0 {
2266 let lb_start = effective_lowband as usize;
2267 let lb_end = lb_start + n;
2268 if lb_end <= norm.len() {
2269 unsafe {
2270 std::ptr::copy_nonoverlapping(
2271 norm.as_ptr().add(lb_start),
2272 lowband_scratch_ptr,
2273 n,
2274 )
2275 };
2276 Some(unsafe { std::slice::from_raw_parts_mut(lowband_scratch_ptr, n) })
2277 } else {
2278 None
2279 }
2280 } else {
2281 None
2282 };
2283
2284 let x_slice = &mut x[offset..offset + n];
2285 if *dual_stereo {
2286 let y_slice = &mut y.as_mut().unwrap()[offset..offset + n];
2287 let lb_x = lowband_scratch.as_deref_mut();
2288 let lb_out_x = if !last && norm_pos + n <= norm.len() {
2289 Some(&mut norm[norm_pos..norm_pos + n])
2290 } else {
2291 None
2292 };
2293 x_cm = quant_band(
2294 &mut ctx,
2295 x_slice,
2296 n,
2297 b / 2,
2298 b_blocks,
2299 lb_x,
2300 lm,
2301 lb_out_x,
2302 1.0,
2303 x_cm,
2304 );
2305 y_cm = quant_band(
2306 &mut ctx,
2307 y_slice,
2308 n,
2309 b / 2,
2310 b_blocks,
2311 None,
2312 lm,
2313 None,
2314 1.0,
2315 y_cm,
2316 );
2317 } else if let Some(y_all) = y.as_mut() {
2318 let y_slice = &mut y_all[offset..offset + n];
2319 let lb = lowband_scratch.as_deref_mut();
2320 let lb_out = if !last && norm_pos + n <= norm.len() {
2321 Some(&mut norm[norm_pos..norm_pos + n])
2322 } else {
2323 None
2324 };
2325 x_cm = quant_band_stereo(
2326 &mut ctx,
2327 x_slice,
2328 y_slice,
2329 n,
2330 b,
2331 b_blocks,
2332 lb,
2333 lm,
2334 lb_out,
2335 1.0,
2336 x_cm | y_cm,
2337 );
2338 y_cm = x_cm;
2339 } else {
2340 let lb = lowband_scratch;
2341 let lb_out = if !last && norm_pos + n <= norm.len() {
2342 Some(&mut norm[norm_pos..norm_pos + n])
2343 } else {
2344 None
2345 };
2346 x_cm = quant_band(&mut ctx, x_slice, n, b, b_blocks, lb, lm, lb_out, 1.0, x_cm);
2347 y_cm = x_cm;
2348 }
2349
2350 collapse_masks[i * c_channels] = (x_cm & 0xFF) as u8 as u32;
2351 if c_channels == 2 {
2352 collapse_masks[i * c_channels + 1] = (y_cm & 0xFF) as u8 as u32;
2353 }
2354
2355 balance_val += pulses[i] + tell;
2356 ctx_seed = ctx.seed;
2357
2358 avoid_split_noise = false;
2359 }
2360 *balance = balance_val;
2361 *seed = ctx_seed;
2362}
2363
2364#[cfg(target_arch = "aarch64")]
2365fn compute_band_energy_neon(band: &[f32]) -> f32 {
2366 use std::arch::aarch64::*;
2367
2368 let n = band.len();
2369 let mut sum = 1e-27f32;
2370
2371 unsafe {
2372 let n16 = n & !15;
2373 if n16 > 0 {
2374 let mut acc0 = vdupq_n_f32(0.0);
2375 let mut acc1 = vdupq_n_f32(0.0);
2376 let mut acc2 = vdupq_n_f32(0.0);
2377 let mut acc3 = vdupq_n_f32(0.0);
2378
2379 for i in (0..n16).step_by(16) {
2380 let v0 = vld1q_f32(band.as_ptr().add(i));
2381 let v1 = vld1q_f32(band.as_ptr().add(i + 4));
2382 let v2 = vld1q_f32(band.as_ptr().add(i + 8));
2383 let v3 = vld1q_f32(band.as_ptr().add(i + 12));
2384
2385 acc0 = vfmaq_f32(acc0, v0, v0);
2386 acc1 = vfmaq_f32(acc1, v1, v1);
2387 acc2 = vfmaq_f32(acc2, v2, v2);
2388 acc3 = vfmaq_f32(acc3, v3, v3);
2389 }
2390
2391 acc0 = vaddq_f32(acc0, acc1);
2392 acc2 = vaddq_f32(acc2, acc3);
2393 acc0 = vaddq_f32(acc0, acc2);
2394 sum += vaddvq_f32(acc0);
2395 }
2396
2397 let n4 = (n & !3) - n16;
2398 if n4 > 0 {
2399 let mut acc = vdupq_n_f32(0.0);
2400 for i in (n16..n16 + n4).step_by(4) {
2401 let v = vld1q_f32(band.as_ptr().add(i));
2402 acc = vfmaq_f32(acc, v, v);
2403 }
2404 sum += vaddvq_f32(acc);
2405 }
2406
2407 for i in (n16 + n4)..n {
2408 let v = band[i];
2409 sum += v * v;
2410 }
2411 }
2412
2413 sum.sqrt()
2414}
2415
2416#[cfg(target_arch = "x86_64")]
2417#[target_feature(enable = "avx2,fma")]
2418unsafe fn compute_band_energy_avx2(band: &[f32]) -> f32 {
2419 use std::arch::x86_64::*;
2420
2421 let n = band.len();
2422 let mut i = 0usize;
2423
2424 let mut acc0 = _mm256_setzero_ps();
2425 let mut acc1 = _mm256_setzero_ps();
2426
2427 while i + 16 <= n {
2428 let v0 = _mm256_loadu_ps(band.as_ptr().add(i));
2429 let v1 = _mm256_loadu_ps(band.as_ptr().add(i + 8));
2430 acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2431 acc1 = _mm256_fmadd_ps(v1, v1, acc1);
2432 i += 16;
2433 }
2434
2435 if i + 8 <= n {
2436 let v0 = _mm256_loadu_ps(band.as_ptr().add(i));
2437 acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2438 i += 8;
2439 }
2440
2441 let acc = _mm256_add_ps(acc0, acc1);
2442 let hi = _mm256_extractf128_ps(acc, 1);
2443 let lo = _mm256_castps256_ps128(acc);
2444 let s4 = _mm_add_ps(lo, hi);
2445 let t1 = _mm_movehl_ps(s4, s4);
2446 let s2 = _mm_add_ps(s4, t1);
2447 let t2 = _mm_shuffle_ps(s2, s2, 0x55);
2448 let mut sum = 1e-27f32 + _mm_cvtss_f32(_mm_add_ss(s2, t2));
2449
2450 for &v in &band[i..] {
2451 sum += v * v;
2452 }
2453
2454 sum.sqrt()
2455}
2456
2457pub fn compute_band_energies(
2458 m: &CeltMode,
2459 x: &[f32],
2460 band_e: &mut [f32],
2461 end: usize,
2462 channels: usize,
2463 lm: usize,
2464) {
2465 let frame_size = m.short_mdct_size << lm;
2466
2467 #[cfg(target_arch = "x86_64")]
2468 let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2469
2470 for c in 0..channels {
2471 let ch = &x[c * frame_size..(c + 1) * frame_size];
2472 for i in 0..end {
2473 let offset = (m.e_bands[i] as usize) << lm;
2474 let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2475 let band = &ch[offset..offset + n];
2476
2477 #[cfg(target_arch = "aarch64")]
2478 {
2479 band_e[c * m.nb_ebands + i] = compute_band_energy_neon(band);
2480 }
2481 #[cfg(target_arch = "x86_64")]
2482 {
2483 if n >= 8 && use_avx2 {
2484 band_e[c * m.nb_ebands + i] = unsafe { compute_band_energy_avx2(band) };
2485 } else {
2486 let sum = band.iter().fold(1e-27f32, |acc, &v| acc + v * v);
2487 band_e[c * m.nb_ebands + i] = sum.sqrt();
2488 }
2489 }
2490 #[cfg(all(not(target_arch = "aarch64"), not(target_arch = "x86_64")))]
2491 {
2492 let sum = band.iter().fold(1e-27f32, |acc, &v| acc + v * v);
2493 band_e[c * m.nb_ebands + i] = sum.sqrt();
2494 }
2495 }
2496 }
2497}
2498
2499pub fn amp2log2(
2500 m: &CeltMode,
2501 start: usize,
2502 end: usize,
2503 band_e: &[f32],
2504 band_log_e: &mut [f32],
2505 channels: usize,
2506) {
2507 for c in 0..channels {
2508 for i in 0..start {
2509 band_log_e[c * m.nb_ebands + i] = -14.0;
2510 }
2511 for i in start..end {
2512 let val = band_e[c * m.nb_ebands + i].max(1e-10);
2513 band_log_e[c * m.nb_ebands + i] = val.log2() - m.e_means[i];
2514 }
2515 }
2516}
2517
2518pub fn log2amp(m: &CeltMode, end: usize, band_e: &mut [f32], band_log_e: &[f32], channels: usize) {
2519 for c in 0..channels {
2520 for i in 0..end {
2521 band_e[c * m.nb_ebands + i] = band_log_e[c * m.nb_ebands + i] + m.e_means[i];
2522 }
2523 }
2524}
2525
2526pub fn normalise_bands(
2527 m: &CeltMode,
2528 freq: &[f32],
2529 x: &mut [f32],
2530 band_e: &[f32],
2531 end: usize,
2532 channels: usize,
2533 m_val: usize,
2534) {
2535 let lm = m_val.trailing_zeros() as usize;
2536 let frame_size = m.short_mdct_size << lm;
2537 #[cfg(target_arch = "x86_64")]
2538 let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2539 for c in 0..channels {
2540 for i in 0..end {
2541 let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2542 let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2543 let norm = 1.0 / (1e-27 + band_e[c * m.nb_ebands + i]);
2544 let src = &freq[base..base + n];
2545 let dst = &mut x[base..base + n];
2546 #[cfg(target_arch = "x86_64")]
2547 if n >= 8 && use_avx2 {
2548 unsafe { scale_slice_avx2(src, dst, norm, n) };
2549 continue;
2550 }
2551 #[cfg(target_arch = "aarch64")]
2552 if n >= 8 {
2553 unsafe { scale_slice_neon(src, dst, norm, n) };
2554 continue;
2555 }
2556 for (d, &s) in dst.iter_mut().zip(src) {
2557 *d = s * norm;
2558 }
2559 }
2560 }
2561}
2562
2563#[cfg(target_arch = "x86_64")]
2564#[target_feature(enable = "avx2")]
2565unsafe fn scale_slice_avx2(src: &[f32], dst: &mut [f32], scale: f32, n: usize) {
2566 use std::arch::x86_64::*;
2567 let vscale = _mm256_set1_ps(scale);
2568 let mut i = 0;
2569
2570 while i + 16 <= n {
2571 let s0 = _mm256_loadu_ps(src.as_ptr().add(i));
2572 let s1 = _mm256_loadu_ps(src.as_ptr().add(i + 8));
2573 _mm256_storeu_ps(dst.as_mut_ptr().add(i), _mm256_mul_ps(s0, vscale));
2574 _mm256_storeu_ps(dst.as_mut_ptr().add(i + 8), _mm256_mul_ps(s1, vscale));
2575 i += 16;
2576 }
2577 while i + 8 <= n {
2578 let sv = _mm256_loadu_ps(src.as_ptr().add(i));
2579 _mm256_storeu_ps(dst.as_mut_ptr().add(i), _mm256_mul_ps(sv, vscale));
2580 i += 8;
2581 }
2582 for j in i..n {
2583 dst[j] = src[j] * scale;
2584 }
2585}
2586
2587#[cfg(target_arch = "aarch64")]
2588#[inline(always)]
2589#[allow(unsafe_op_in_unsafe_fn)]
2590unsafe fn scale_slice_neon(src: &[f32], dst: &mut [f32], scale: f32, n: usize) {
2591 use std::arch::aarch64::*;
2592 let vscale = vdupq_n_f32(scale);
2593 let mut i = 0;
2594
2595 while i + 16 <= n {
2596 let s0 = vld1q_f32(src.as_ptr().add(i));
2597 let s1 = vld1q_f32(src.as_ptr().add(i + 4));
2598 let s2 = vld1q_f32(src.as_ptr().add(i + 8));
2599 let s3 = vld1q_f32(src.as_ptr().add(i + 12));
2600 vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2601 vst1q_f32(dst.as_mut_ptr().add(i + 4), vmulq_f32(s1, vscale));
2602 vst1q_f32(dst.as_mut_ptr().add(i + 8), vmulq_f32(s2, vscale));
2603 vst1q_f32(dst.as_mut_ptr().add(i + 12), vmulq_f32(s3, vscale));
2604 i += 16;
2605 }
2606 while i + 8 <= n {
2607 let s0 = vld1q_f32(src.as_ptr().add(i));
2608 let s1 = vld1q_f32(src.as_ptr().add(i + 4));
2609 vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2610 vst1q_f32(dst.as_mut_ptr().add(i + 4), vmulq_f32(s1, vscale));
2611 i += 8;
2612 }
2613 while i + 4 <= n {
2614 let s0 = vld1q_f32(src.as_ptr().add(i));
2615 vst1q_f32(dst.as_mut_ptr().add(i), vmulq_f32(s0, vscale));
2616 i += 4;
2617 }
2618 for j in i..n {
2619 dst[j] = src[j] * scale;
2620 }
2621}
2622
2623#[allow(clippy::too_many_arguments)]
2624pub fn denormalise_bands(
2625 m: &CeltMode,
2626 x: &[f32],
2627 freq: &mut [f32],
2628 band_e: &[f32],
2629 start: usize,
2630 end: usize,
2631 channels: usize,
2632 m_val: usize,
2633) {
2634 let lm = m_val.trailing_zeros() as usize;
2635 let frame_size = m.short_mdct_size << lm;
2636 #[cfg(target_arch = "x86_64")]
2637 let use_avx2 = std::arch::is_x86_feature_detected!("avx2");
2638
2639 for c in 0..channels {
2640 for i in start..end {
2641 let base = c * frame_size + ((m.e_bands[i] as usize) << lm);
2642 let n = ((m.e_bands[i + 1] - m.e_bands[i]) as usize) << lm;
2643 let band_log = band_e[c * m.nb_ebands + i];
2644
2645 let g = (2.0f32).powf(band_log.min(32.0));
2647 let src = &x[base..base + n];
2648 let dst = &mut freq[base..base + n];
2649 #[cfg(target_arch = "x86_64")]
2650 if n >= 8 && use_avx2 {
2651 unsafe { scale_slice_avx2(src, dst, g, n) };
2652 continue;
2653 }
2654 #[cfg(target_arch = "aarch64")]
2655 if n >= 8 {
2656 unsafe { scale_slice_neon(src, dst, g, n) };
2657 continue;
2658 }
2659 for (d, &s) in dst.iter_mut().zip(src) {
2660 *d = s * g;
2661 }
2662 }
2663 }
2664}
2665
2666pub fn celt_lcg_rand(seed: u32) -> u32 {
2667 seed.wrapping_mul(1664525).wrapping_add(1013904223)
2668}
2669
2670#[cfg(target_arch = "aarch64")]
2671#[inline(always)]
2672#[allow(unsafe_op_in_unsafe_fn)]
2673unsafe fn renormalise_vector_neon(x: &mut [f32], n: usize, gain: f32) {
2674 use std::arch::aarch64::*;
2675
2676 let mut sum_vec = vdupq_n_f32(0.0);
2677 let mut i = 0;
2678
2679 while i + 16 <= n {
2680 let x0 = vld1q_f32(x.as_ptr().add(i));
2681 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2682 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2683 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2684 sum_vec = vfmaq_f32(sum_vec, x0, x0);
2685 sum_vec = vfmaq_f32(sum_vec, x1, x1);
2686 sum_vec = vfmaq_f32(sum_vec, x2, x2);
2687 sum_vec = vfmaq_f32(sum_vec, x3, x3);
2688 i += 16;
2689 }
2690
2691 while i + 8 <= n {
2692 let x0 = vld1q_f32(x.as_ptr().add(i));
2693 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2694 sum_vec = vfmaq_f32(sum_vec, x0, x0);
2695 sum_vec = vfmaq_f32(sum_vec, x1, x1);
2696 i += 8;
2697 }
2698
2699 while i + 4 <= n {
2700 let x0 = vld1q_f32(x.as_ptr().add(i));
2701 sum_vec = vfmaq_f32(sum_vec, x0, x0);
2702 i += 4;
2703 }
2704
2705 let mut e = 1e-15f32 + vaddvq_f32(sum_vec);
2706
2707 for j in i..n {
2708 e += x[j] * x[j];
2709 }
2710
2711 let norm = gain / e.sqrt();
2712 let vnorm = vdupq_n_f32(norm);
2713
2714 i = 0;
2715 while i + 16 <= n {
2716 let x0 = vld1q_f32(x.as_ptr().add(i));
2717 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2718 let x2 = vld1q_f32(x.as_ptr().add(i + 8));
2719 let x3 = vld1q_f32(x.as_ptr().add(i + 12));
2720 vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2721 vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2722 vst1q_f32(x.as_mut_ptr().add(i + 8), vmulq_f32(x2, vnorm));
2723 vst1q_f32(x.as_mut_ptr().add(i + 12), vmulq_f32(x3, vnorm));
2724 i += 16;
2725 }
2726
2727 while i + 8 <= n {
2728 let x0 = vld1q_f32(x.as_ptr().add(i));
2729 let x1 = vld1q_f32(x.as_ptr().add(i + 4));
2730 vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2731 vst1q_f32(x.as_mut_ptr().add(i + 4), vmulq_f32(x1, vnorm));
2732 i += 8;
2733 }
2734
2735 while i + 4 <= n {
2736 let x0 = vld1q_f32(x.as_ptr().add(i));
2737 vst1q_f32(x.as_mut_ptr().add(i), vmulq_f32(x0, vnorm));
2738 i += 4;
2739 }
2740
2741 for j in i..n {
2742 x[j] *= norm;
2743 }
2744}
2745
2746#[cfg(target_arch = "x86_64")]
2747#[target_feature(enable = "avx2,fma")]
2748unsafe fn renormalise_vector_avx2(x: &mut [f32], n: usize, gain: f32) {
2749 use std::arch::x86_64::*;
2750
2751 let mut i = 0usize;
2752
2753 let mut acc0 = _mm256_setzero_ps();
2754 let mut acc1 = _mm256_setzero_ps();
2755
2756 while i + 16 <= n {
2757 let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2758 let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
2759 acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2760 acc1 = _mm256_fmadd_ps(v1, v1, acc1);
2761 i += 16;
2762 }
2763
2764 if i + 8 <= n {
2765 let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2766 acc0 = _mm256_fmadd_ps(v0, v0, acc0);
2767 i += 8;
2768 }
2769
2770 let acc = _mm256_add_ps(acc0, acc1);
2771 let hi = _mm256_extractf128_ps(acc, 1);
2772 let lo = _mm256_castps256_ps128(acc);
2773 let s4 = _mm_add_ps(lo, hi);
2774 let t1 = _mm_movehl_ps(s4, s4);
2775 let s2 = _mm_add_ps(s4, t1);
2776 let t2 = _mm_shuffle_ps(s2, s2, 0x55);
2777 let mut e = 1e-15f32 + _mm_cvtss_f32(_mm_add_ss(s2, t2));
2778
2779 for &v in &x[i..n] {
2780 e += v * v;
2781 }
2782
2783 let norm = gain / e.sqrt();
2784 let vnorm = _mm256_set1_ps(norm);
2785
2786 i = 0;
2787 while i + 16 <= n {
2788 let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
2789 let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
2790 _mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v0, vnorm));
2791 _mm256_storeu_ps(x.as_mut_ptr().add(i + 8), _mm256_mul_ps(v1, vnorm));
2792 i += 16;
2793 }
2794 while i + 8 <= n {
2795 let v = _mm256_loadu_ps(x.as_ptr().add(i));
2796 _mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v, vnorm));
2797 i += 8;
2798 }
2799 for v in &mut x[i..n] {
2800 *v *= norm;
2801 }
2802}
2803
2804pub fn renormalise_vector(x: &mut [f32], n: usize, gain: f32) {
2805 #[cfg(target_arch = "aarch64")]
2806 unsafe {
2807 renormalise_vector_neon(x, n, gain);
2808 }
2809 #[cfg(target_arch = "x86_64")]
2810 unsafe {
2811 if n >= 16 && std::arch::is_x86_feature_detected!("avx2") {
2812 renormalise_vector_avx2(x, n, gain);
2813 return;
2814 }
2815 }
2816 #[cfg(all(not(target_arch = "aarch64"), not(target_arch = "x86_64")))]
2817 {
2818 let mut e = 1e-15f32;
2819 for &xv in x[..n].iter() {
2820 e += xv * xv;
2821 }
2822 let norm = gain / e.sqrt();
2823 for xv in x[..n].iter_mut() {
2824 *xv *= norm;
2825 }
2826 }
2827 #[cfg(target_arch = "x86_64")]
2828 {
2829 let mut e = 1e-15f32;
2830 for &xv in x[..n].iter() {
2831 e += xv * xv;
2832 }
2833 let norm = gain / e.sqrt();
2834 for xv in x[..n].iter_mut() {
2835 *xv *= norm;
2836 }
2837 }
2838}
2839
2840#[allow(clippy::too_many_arguments)]
2841pub fn anti_collapse(
2842 m: &CeltMode,
2843 x_buf: &mut [f32],
2844 collapse_masks: &[u32],
2845 lm: i32,
2846 channels: usize,
2847 size: usize,
2848 start: usize,
2849 end: usize,
2850 log_e: &[f32],
2851 prev1_log_e: &[f32],
2852 prev2_log_e: &[f32],
2853 pulses: &[i32],
2854 mut seed: u32,
2855) -> u32 {
2856 for i in start..end {
2857 let n0 = (m.e_bands[i + 1] - m.e_bands[i]) as usize;
2858 let depth = if n0 > 0 {
2859 ((1 + pulses[i]) / n0 as i32) >> lm
2860 } else {
2861 0
2862 };
2863
2864 let thresh = 0.5 * (-(0.125 * depth as f32)).exp2();
2865 let sqrt_1 = 1.0 / ((n0 << lm) as f32).sqrt();
2866
2867 for c in 0..channels {
2868 let p1 = prev1_log_e[c * m.nb_ebands + i];
2869 let p2 = prev2_log_e[c * m.nb_ebands + i];
2870
2871 let (p1_adj, p2_adj) = if channels == 1 && prev1_log_e.len() >= 2 * m.nb_ebands {
2872 (
2873 p1.max(prev1_log_e[m.nb_ebands + i]),
2874 p2.max(prev2_log_e[m.nb_ebands + i]),
2875 )
2876 } else {
2877 (p1, p2)
2878 };
2879
2880 let e_diff = log_e[c * m.nb_ebands + i] - p1_adj.min(p2_adj);
2881 let e_diff = e_diff.max(0.0);
2882
2883 let mut r = 2.0 * (-e_diff).exp2();
2884 if lm == 3 {
2885 r *= std::f32::consts::SQRT_2;
2886 }
2887 r = r.min(thresh);
2888 r *= sqrt_1;
2889
2890 let x_offset = c * size + ((m.e_bands[i] as usize) << lm);
2891 let mut renormalize = false;
2892 for k in 0..(1 << lm) {
2893 if (collapse_masks[i * channels + c] & (1 << k)) == 0 {
2894 for j in 0..n0 {
2895 seed = celt_lcg_rand(seed);
2896 x_buf[x_offset + (j << lm) + k] = if (seed & 0x8000) != 0 { r } else { -r };
2897 }
2898 renormalize = true;
2899 }
2900 }
2901 if renormalize {
2902 renormalise_vector(&mut x_buf[x_offset..x_offset + (n0 << lm)], n0 << lm, 1.0);
2903 }
2904 }
2905 }
2906 seed
2907}