1use crate::modes::CeltMode;
2use crate::range_coder::{BITRES, RangeCoder};
3
4pub const PRED_COEF: [f32; 4] = [
5 29440.0 / 32768.0,
6 26112.0 / 32768.0,
7 21248.0 / 32768.0,
8 16384.0 / 32768.0,
9];
10pub const BETA_COEF: [f32; 4] = [
11 30147.0 / 32768.0,
12 22282.0 / 32768.0,
13 12124.0 / 32768.0,
14 6554.0 / 32768.0,
15];
16pub const BETA_INTRA: f32 = 4915.0 / 32768.0;
17
18pub const E_PROB_MODEL: [[[u8; 42]; 2]; 4] = [
19 [
20 [
21 72, 127, 65, 129, 66, 128, 65, 128, 64, 128, 62, 128, 64, 128, 64, 128, 92, 78, 92, 79,
22 92, 78, 90, 79, 116, 41, 115, 40, 114, 40, 132, 26, 132, 26, 145, 17, 161, 12, 176, 10,
23 177, 11,
24 ],
25 [
26 24, 179, 48, 138, 54, 135, 54, 132, 53, 134, 56, 133, 55, 132, 55, 132, 61, 114, 70,
27 96, 74, 88, 75, 88, 87, 74, 89, 66, 91, 67, 100, 59, 108, 50, 120, 40, 122, 37, 97, 43,
28 78, 50,
29 ],
30 ],
31 [
32 [
33 83, 78, 84, 81, 88, 75, 86, 74, 87, 71, 90, 73, 93, 74, 93, 74, 109, 40, 114, 36, 117,
34 34, 117, 34, 143, 17, 145, 18, 146, 19, 162, 12, 165, 10, 178, 7, 189, 6, 190, 8, 177,
35 9,
36 ],
37 [
38 23, 178, 54, 115, 63, 102, 66, 98, 69, 99, 74, 89, 71, 91, 73, 91, 78, 89, 86, 80, 92,
39 66, 93, 64, 102, 59, 103, 60, 104, 60, 117, 52, 123, 44, 138, 35, 133, 31, 97, 38, 77,
40 45,
41 ],
42 ],
43 [
44 [
45 61, 90, 93, 60, 105, 42, 107, 41, 110, 45, 116, 38, 113, 38, 112, 38, 124, 26, 132, 27,
46 136, 19, 140, 20, 155, 14, 159, 16, 158, 18, 170, 13, 177, 10, 187, 8, 192, 6, 175, 9,
47 159, 10,
48 ],
49 [
50 21, 178, 59, 110, 71, 86, 75, 85, 84, 83, 91, 66, 88, 73, 87, 72, 92, 75, 98, 72, 105,
51 58, 107, 54, 115, 52, 114, 55, 112, 56, 129, 51, 132, 40, 150, 33, 140, 29, 98, 35, 77,
52 42,
53 ],
54 ],
55 [
56 [
57 42, 121, 96, 66, 108, 43, 111, 40, 117, 44, 123, 32, 120, 36, 119, 33, 127, 33, 134,
58 34, 139, 21, 147, 23, 152, 20, 158, 25, 154, 26, 166, 21, 173, 16, 184, 13, 184, 10,
59 150, 13, 139, 15,
60 ],
61 [
62 22, 178, 63, 114, 74, 82, 84, 83, 92, 82, 103, 62, 96, 72, 96, 67, 101, 73, 107, 72,
63 113, 55, 118, 52, 125, 52, 118, 52, 117, 55, 135, 49, 137, 39, 157, 32, 145, 29, 97,
64 33, 77, 40,
65 ],
66 ],
67];
68
69pub const SMALL_ENERGY_ICDF: [u8; 3] = [2, 1, 0];
70
71fn loss_distortion(
72 e_bands: &[f32],
73 old_e_bands: &[f32],
74 start: usize,
75 end: usize,
76 len: usize,
77 channels: usize,
78) -> f32 {
79 let mut dist = 0.0f32;
80 for c in 0..channels {
81 let off = c * len;
82 for i in start..end.min(len) {
83 let d = e_bands[off + i] - old_e_bands[off + i];
84 dist += d * d;
85 }
86 }
87 dist.min(200.0)
88}
89
90#[allow(clippy::too_many_arguments)]
91fn quant_coarse_energy_impl(
92 m: &CeltMode,
93 start: usize,
94 end: usize,
95 e_bands: &[f32],
96 old_e_bands: &mut [f32],
97 budget: u32,
98 tell_start: i32,
99 prob_model: &[u8; 42],
100 error: &mut [f32],
101 enc: &mut RangeCoder,
102 channels: usize,
103 lm: usize,
104 intra: bool,
105 max_decay: f32,
106 lfe: bool,
107) -> i32 {
108 let coef = if intra { 0.0 } else { PRED_COEF[lm] };
109 let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
110 let mut prev = [0.0f32; 2];
111 let mut badness = 0i32;
112
113 if tell_start + 3 <= budget as i32 {
114 enc.encode_bit_logp(intra, 3);
115 }
116
117 for i in start..end {
118 for c in 0..channels {
119 let x = e_bands[c * m.nb_ebands + i];
120 let old_e_val = old_e_bands[c * m.nb_ebands + i];
121 let old_e = old_e_val.max(-9.0);
122 let f = x - coef * old_e - prev[c];
123
124 let mut qi = (f + 0.5).floor() as i32;
125 let qi0 = qi;
126
127 let decay_bound = old_e_val.max(-28.0) - max_decay;
128 if qi < 0 && x < decay_bound {
129 qi += ((decay_bound - x) as i32).max(0);
130 if qi > 0 {
131 qi = 0;
132 }
133 }
134
135 let tell = enc.tell();
136 let bits_left = budget as i32 - tell - 3 * channels as i32 * (end - i) as i32;
137 if i != start && bits_left < 30 {
138 if bits_left < 24 {
139 qi = qi.min(1);
140 }
141 if bits_left < 16 {
142 qi = qi.max(-1);
143 }
144 }
145 if lfe && i >= 2 {
146 qi = qi.min(0);
147 }
148
149 if tell + 15 <= budget as i32 {
150 let prob_idx = 2 * i.min(20);
151 let fs = (prob_model[prob_idx] as u32) << 7;
152 let decay = (prob_model[prob_idx + 1] as i32) << 6;
153 enc.laplace_encode(&mut qi, fs, decay);
154 } else if tell + 2 <= budget as i32 {
155 qi = qi.clamp(-1, 1);
156 enc.encode_icdf(
157 (2 * qi) ^ (if qi < 0 { -1 } else { 0 }),
158 &SMALL_ENERGY_ICDF,
159 2,
160 );
161 } else if tell < budget as i32 {
162 qi = qi.min(0);
163 enc.encode_bit_logp(qi != 0, 1);
164 } else {
165 qi = -1;
166 }
167
168 badness += (qi0 - qi).abs();
169
170 let q = qi as f32;
171 error[c * m.nb_ebands + i] = f - q;
172 let tmp = coef * old_e + prev[c] + q;
173 old_e_bands[c * m.nb_ebands + i] = tmp;
174 prev[c] = prev[c] + q - beta * q;
175 }
176 }
177
178 if lfe { 0 } else { badness }
179}
180
181#[allow(clippy::too_many_arguments)]
182pub fn quant_coarse_energy_advanced(
183 m: &CeltMode,
184 start: usize,
185 end: usize,
186 eff_end: usize,
187 e_bands: &[f32],
188 old_e_bands: &mut [f32],
189 budget: u32,
190 error: &mut [f32],
191 enc: &mut RangeCoder,
192 channels: usize,
193 lm: usize,
194 nb_available_bytes: usize,
195 force_intra: bool,
196 delayed_intra: &mut f32,
197 mut two_pass: bool,
198 loss_rate: i32,
199 lfe: bool,
200) {
201 let mut intra = force_intra
202 || (!two_pass
203 && *delayed_intra > 2.0 * channels as f32 * (end.saturating_sub(start)) as f32
204 && nb_available_bytes > (end.saturating_sub(start)) * channels);
205
206 let intra_bias = ((budget as f32) * (*delayed_intra) * (loss_rate as f32)
207 / ((channels as f32) * 512.0)) as i32;
208 let new_distortion =
209 loss_distortion(e_bands, old_e_bands, start, eff_end, m.nb_ebands, channels);
210
211 let tell = enc.tell();
212 if tell + 3 > budget as i32 {
213 two_pass = false;
214 intra = false;
215 }
216
217 let mut max_decay = if end - start > 10 {
218 16.0f32.min(0.125 * nb_available_bytes as f32)
219 } else {
220 16.0f32
221 };
222 if lfe {
223 max_decay = 3.0;
224 }
225
226 let enc_start_state = enc.clone();
227 let mut old_e_bands_intra = old_e_bands.to_vec();
228 let mut error_intra = error.to_vec();
229 let mut badness1 = 0i32;
230 let mut tell_intra = 0i32;
231 let intra_prob = &E_PROB_MODEL[lm][1];
232
233 if two_pass || intra {
234 badness1 = quant_coarse_energy_impl(
235 m,
236 start,
237 end,
238 e_bands,
239 &mut old_e_bands_intra,
240 budget,
241 tell,
242 intra_prob,
243 &mut error_intra,
244 enc,
245 channels,
246 lm,
247 true,
248 max_decay,
249 lfe,
250 );
251 tell_intra = crate::tell_frac_inline!(enc);
252 }
253
254 if !intra {
255 let enc_intra_state = enc.clone();
256
257 *enc = enc_start_state.clone();
258 let inter_prob = &E_PROB_MODEL[lm][0];
259 let badness2 = quant_coarse_energy_impl(
260 m,
261 start,
262 end,
263 e_bands,
264 old_e_bands,
265 budget,
266 tell,
267 inter_prob,
268 error,
269 enc,
270 channels,
271 lm,
272 false,
273 max_decay,
274 lfe,
275 );
276
277 if two_pass
278 && (badness1 < badness2
279 || (badness1 == badness2
280 && crate::tell_frac_inline!(enc) + intra_bias > tell_intra))
281 {
282 *enc = enc_intra_state;
283 old_e_bands.copy_from_slice(&old_e_bands_intra);
284 error.copy_from_slice(&error_intra);
285 intra = true;
286 }
287 } else {
288 old_e_bands.copy_from_slice(&old_e_bands_intra);
289 error.copy_from_slice(&error_intra);
290 }
291
292 if intra {
293 *delayed_intra = new_distortion;
294 } else {
295 let pred2 = PRED_COEF[lm] * PRED_COEF[lm];
296 *delayed_intra = pred2 * *delayed_intra + new_distortion;
297 }
298}
299
300#[allow(clippy::too_many_arguments)]
301pub fn quant_coarse_energy(
302 m: &CeltMode,
303 start: usize,
304 end: usize,
305 e_bands: &[f32],
306 old_e_bands: &mut [f32],
307 budget: u32,
308 error: &mut [f32],
309 enc: &mut RangeCoder,
310 channels: usize,
311 lm: usize,
312 force_intra: bool,
313 nb_available_bytes: usize,
314) {
315 let mut delayed_intra = 0.0f32;
316 quant_coarse_energy_advanced(
317 m,
318 start,
319 end,
320 end,
321 e_bands,
322 old_e_bands,
323 budget,
324 error,
325 enc,
326 channels,
327 lm,
328 nb_available_bytes,
329 force_intra,
330 &mut delayed_intra,
331 false,
332 0,
333 false,
334 );
335}
336
337#[allow(clippy::too_many_arguments)]
338pub fn unquant_coarse_energy(
339 m: &CeltMode,
340 start: usize,
341 end: usize,
342 old_e_bands: &mut [f32],
343 intra: bool,
344 dec: &mut RangeCoder,
345 channels: usize,
346 lm: usize,
347) {
348 let prob_model = &E_PROB_MODEL[lm][if intra { 1 } else { 0 }];
349 let coef = if intra { 0.0 } else { PRED_COEF[lm] };
350 let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
351 debug_assert!(channels <= 2);
352 let mut prev = [0.0f32; 2];
353 let budget = (dec.storage * 8) as i32;
354
355 for i in start..end {
356 for c in 0..channels {
357 let qi;
358 let tell = dec.tell();
359 if budget - tell >= 15 {
360 let prob_idx = 2 * i.min(20);
361 let fs = (prob_model[prob_idx] as u32) << 7;
362 let decay = (prob_model[prob_idx + 1] as i32) << 6;
363 qi = dec.laplace_decode(fs, decay);
364 } else if budget - tell >= 2 {
365 let s = dec.decode_icdf(&SMALL_ENERGY_ICDF, 2);
366 qi = (s >> 1) ^ -(s & 1);
367 } else if budget - tell >= 1 {
368 qi = if dec.decode_bit_logp(1) { -1 } else { 0 };
369 } else {
370 qi = -1;
371 }
372
373 old_e_bands[c * m.nb_ebands + i] = old_e_bands[c * m.nb_ebands + i].max(-9.0);
375 let old_e = old_e_bands[c * m.nb_ebands + i];
376
377 let q = qi as f32;
378 let tmp = coef * old_e + prev[c] + q;
379 old_e_bands[c * m.nb_ebands + i] = tmp;
380 prev[c] = prev[c] + q - beta * q;
381 }
382 }
383}
384
385#[allow(clippy::too_many_arguments)]
386pub fn quant_fine_energy(
387 m: &CeltMode,
388 start: usize,
389 end: usize,
390 old_e_bands: &mut [f32],
391 error: &mut [f32],
392 fine_quant: &[i32],
393 enc: &mut RangeCoder,
394 channels: usize,
395) {
396 for i in start..end {
397 for c in 0..channels {
398 let bits = fine_quant[i];
399 if bits <= 0 {
400 continue;
401 }
402 let mut q = ((error[c * m.nb_ebands + i] + 0.5) * (1 << bits) as f32).floor() as i32;
403 q = q.max(0).min((1 << bits) - 1);
404 enc.enc_bits(q as u32, bits as u32);
405 let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
406 old_e_bands[c * m.nb_ebands + i] += offset;
407 error[c * m.nb_ebands + i] -= offset;
408 }
409 }
410}
411
412pub fn unquant_fine_energy(
413 m: &CeltMode,
414 start: usize,
415 end: usize,
416 old_e_bands: &mut [f32],
417 fine_quant: &[i32],
418 dec: &mut RangeCoder,
419 channels: usize,
420) {
421 for i in start..end {
422 for c in 0..channels {
423 let bits = fine_quant[i];
424 if bits <= 0 {
425 continue;
426 }
427 let q = dec.dec_bits(bits as u32);
428 let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
429 old_e_bands[c * m.nb_ebands + i] += offset;
430 }
431 }
432}
433
434#[allow(clippy::too_many_arguments)]
435pub fn quant_energy_finalise(
436 m: &CeltMode,
437 start: usize,
438 end: usize,
439 old_e_bands: &mut [f32],
440 error: &mut [f32],
441 fine_quant: &[i32],
442 fine_priority: &[i32],
443 bits_left: i32,
444 enc: &mut RangeCoder,
445 channels: usize,
446) {
447 let mut bits_left = bits_left;
448 for priority in 0..2 {
449 let mut i = start;
450 while i < end && bits_left >= channels as i32 {
451 if fine_quant[i] >= 8 || fine_priority[i] != priority {
452 i += 1;
453 continue;
454 }
455 let mut c = 0;
456 while c < channels {
457 let q2 = if error[i + c * m.nb_ebands] < 0.0 {
458 0
459 } else {
460 1
461 };
462 enc.enc_bits(q2 as u32, 1);
463 let offset =
464 (q2 as f32 - 0.5) * (1i32 << (14 - fine_quant[i] - 1)) as f32 * (1.0 / 16384.0);
465 old_e_bands[i + c * m.nb_ebands] += offset;
466 error[i + c * m.nb_ebands] -= offset;
467 bits_left -= 1;
468 c += 1;
469 }
470 i += 1;
471 }
472 }
473}
474
475#[allow(clippy::too_many_arguments)]
476pub fn unquant_energy_finalise(
477 m: &CeltMode,
478 start: usize,
479 end: usize,
480 old_e_bands: &mut [f32],
481 fine_quant: &[i32],
482 fine_priority: &[i32],
483 bits_left: i32,
484 dec: &mut RangeCoder,
485 channels: usize,
486) {
487 let mut bits_left = bits_left;
488 for priority in 0..2 {
489 let mut i = start;
490 while i < end && bits_left >= channels as i32 {
491 if fine_quant[i] >= 8 || fine_priority[i] != priority {
492 i += 1;
493 continue;
494 }
495 let mut c = 0;
496 while c < channels {
497 let q2 = dec.dec_bits(1);
498 let offset =
499 (q2 as f32 - 0.5) * (1i32 << (14 - fine_quant[i] - 1)) as f32 * (1.0 / 16384.0);
500 old_e_bands[i + c * m.nb_ebands] += offset;
501 bits_left -= 1;
502 c += 1;
503 }
504 i += 1;
505 }
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use crate::range_coder::RangeCoder;
513
514 #[test]
515 fn test_coarse_fine_energy() {
516 let mode = crate::modes::default_mode();
517 let mut e_bands = vec![0.0; mode.nb_ebands];
518 for (i, v) in e_bands.iter_mut().enumerate() {
519 *v = 5.0 + (i as f32 * 0.5).sin() * 2.0;
520 }
521
522 let mut old_e_bands = vec![0.0; mode.nb_ebands];
523 let mut error = vec![0.0; mode.nb_ebands];
524 let mut enc = RangeCoder::new_encoder(1000);
525
526 quant_coarse_energy(
527 mode,
528 0,
529 mode.nb_ebands,
530 &e_bands,
531 &mut old_e_bands,
532 10000,
533 &mut error,
534 &mut enc,
535 1,
536 3,
537 false,
538 80,
539 );
540
541 let mut fine_quant = vec![0; mode.nb_ebands];
542 for (i, v) in fine_quant.iter_mut().enumerate() {
543 *v = (i % 3) as i32;
544 }
545
546 quant_fine_energy(
547 mode,
548 0,
549 mode.nb_ebands,
550 &mut old_e_bands,
551 &mut error,
552 &fine_quant,
553 &mut enc,
554 1,
555 );
556
557 let mut fine_priority = vec![0i32; mode.nb_ebands];
558 for (i, v) in fine_priority.iter_mut().enumerate() {
559 *v = (i % 2) as i32;
560 }
561
562 quant_energy_finalise(
563 mode,
564 0,
565 mode.nb_ebands,
566 &mut old_e_bands,
567 &mut error,
568 &fine_quant,
569 &fine_priority,
570 10,
571 &mut enc,
572 1,
573 );
574
575 enc.done();
576 let _compressed = &enc.buf;
577
578 let mut dec = RangeCoder::new_decoder(&enc.buf);
579
580 let mut decoded_old_e_bands = vec![0.0; mode.nb_ebands];
581 let intra = dec.decode_bit_logp(3);
582 unquant_coarse_energy(
583 mode,
584 0,
585 mode.nb_ebands,
586 &mut decoded_old_e_bands,
587 intra,
588 &mut dec,
589 1,
590 3,
591 );
592
593 unquant_fine_energy(
594 mode,
595 0,
596 mode.nb_ebands,
597 &mut decoded_old_e_bands,
598 &fine_quant,
599 &mut dec,
600 1,
601 );
602
603 unquant_energy_finalise(
604 mode,
605 0,
606 mode.nb_ebands,
607 &mut decoded_old_e_bands,
608 &fine_quant,
609 &fine_priority,
610 10,
611 &mut dec,
612 1,
613 );
614
615 for i in 0..mode.nb_ebands {
616 if (decoded_old_e_bands[i] - old_e_bands[i]).abs() >= 1e-5 {
617 println!(
618 "Mismatch at band {}: enc={} dec={} diff={}",
619 i,
620 old_e_bands[i],
621 decoded_old_e_bands[i],
622 (decoded_old_e_bands[i] - old_e_bands[i]).abs()
623 );
624 }
625 assert!((decoded_old_e_bands[i] - old_e_bands[i]).abs() < 1e-5);
626 }
627 }
628}