1use crate::modes::CeltMode;
2use crate::range_coder::RangeCoder;
3
4pub const PRED_COEF: [f32; 4] = [
5 29440.0 / 32768.0,
6 26112.0 / 32768.0,
7 21248.0 / 32768.0,
8 16384.0 / 32768.0,
9];
10pub const BETA_COEF: [f32; 4] = [
11 30147.0 / 32768.0,
12 22282.0 / 32768.0,
13 12124.0 / 32768.0,
14 6554.0 / 32768.0,
15];
16pub const BETA_INTRA: f32 = 4915.0 / 32768.0;
17
18pub const E_PROB_MODEL: [[[u8; 42]; 2]; 4] = [
19 [
20 [
21 72, 127, 65, 129, 66, 128, 65, 128, 64, 128, 62, 128, 64, 128, 64, 128, 92, 78, 92, 79,
22 92, 78, 90, 79, 116, 41, 115, 40, 114, 40, 132, 26, 132, 26, 145, 17, 161, 12, 176, 10,
23 177, 11,
24 ],
25 [
26 24, 179, 48, 138, 54, 135, 54, 132, 53, 134, 56, 133, 55, 132, 55, 132, 61, 114, 70,
27 96, 74, 88, 75, 88, 87, 74, 89, 66, 91, 67, 100, 59, 108, 50, 120, 40, 122, 37, 97, 43,
28 78, 50,
29 ],
30 ],
31 [
32 [
33 83, 78, 84, 81, 88, 75, 86, 74, 87, 71, 90, 73, 93, 74, 93, 74, 109, 40, 114, 36, 117,
34 34, 117, 34, 143, 17, 145, 18, 146, 19, 162, 12, 165, 10, 178, 7, 189, 6, 190, 8, 177,
35 9,
36 ],
37 [
38 23, 178, 54, 115, 63, 102, 66, 98, 69, 99, 74, 89, 71, 91, 73, 91, 78, 89, 86, 80, 92,
39 66, 93, 64, 102, 59, 103, 60, 104, 60, 117, 52, 123, 44, 138, 35, 133, 31, 97, 38, 77,
40 45,
41 ],
42 ],
43 [
44 [
45 61, 90, 93, 60, 105, 42, 107, 41, 110, 45, 116, 38, 113, 38, 112, 38, 124, 26, 132, 27,
46 136, 19, 140, 20, 155, 14, 159, 16, 158, 18, 170, 13, 177, 10, 187, 8, 192, 6, 175, 9,
47 159, 10,
48 ],
49 [
50 21, 178, 59, 110, 71, 86, 75, 85, 84, 83, 91, 66, 88, 73, 87, 72, 92, 75, 98, 72, 105,
51 58, 107, 54, 115, 52, 114, 55, 112, 56, 129, 51, 132, 40, 150, 33, 140, 29, 98, 35, 77,
52 42,
53 ],
54 ],
55 [
56 [
57 42, 121, 96, 66, 108, 43, 111, 40, 117, 44, 123, 32, 120, 36, 119, 33, 127, 33, 134,
58 34, 139, 21, 147, 23, 152, 20, 158, 25, 154, 26, 166, 21, 173, 16, 184, 13, 184, 10,
59 150, 13, 139, 15,
60 ],
61 [
62 22, 178, 63, 114, 74, 82, 84, 83, 92, 82, 103, 62, 96, 72, 96, 67, 101, 73, 107, 72,
63 113, 55, 118, 52, 125, 52, 118, 52, 117, 55, 135, 49, 137, 39, 157, 32, 145, 29, 97,
64 33, 77, 40,
65 ],
66 ],
67];
68
69pub const SMALL_ENERGY_ICDF: [u8; 3] = [2, 1, 0];
70
71#[allow(clippy::too_many_arguments)]
72pub fn quant_coarse_energy(
73 m: &CeltMode,
74 start: usize,
75 end: usize,
76 e_bands: &[f32],
77 old_e_bands: &mut [f32],
78 budget: u32,
79 error: &mut [f32],
80 enc: &mut RangeCoder,
81 channels: usize,
82 lm: usize,
83 intra: bool,
84 nb_available_bytes: usize,
85) {
86 let prob_model = &E_PROB_MODEL[lm][if intra { 1 } else { 0 }];
87 let coef = if intra { 0.0 } else { PRED_COEF[lm] };
88 let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
89 debug_assert!(channels <= 2);
90 let mut prev = [0.0f32; 2];
91
92 let max_decay = if end - start > 10 {
94 16.0f32.min(0.125 * nb_available_bytes as f32)
95 } else {
96 16.0f32
97 };
98
99 enc.encode_bit_logp(intra, 3);
100
101 for i in start..end {
102 for c in 0..channels {
103 let x = e_bands[c * m.nb_ebands + i];
104 let old_e_val = old_e_bands[c * m.nb_ebands + i];
105 let old_e = old_e_val.max(-9.0);
106 let f = x - coef * old_e - prev[c];
107
108 let mut qi = (f + 0.5).floor() as i32;
109
110 let decay_bound = old_e_val.max(-28.0) - max_decay;
111 if qi < 0 && x < decay_bound {
112 qi += (decay_bound - x).floor() as i32;
113 if qi > 0 {
114 qi = 0;
115 }
116 }
117
118 let tell = enc.tell();
119 let bits_left = budget as i32 - tell - 3 * channels as i32 * (end - i) as i32;
120 if i != start && bits_left < 30 {
121 if bits_left < 24 {
122 qi = qi.min(1);
123 }
124 if bits_left < 16 {
125 qi = qi.max(-1);
126 }
127 }
128
129 if tell + 15 <= budget as i32 {
130 let prob_idx = 2 * i.min(20);
131 let fs = (prob_model[prob_idx] as u32) << 7;
132 let decay = (prob_model[prob_idx + 1] as i32) << 6;
133 enc.laplace_encode(&mut qi, fs, decay);
134 } else if tell + 2 <= budget as i32 {
135 qi = qi.clamp(-1, 1);
136 enc.encode_icdf(
137 (2 * qi) ^ (if qi < 0 { -1 } else { 0 }),
138 &SMALL_ENERGY_ICDF,
139 2,
140 );
141 } else if tell < budget as i32 {
142 qi = qi.min(0);
143 enc.encode_bit_logp(qi != 0, 1);
144 } else {
145 qi = -1;
146 }
147
148 let q = qi as f32;
149 error[c * m.nb_ebands + i] = f - q;
150 let tmp = coef * old_e + prev[c] + q;
151 old_e_bands[c * m.nb_ebands + i] = tmp;
152 prev[c] = prev[c] + q - beta * q;
153
154 if i < 3 {}
155 }
156 }
157}
158
159#[allow(clippy::too_many_arguments)]
160pub fn unquant_coarse_energy(
161 m: &CeltMode,
162 start: usize,
163 end: usize,
164 old_e_bands: &mut [f32],
165 budget: u32,
166 dec: &mut RangeCoder,
167 channels: usize,
168 lm: usize,
169 mut intra: bool,
170) {
171 let tell = dec.tell();
172 if tell + 3 <= budget as i32 {
173 intra = dec.decode_bit_logp(3);
174 }
175 let prob_model = &E_PROB_MODEL[lm][if intra { 1 } else { 0 }];
176 let coef = if intra { 0.0 } else { PRED_COEF[lm] };
177 let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
178 debug_assert!(channels <= 2);
179 let mut prev = [0.0f32; 2];
180
181 for i in start..end {
182 for c in 0..channels {
183 let old_e = old_e_bands[c * m.nb_ebands + i].max(-9.0);
184
185 let qi;
186 let tell = dec.tell();
187 if tell + 15 <= budget as i32 {
188 let prob_idx = 2 * i.min(20);
189 let fs = (prob_model[prob_idx] as u32) << 7;
190 let decay = (prob_model[prob_idx + 1] as i32) << 6;
191 qi = dec.laplace_decode(fs, decay);
192 } else if tell + 2 <= budget as i32 {
193 let s = dec.decode_icdf(&SMALL_ENERGY_ICDF, 2);
194 qi = (s >> 1) ^ -(s & 1);
195 } else if tell < budget as i32 {
196 qi = if dec.decode_bit_logp(1) { -1 } else { 0 };
197 } else {
198 qi = -1;
199 }
200
201 let q = qi as f32;
202 let tmp = coef * old_e + prev[c] + q;
203 old_e_bands[c * m.nb_ebands + i] = tmp;
204 prev[c] = prev[c] + q - beta * q;
205
206 if i < 3 {}
207 }
208 }
209}
210
211#[allow(clippy::too_many_arguments)]
212pub fn quant_fine_energy(
213 m: &CeltMode,
214 start: usize,
215 end: usize,
216 old_e_bands: &mut [f32],
217 error: &mut [f32],
218 fine_quant: &[i32],
219 enc: &mut RangeCoder,
220 channels: usize,
221) {
222 for i in start..end {
223 for c in 0..channels {
224 let bits = fine_quant[c * m.nb_ebands + i];
225 if bits <= 0 {
226 continue;
227 }
228 let mut q = ((error[c * m.nb_ebands + i] + 0.5) * (1 << bits) as f32).floor() as i32;
229 q = q.max(0).min((1 << bits) - 1);
230 enc.enc_bits(q as u32, bits as u32);
231 let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
232 old_e_bands[c * m.nb_ebands + i] += offset;
233 error[c * m.nb_ebands + i] -= offset;
234 }
235 }
236}
237
238pub fn unquant_fine_energy(
239 m: &CeltMode,
240 start: usize,
241 end: usize,
242 old_e_bands: &mut [f32],
243 fine_quant: &[i32],
244 dec: &mut RangeCoder,
245 channels: usize,
246) {
247 for i in start..end {
248 for c in 0..channels {
249 let bits = fine_quant[c * m.nb_ebands + i];
250 if bits <= 0 {
251 continue;
252 }
253 let q = dec.dec_bits(bits as u32);
254 let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
255 old_e_bands[c * m.nb_ebands + i] += offset;
256 }
257 }
258}
259
260#[allow(clippy::too_many_arguments)]
261pub fn quant_energy_finalise(
262 m: &CeltMode,
263 start: usize,
264 end: usize,
265 old_e_bands: &mut [f32],
266 error: &mut [f32],
267 fine_quant: &[i32],
268 fine_priority: &[i32],
269 bits_left: i32,
270 enc: &mut RangeCoder,
271 channels: usize,
272) {
273 let mut bits_left = bits_left;
274 for priority in 0..2 {
275 for i in start..end {
276 for c in 0..channels {
277 if bits_left >= 8
278 && fine_priority[c * m.nb_ebands + i] == priority
279 && fine_quant[c * m.nb_ebands + i] < 7
280 {
281 let q = if error[c * m.nb_ebands + i] >= 0.0 {
282 1
283 } else {
284 0
285 };
286 enc.enc_bits(q as u32, 1);
287 let offset = if q == 1 { 0.25 } else { -0.25 };
288 old_e_bands[c * m.nb_ebands + i] += offset;
289 error[c * m.nb_ebands + i] -= offset;
290 bits_left -= 8;
291 }
292 }
293 }
294 }
295}
296
297#[allow(clippy::too_many_arguments)]
298pub fn unquant_energy_finalise(
299 m: &CeltMode,
300 start: usize,
301 end: usize,
302 old_e_bands: &mut [f32],
303 fine_quant: &[i32],
304 fine_priority: &[i32],
305 bits_left: i32,
306 dec: &mut RangeCoder,
307 channels: usize,
308) {
309 let mut bits_left = bits_left;
310 for priority in 0..2 {
311 for i in start..end {
312 for c in 0..channels {
313 if bits_left >= 8
314 && fine_priority[c * m.nb_ebands + i] == priority
315 && fine_quant[c * m.nb_ebands + i] < 7
316 {
317 let q = dec.dec_bits(1);
318 let offset = if q == 1 { 0.25 } else { -0.25 };
319 old_e_bands[c * m.nb_ebands + i] += offset;
320 bits_left -= 8;
321 }
322 }
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::range_coder::RangeCoder;
331
332 #[test]
333 fn test_coarse_fine_energy() {
334 let mode = crate::modes::default_mode();
335 let mut e_bands = vec![0.0; mode.nb_ebands];
336 for (i, v) in e_bands.iter_mut().enumerate() {
337 *v = 5.0 + (i as f32 * 0.5).sin() * 2.0;
338 }
339
340 let mut old_e_bands = vec![0.0; mode.nb_ebands];
341 let mut error = vec![0.0; mode.nb_ebands];
342 let mut enc = RangeCoder::new_encoder(1000);
343
344 quant_coarse_energy(
345 mode,
346 0,
347 mode.nb_ebands,
348 &e_bands,
349 &mut old_e_bands,
350 10000,
351 &mut error,
352 &mut enc,
353 1,
354 3,
355 false,
356 80,
357 );
358
359 let mut fine_quant = vec![0; mode.nb_ebands];
360 for (i, v) in fine_quant.iter_mut().enumerate() {
361 *v = (i % 3) as i32;
362 }
363
364 quant_fine_energy(
365 mode,
366 0,
367 mode.nb_ebands,
368 &mut old_e_bands,
369 &mut error,
370 &fine_quant,
371 &mut enc,
372 1,
373 );
374
375 let mut fine_priority = vec![0i32; mode.nb_ebands];
376 for (i, v) in fine_priority.iter_mut().enumerate() {
377 *v = (i % 2) as i32;
378 }
379
380 quant_energy_finalise(
381 mode,
382 0,
383 mode.nb_ebands,
384 &mut old_e_bands,
385 &mut error,
386 &fine_quant,
387 &fine_priority,
388 10,
389 &mut enc,
390 1,
391 );
392
393 enc.done();
394 let _compressed = &enc.buf;
395
396 let mut dec = RangeCoder::new_decoder(&enc.buf);
397
398 let mut decoded_old_e_bands = vec![0.0; mode.nb_ebands];
399 unquant_coarse_energy(
400 mode,
401 0,
402 mode.nb_ebands,
403 &mut decoded_old_e_bands,
404 10000,
405 &mut dec,
406 1,
407 3,
408 false,
409 );
410
411 unquant_fine_energy(
412 mode,
413 0,
414 mode.nb_ebands,
415 &mut decoded_old_e_bands,
416 &fine_quant,
417 &mut dec,
418 1,
419 );
420
421 unquant_energy_finalise(
422 mode,
423 0,
424 mode.nb_ebands,
425 &mut decoded_old_e_bands,
426 &fine_quant,
427 &fine_priority,
428 10,
429 &mut dec,
430 1,
431 );
432
433 for i in 0..mode.nb_ebands {
434 if (decoded_old_e_bands[i] - old_e_bands[i]).abs() >= 1e-5 {
435 println!(
436 "Mismatch at band {}: enc={} dec={} diff={}",
437 i,
438 old_e_bands[i],
439 decoded_old_e_bands[i],
440 (decoded_old_e_bands[i] - old_e_bands[i]).abs()
441 );
442 }
443 assert!((decoded_old_e_bands[i] - old_e_bands[i]).abs() < 1e-5);
444 }
445 }
446}