1use crate::modes::CeltMode;
2use crate::range_coder::{BITRES, RangeCoder};
3
4pub const PRED_COEF: [f32; 4] = [
5 29440.0 / 32768.0,
6 26112.0 / 32768.0,
7 21248.0 / 32768.0,
8 16384.0 / 32768.0,
9];
10pub const BETA_COEF: [f32; 4] = [
11 30147.0 / 32768.0,
12 22282.0 / 32768.0,
13 12124.0 / 32768.0,
14 6554.0 / 32768.0,
15];
16pub const BETA_INTRA: f32 = 4915.0 / 32768.0;
17
18pub const E_PROB_MODEL: [[[u8; 42]; 2]; 4] = [
19 [
20 [
21 72, 127, 65, 129, 66, 128, 65, 128, 64, 128, 62, 128, 64, 128, 64, 128, 92, 78, 92, 79,
22 92, 78, 90, 79, 116, 41, 115, 40, 114, 40, 132, 26, 132, 26, 145, 17, 161, 12, 176, 10,
23 177, 11,
24 ],
25 [
26 24, 179, 48, 138, 54, 135, 54, 132, 53, 134, 56, 133, 55, 132, 55, 132, 61, 114, 70,
27 96, 74, 88, 75, 88, 87, 74, 89, 66, 91, 67, 100, 59, 108, 50, 120, 40, 122, 37, 97, 43,
28 78, 50,
29 ],
30 ],
31 [
32 [
33 83, 78, 84, 81, 88, 75, 86, 74, 87, 71, 90, 73, 93, 74, 93, 74, 109, 40, 114, 36, 117,
34 34, 117, 34, 143, 17, 145, 18, 146, 19, 162, 12, 165, 10, 178, 7, 189, 6, 190, 8, 177,
35 9,
36 ],
37 [
38 23, 178, 54, 115, 63, 102, 66, 98, 69, 99, 74, 89, 71, 91, 73, 91, 78, 89, 86, 80, 92,
39 66, 93, 64, 102, 59, 103, 60, 104, 60, 117, 52, 123, 44, 138, 35, 133, 31, 97, 38, 77,
40 45,
41 ],
42 ],
43 [
44 [
45 61, 90, 93, 60, 105, 42, 107, 41, 110, 45, 116, 38, 113, 38, 112, 38, 124, 26, 132, 27,
46 136, 19, 140, 20, 155, 14, 159, 16, 158, 18, 170, 13, 177, 10, 187, 8, 192, 6, 175, 9,
47 159, 10,
48 ],
49 [
50 21, 178, 59, 110, 71, 86, 75, 85, 84, 83, 91, 66, 88, 73, 87, 72, 92, 75, 98, 72, 105,
51 58, 107, 54, 115, 52, 114, 55, 112, 56, 129, 51, 132, 40, 150, 33, 140, 29, 98, 35, 77,
52 42,
53 ],
54 ],
55 [
56 [
57 42, 121, 96, 66, 108, 43, 111, 40, 117, 44, 123, 32, 120, 36, 119, 33, 127, 33, 134,
58 34, 139, 21, 147, 23, 152, 20, 158, 25, 154, 26, 166, 21, 173, 16, 184, 13, 184, 10,
59 150, 13, 139, 15,
60 ],
61 [
62 22, 178, 63, 114, 74, 82, 84, 83, 92, 82, 103, 62, 96, 72, 96, 67, 101, 73, 107, 72,
63 113, 55, 118, 52, 125, 52, 118, 52, 117, 55, 135, 49, 137, 39, 157, 32, 145, 29, 97,
64 33, 77, 40,
65 ],
66 ],
67];
68
69pub const SMALL_ENERGY_ICDF: [u8; 3] = [2, 1, 0];
70
71fn loss_distortion(
72 e_bands: &[f32],
73 old_e_bands: &[f32],
74 start: usize,
75 end: usize,
76 len: usize,
77 channels: usize,
78) -> f32 {
79 let mut dist = 0.0f32;
80 for c in 0..channels {
81 let off = c * len;
82 for i in start..end.min(len) {
83 let d = e_bands[off + i] - old_e_bands[off + i];
84 dist += d * d;
85 }
86 }
87 dist.min(200.0)
88}
89
90#[allow(clippy::too_many_arguments)]
91fn quant_coarse_energy_impl(
92 m: &CeltMode,
93 start: usize,
94 end: usize,
95 e_bands: &[f32],
96 old_e_bands: &mut [f32],
97 budget: u32,
98 tell_start: i32,
99 prob_model: &[u8; 42],
100 error: &mut [f32],
101 enc: &mut RangeCoder,
102 channels: usize,
103 lm: usize,
104 intra: bool,
105 max_decay: f32,
106 lfe: bool,
107) -> i32 {
108 let coef = if intra { 0.0 } else { PRED_COEF[lm] };
109 let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
110 let mut prev = [0.0f32; 2];
111 let mut badness = 0i32;
112
113 if tell_start + 3 <= budget as i32 {
114 enc.encode_bit_logp(intra, 3);
115 }
116
117 for i in start..end {
118 for c in 0..channels {
119 let x = e_bands[c * m.nb_ebands + i];
120 let old_e_val = old_e_bands[c * m.nb_ebands + i];
121 let old_e = old_e_val.max(-9.0);
122 let f = x - coef * old_e - prev[c];
123
124 let mut qi = (f + 0.5).floor() as i32;
125 let qi0 = qi;
126
127 let decay_bound = old_e_val.max(-28.0) - max_decay;
128 if qi < 0 && x < decay_bound {
129 qi += ((decay_bound - x) as i32).max(0);
130 if qi > 0 {
131 qi = 0;
132 }
133 }
134
135 let tell = enc.tell();
136 let bits_left = budget as i32 - tell - 3 * channels as i32 * (end - i) as i32;
137 if i != start && bits_left < 30 {
138 if bits_left < 24 {
139 qi = qi.min(1);
140 }
141 if bits_left < 16 {
142 qi = qi.max(-1);
143 }
144 }
145 if lfe && i >= 2 {
146 qi = qi.min(0);
147 }
148
149 if tell + 15 <= budget as i32 {
150 let prob_idx = 2 * i.min(20);
151 let fs = (prob_model[prob_idx] as u32) << 7;
152 let decay = (prob_model[prob_idx + 1] as i32) << 6;
153 enc.laplace_encode(&mut qi, fs, decay);
154 } else if tell + 2 <= budget as i32 {
155 qi = qi.clamp(-1, 1);
156 enc.encode_icdf(
157 (2 * qi) ^ (if qi < 0 { -1 } else { 0 }),
158 &SMALL_ENERGY_ICDF,
159 2,
160 );
161 } else if tell < budget as i32 {
162 qi = qi.min(0);
163 enc.encode_bit_logp(qi != 0, 1);
164 } else {
165 qi = -1;
166 }
167
168 badness += (qi0 - qi).abs();
169
170 let q = qi as f32;
171 error[c * m.nb_ebands + i] = f - q;
172 let tmp = coef * old_e + prev[c] + q;
173 old_e_bands[c * m.nb_ebands + i] = tmp;
174 prev[c] = prev[c] + q - beta * q;
175 }
176 }
177
178 if lfe { 0 } else { badness }
179}
180
181#[allow(clippy::too_many_arguments)]
182pub fn quant_coarse_energy_advanced(
183 m: &CeltMode,
184 start: usize,
185 end: usize,
186 eff_end: usize,
187 e_bands: &[f32],
188 old_e_bands: &mut [f32],
189 budget: u32,
190 error: &mut [f32],
191 enc: &mut RangeCoder,
192 channels: usize,
193 lm: usize,
194 nb_available_bytes: usize,
195 force_intra: bool,
196 delayed_intra: &mut f32,
197 mut two_pass: bool,
198 loss_rate: i32,
199 lfe: bool,
200) {
201 let mut intra = force_intra
202 || (!two_pass
203 && *delayed_intra > 2.0 * channels as f32 * (end.saturating_sub(start)) as f32
204 && nb_available_bytes > (end.saturating_sub(start)) * channels);
205
206 let intra_bias = ((budget as f32) * (*delayed_intra) * (loss_rate as f32)
207 / ((channels as f32) * 512.0)) as i32;
208 let new_distortion =
209 loss_distortion(e_bands, old_e_bands, start, eff_end, m.nb_ebands, channels);
210
211 let tell = enc.tell();
212 if tell + 3 > budget as i32 {
213 two_pass = false;
214 intra = false;
215 }
216
217 let mut max_decay = if end - start > 10 {
218 16.0f32.min(0.125 * nb_available_bytes as f32)
219 } else {
220 16.0f32
221 };
222 if lfe {
223 max_decay = 3.0;
224 }
225
226 let enc_start_state = enc.clone();
227 let mut old_e_bands_intra = old_e_bands.to_vec();
228 let mut error_intra = error.to_vec();
229 let mut badness1 = 0i32;
230 let mut tell_intra = 0i32;
231 let intra_prob = &E_PROB_MODEL[lm][1];
232
233 if two_pass || intra {
234 badness1 = quant_coarse_energy_impl(
235 m,
236 start,
237 end,
238 e_bands,
239 &mut old_e_bands_intra,
240 budget,
241 tell,
242 intra_prob,
243 &mut error_intra,
244 enc,
245 channels,
246 lm,
247 true,
248 max_decay,
249 lfe,
250 );
251 tell_intra = crate::tell_frac_inline!(enc);
252 }
253
254 if !intra {
255 let enc_intra_state = enc.clone();
256
257 *enc = enc_start_state.clone();
258 let inter_prob = &E_PROB_MODEL[lm][0];
259 let badness2 = quant_coarse_energy_impl(
260 m,
261 start,
262 end,
263 e_bands,
264 old_e_bands,
265 budget,
266 tell,
267 inter_prob,
268 error,
269 enc,
270 channels,
271 lm,
272 false,
273 max_decay,
274 lfe,
275 );
276
277 if two_pass
278 && (badness1 < badness2
279 || (badness1 == badness2
280 && crate::tell_frac_inline!(enc) + intra_bias > tell_intra))
281 {
282 *enc = enc_intra_state;
283 old_e_bands.copy_from_slice(&old_e_bands_intra);
284 error.copy_from_slice(&error_intra);
285 intra = true;
286 }
287 } else {
288 old_e_bands.copy_from_slice(&old_e_bands_intra);
289 error.copy_from_slice(&error_intra);
290 }
291
292 if intra {
293 *delayed_intra = new_distortion;
294 } else {
295 let pred2 = PRED_COEF[lm] * PRED_COEF[lm];
296 *delayed_intra = pred2 * *delayed_intra + new_distortion;
297 }
298}
299
300#[allow(clippy::too_many_arguments)]
301pub fn quant_coarse_energy(
302 m: &CeltMode,
303 start: usize,
304 end: usize,
305 e_bands: &[f32],
306 old_e_bands: &mut [f32],
307 budget: u32,
308 error: &mut [f32],
309 enc: &mut RangeCoder,
310 channels: usize,
311 lm: usize,
312 force_intra: bool,
313 nb_available_bytes: usize,
314) {
315 let mut delayed_intra = 0.0f32;
316 quant_coarse_energy_advanced(
317 m,
318 start,
319 end,
320 end,
321 e_bands,
322 old_e_bands,
323 budget,
324 error,
325 enc,
326 channels,
327 lm,
328 nb_available_bytes,
329 force_intra,
330 &mut delayed_intra,
331 false,
332 0,
333 false,
334 );
335}
336
337#[allow(clippy::too_many_arguments)]
338pub fn unquant_coarse_energy(
339 m: &CeltMode,
340 start: usize,
341 end: usize,
342 old_e_bands: &mut [f32],
343 budget: u32,
344 dec: &mut RangeCoder,
345 channels: usize,
346 lm: usize,
347) {
348 let intra: bool;
349 let tell = dec.tell();
350 if tell + 3 <= budget as i32 {
351 intra = dec.decode_bit_logp(3);
352 } else {
353 intra = false;
354 }
355 let prob_model = &E_PROB_MODEL[lm][if intra { 1 } else { 0 }];
356 let coef = if intra { 0.0 } else { PRED_COEF[lm] };
357 let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
358 debug_assert!(channels <= 2);
359 let mut prev = [0.0f32; 2];
360
361 for i in start..end {
362 for c in 0..channels {
363 old_e_bands[c * m.nb_ebands + i] = old_e_bands[c * m.nb_ebands + i].max(-9.0);
365 let old_e = old_e_bands[c * m.nb_ebands + i];
366
367 let qi;
368 let tell = dec.tell();
369 if tell + 15 <= budget as i32 {
370 let prob_idx = 2 * i.min(20);
371 let fs = (prob_model[prob_idx] as u32) << 7;
372 let decay = (prob_model[prob_idx + 1] as i32) << 6;
373 qi = dec.laplace_decode(fs, decay);
374 } else if tell + 2 <= budget as i32 {
375 let s = dec.decode_icdf(&SMALL_ENERGY_ICDF, 2);
376 qi = (s >> 1) ^ -(s & 1);
377 } else if tell < budget as i32 {
378 qi = if dec.decode_bit_logp(1) { -1 } else { 0 };
379 } else {
380 qi = -1;
381 }
382
383 let q = qi as f32;
384 let tmp = coef * old_e + prev[c] + q;
385 old_e_bands[c * m.nb_ebands + i] = tmp;
386 prev[c] = prev[c] + q - beta * q;
387 }
388 }
389}
390
391#[allow(clippy::too_many_arguments)]
392pub fn quant_fine_energy(
393 m: &CeltMode,
394 start: usize,
395 end: usize,
396 old_e_bands: &mut [f32],
397 error: &mut [f32],
398 fine_quant: &[i32],
399 enc: &mut RangeCoder,
400 channels: usize,
401) {
402 for i in start..end {
403 for c in 0..channels {
404 let bits = fine_quant[c * m.nb_ebands + i];
405 if bits <= 0 {
406 continue;
407 }
408 let mut q = ((error[c * m.nb_ebands + i] + 0.5) * (1 << bits) as f32).floor() as i32;
409 q = q.max(0).min((1 << bits) - 1);
410 enc.enc_bits(q as u32, bits as u32);
411 let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
412 old_e_bands[c * m.nb_ebands + i] += offset;
413 error[c * m.nb_ebands + i] -= offset;
414 }
415 }
416}
417
418pub fn unquant_fine_energy(
419 m: &CeltMode,
420 start: usize,
421 end: usize,
422 old_e_bands: &mut [f32],
423 fine_quant: &[i32],
424 dec: &mut RangeCoder,
425 channels: usize,
426) {
427 for i in start..end {
428 for c in 0..channels {
429 let bits = fine_quant[c * m.nb_ebands + i];
430 if bits <= 0 {
431 continue;
432 }
433 let q = dec.dec_bits(bits as u32);
434 let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
435 old_e_bands[c * m.nb_ebands + i] += offset;
436 }
437 }
438}
439
440#[allow(clippy::too_many_arguments)]
441pub fn quant_energy_finalise(
442 m: &CeltMode,
443 start: usize,
444 end: usize,
445 old_e_bands: &mut [f32],
446 error: &mut [f32],
447 fine_quant: &[i32],
448 fine_priority: &[i32],
449 bits_left: i32,
450 enc: &mut RangeCoder,
451 channels: usize,
452) {
453 let mut bits_left = bits_left;
454 for priority in 0..2 {
455 let mut i = start;
456 while i < end && bits_left >= channels as i32 {
457 if fine_quant[i] >= 8 || fine_priority[i] != priority {
458 i += 1;
459 continue;
460 }
461 let mut c = 0;
462 while c < channels {
463 let q2 = if error[i + c * m.nb_ebands] < 0.0 {
464 0
465 } else {
466 1
467 };
468 enc.enc_bits(q2 as u32, 1);
469 let offset =
470 (q2 as f32 - 0.5) * (1i32 << (14 - fine_quant[i] - 1)) as f32 * (1.0 / 16384.0);
471 old_e_bands[i + c * m.nb_ebands] += offset;
472 error[i + c * m.nb_ebands] -= offset;
473 bits_left -= 1;
474 c += 1;
475 }
476 i += 1;
477 }
478 }
479}
480
481#[allow(clippy::too_many_arguments)]
482pub fn unquant_energy_finalise(
483 m: &CeltMode,
484 start: usize,
485 end: usize,
486 old_e_bands: &mut [f32],
487 fine_quant: &[i32],
488 fine_priority: &[i32],
489 bits_left: i32,
490 dec: &mut RangeCoder,
491 channels: usize,
492) {
493 let mut bits_left = bits_left;
494 for priority in 0..2 {
495 let mut i = start;
496 while i < end && bits_left >= channels as i32 {
497 if fine_quant[i] >= 8 || fine_priority[i] != priority {
498 i += 1;
499 continue;
500 }
501 let mut c = 0;
502 while c < channels {
503 let q2 = dec.dec_bits(1);
504 let offset =
505 (q2 as f32 - 0.5) * (1i32 << (14 - fine_quant[i] - 1)) as f32 * (1.0 / 16384.0);
506 old_e_bands[i + c * m.nb_ebands] += offset;
507 bits_left -= 1;
508 c += 1;
509 }
510 i += 1;
511 }
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use crate::range_coder::RangeCoder;
519
520 #[test]
521 fn test_coarse_fine_energy() {
522 let mode = crate::modes::default_mode();
523 let mut e_bands = vec![0.0; mode.nb_ebands];
524 for (i, v) in e_bands.iter_mut().enumerate() {
525 *v = 5.0 + (i as f32 * 0.5).sin() * 2.0;
526 }
527
528 let mut old_e_bands = vec![0.0; mode.nb_ebands];
529 let mut error = vec![0.0; mode.nb_ebands];
530 let mut enc = RangeCoder::new_encoder(1000);
531
532 quant_coarse_energy(
533 mode,
534 0,
535 mode.nb_ebands,
536 &e_bands,
537 &mut old_e_bands,
538 10000,
539 &mut error,
540 &mut enc,
541 1,
542 3,
543 false,
544 80,
545 );
546
547 let mut fine_quant = vec![0; mode.nb_ebands];
548 for (i, v) in fine_quant.iter_mut().enumerate() {
549 *v = (i % 3) as i32;
550 }
551
552 quant_fine_energy(
553 mode,
554 0,
555 mode.nb_ebands,
556 &mut old_e_bands,
557 &mut error,
558 &fine_quant,
559 &mut enc,
560 1,
561 );
562
563 let mut fine_priority = vec![0i32; mode.nb_ebands];
564 for (i, v) in fine_priority.iter_mut().enumerate() {
565 *v = (i % 2) as i32;
566 }
567
568 quant_energy_finalise(
569 mode,
570 0,
571 mode.nb_ebands,
572 &mut old_e_bands,
573 &mut error,
574 &fine_quant,
575 &fine_priority,
576 10,
577 &mut enc,
578 1,
579 );
580
581 enc.done();
582 let _compressed = &enc.buf;
583
584 let mut dec = RangeCoder::new_decoder(&enc.buf);
585
586 let mut decoded_old_e_bands = vec![0.0; mode.nb_ebands];
587 unquant_coarse_energy(
588 mode,
589 0,
590 mode.nb_ebands,
591 &mut decoded_old_e_bands,
592 10000,
593 &mut dec,
594 1,
595 3,
596 );
597
598 unquant_fine_energy(
599 mode,
600 0,
601 mode.nb_ebands,
602 &mut decoded_old_e_bands,
603 &fine_quant,
604 &mut dec,
605 1,
606 );
607
608 unquant_energy_finalise(
609 mode,
610 0,
611 mode.nb_ebands,
612 &mut decoded_old_e_bands,
613 &fine_quant,
614 &fine_priority,
615 10,
616 &mut dec,
617 1,
618 );
619
620 for i in 0..mode.nb_ebands {
621 if (decoded_old_e_bands[i] - old_e_bands[i]).abs() >= 1e-5 {
622 println!(
623 "Mismatch at band {}: enc={} dec={} diff={}",
624 i,
625 old_e_bands[i],
626 decoded_old_e_bands[i],
627 (decoded_old_e_bands[i] - old_e_bands[i]).abs()
628 );
629 }
630 assert!((decoded_old_e_bands[i] - old_e_bands[i]).abs() < 1e-5);
631 }
632 }
633}