Skip to main content

opus_rs/
quant_bands.rs

1use crate::modes::CeltMode;
2use crate::range_coder::RangeCoder;
3
4pub const PRED_COEF: [f32; 4] = [
5    29440.0 / 32768.0,
6    26112.0 / 32768.0,
7    21248.0 / 32768.0,
8    16384.0 / 32768.0,
9];
10pub const BETA_COEF: [f32; 4] = [
11    30147.0 / 32768.0,
12    22282.0 / 32768.0,
13    12124.0 / 32768.0,
14    6554.0 / 32768.0,
15];
16pub const BETA_INTRA: f32 = 4915.0 / 32768.0;
17
18pub const E_PROB_MODEL: [[[u8; 42]; 2]; 4] = [
19    [
20        [
21            72, 127, 65, 129, 66, 128, 65, 128, 64, 128, 62, 128, 64, 128, 64, 128, 92, 78, 92, 79,
22            92, 78, 90, 79, 116, 41, 115, 40, 114, 40, 132, 26, 132, 26, 145, 17, 161, 12, 176, 10,
23            177, 11,
24        ],
25        [
26            24, 179, 48, 138, 54, 135, 54, 132, 53, 134, 56, 133, 55, 132, 55, 132, 61, 114, 70,
27            96, 74, 88, 75, 88, 87, 74, 89, 66, 91, 67, 100, 59, 108, 50, 120, 40, 122, 37, 97, 43,
28            78, 50,
29        ],
30    ],
31    [
32        [
33            83, 78, 84, 81, 88, 75, 86, 74, 87, 71, 90, 73, 93, 74, 93, 74, 109, 40, 114, 36, 117,
34            34, 117, 34, 143, 17, 145, 18, 146, 19, 162, 12, 165, 10, 178, 7, 189, 6, 190, 8, 177,
35            9,
36        ],
37        [
38            23, 178, 54, 115, 63, 102, 66, 98, 69, 99, 74, 89, 71, 91, 73, 91, 78, 89, 86, 80, 92,
39            66, 93, 64, 102, 59, 103, 60, 104, 60, 117, 52, 123, 44, 138, 35, 133, 31, 97, 38, 77,
40            45,
41        ],
42    ],
43    [
44        [
45            61, 90, 93, 60, 105, 42, 107, 41, 110, 45, 116, 38, 113, 38, 112, 38, 124, 26, 132, 27,
46            136, 19, 140, 20, 155, 14, 159, 16, 158, 18, 170, 13, 177, 10, 187, 8, 192, 6, 175, 9,
47            159, 10,
48        ],
49        [
50            21, 178, 59, 110, 71, 86, 75, 85, 84, 83, 91, 66, 88, 73, 87, 72, 92, 75, 98, 72, 105,
51            58, 107, 54, 115, 52, 114, 55, 112, 56, 129, 51, 132, 40, 150, 33, 140, 29, 98, 35, 77,
52            42,
53        ],
54    ],
55    [
56        [
57            42, 121, 96, 66, 108, 43, 111, 40, 117, 44, 123, 32, 120, 36, 119, 33, 127, 33, 134,
58            34, 139, 21, 147, 23, 152, 20, 158, 25, 154, 26, 166, 21, 173, 16, 184, 13, 184, 10,
59            150, 13, 139, 15,
60        ],
61        [
62            22, 178, 63, 114, 74, 82, 84, 83, 92, 82, 103, 62, 96, 72, 96, 67, 101, 73, 107, 72,
63            113, 55, 118, 52, 125, 52, 118, 52, 117, 55, 135, 49, 137, 39, 157, 32, 145, 29, 97,
64            33, 77, 40,
65        ],
66    ],
67];
68
69pub const SMALL_ENERGY_ICDF: [u8; 3] = [2, 1, 0];
70
71#[allow(clippy::too_many_arguments)]
72pub fn quant_coarse_energy(
73    m: &CeltMode,
74    start: usize,
75    end: usize,
76    e_bands: &[f32],
77    old_e_bands: &mut [f32],
78    budget: u32,
79    error: &mut [f32],
80    enc: &mut RangeCoder,
81    channels: usize,
82    lm: usize,
83    intra: bool,
84    nb_available_bytes: usize,
85) {
86    let prob_model = &E_PROB_MODEL[lm][if intra { 1 } else { 0 }];
87    let coef = if intra { 0.0 } else { PRED_COEF[lm] };
88    let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
89    debug_assert!(channels <= 2);
90    let mut prev = [0.0f32; 2];
91
92    // Match C: max_decay = min(16, 0.125 * nbAvailableBytes) when end-start > 10
93    let max_decay = if end - start > 10 {
94        16.0f32.min(0.125 * nb_available_bytes as f32)
95    } else {
96        16.0f32
97    };
98
99    enc.encode_bit_logp(intra, 3);
100
101    for i in start..end {
102        for c in 0..channels {
103            let x = e_bands[c * m.nb_ebands + i];
104            let old_e_val = old_e_bands[c * m.nb_ebands + i];
105            let old_e = old_e_val.max(-9.0);
106            let f = x - coef * old_e - prev[c];
107
108            let mut qi = (f + 0.5).floor() as i32;
109
110            let decay_bound = old_e_val.max(-28.0) - max_decay;
111            if qi < 0 && x < decay_bound {
112                qi += (decay_bound - x).floor() as i32;
113                if qi > 0 {
114                    qi = 0;
115                }
116            }
117
118            let tell = enc.tell();
119            let bits_left = budget as i32 - tell - 3 * channels as i32 * (end - i) as i32;
120            if i != start && bits_left < 30 {
121                if bits_left < 24 {
122                    qi = qi.min(1);
123                }
124                if bits_left < 16 {
125                    qi = qi.max(-1);
126                }
127            }
128
129            if tell + 15 <= budget as i32 {
130                let prob_idx = 2 * i.min(20);
131                let fs = (prob_model[prob_idx] as u32) << 7;
132                let decay = (prob_model[prob_idx + 1] as i32) << 6;
133                enc.laplace_encode(&mut qi, fs, decay);
134            } else if tell + 2 <= budget as i32 {
135                qi = qi.clamp(-1, 1);
136                enc.encode_icdf(
137                    (2 * qi) ^ (if qi < 0 { -1 } else { 0 }),
138                    &SMALL_ENERGY_ICDF,
139                    2,
140                );
141            } else if tell < budget as i32 {
142                qi = qi.min(0);
143                enc.encode_bit_logp(qi != 0, 1);
144            } else {
145                qi = -1;
146            }
147
148            let q = qi as f32;
149            error[c * m.nb_ebands + i] = f - q;
150            let tmp = coef * old_e + prev[c] + q;
151            old_e_bands[c * m.nb_ebands + i] = tmp;
152            prev[c] = prev[c] + q - beta * q;
153
154            if i < 3 {}
155        }
156    }
157}
158
159#[allow(clippy::too_many_arguments)]
160pub fn unquant_coarse_energy(
161    m: &CeltMode,
162    start: usize,
163    end: usize,
164    old_e_bands: &mut [f32],
165    budget: u32,
166    dec: &mut RangeCoder,
167    channels: usize,
168    lm: usize,
169    mut intra: bool,
170) {
171    let tell = dec.tell();
172    if tell + 3 <= budget as i32 {
173        intra = dec.decode_bit_logp(3);
174    }
175    let prob_model = &E_PROB_MODEL[lm][if intra { 1 } else { 0 }];
176    let coef = if intra { 0.0 } else { PRED_COEF[lm] };
177    let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
178    debug_assert!(channels <= 2);
179    let mut prev = [0.0f32; 2];
180
181    for i in start..end {
182        for c in 0..channels {
183            let old_e = old_e_bands[c * m.nb_ebands + i].max(-9.0);
184
185            let qi;
186            let tell = dec.tell();
187            if tell + 15 <= budget as i32 {
188                let prob_idx = 2 * i.min(20);
189                let fs = (prob_model[prob_idx] as u32) << 7;
190                let decay = (prob_model[prob_idx + 1] as i32) << 6;
191                qi = dec.laplace_decode(fs, decay);
192            } else if tell + 2 <= budget as i32 {
193                let s = dec.decode_icdf(&SMALL_ENERGY_ICDF, 2);
194                qi = (s >> 1) ^ -(s & 1);
195            } else if tell < budget as i32 {
196                qi = if dec.decode_bit_logp(1) { -1 } else { 0 };
197            } else {
198                qi = -1;
199            }
200
201            let q = qi as f32;
202            let tmp = coef * old_e + prev[c] + q;
203            old_e_bands[c * m.nb_ebands + i] = tmp;
204            prev[c] = prev[c] + q - beta * q;
205
206            if i < 3 {}
207        }
208    }
209}
210
211#[allow(clippy::too_many_arguments)]
212pub fn quant_fine_energy(
213    m: &CeltMode,
214    start: usize,
215    end: usize,
216    old_e_bands: &mut [f32],
217    error: &mut [f32],
218    fine_quant: &[i32],
219    enc: &mut RangeCoder,
220    channels: usize,
221) {
222    for i in start..end {
223        for c in 0..channels {
224            let bits = fine_quant[c * m.nb_ebands + i];
225            if bits <= 0 {
226                continue;
227            }
228            let mut q = ((error[c * m.nb_ebands + i] + 0.5) * (1 << bits) as f32).floor() as i32;
229            q = q.max(0).min((1 << bits) - 1);
230            enc.enc_bits(q as u32, bits as u32);
231            let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
232            old_e_bands[c * m.nb_ebands + i] += offset;
233            error[c * m.nb_ebands + i] -= offset;
234        }
235    }
236}
237
238pub fn unquant_fine_energy(
239    m: &CeltMode,
240    start: usize,
241    end: usize,
242    old_e_bands: &mut [f32],
243    fine_quant: &[i32],
244    dec: &mut RangeCoder,
245    channels: usize,
246) {
247    for i in start..end {
248        for c in 0..channels {
249            let bits = fine_quant[c * m.nb_ebands + i];
250            if bits <= 0 {
251                continue;
252            }
253            let q = dec.dec_bits(bits as u32);
254            let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
255            old_e_bands[c * m.nb_ebands + i] += offset;
256        }
257    }
258}
259
260#[allow(clippy::too_many_arguments)]
261pub fn quant_energy_finalise(
262    m: &CeltMode,
263    start: usize,
264    end: usize,
265    old_e_bands: &mut [f32],
266    error: &mut [f32],
267    fine_quant: &[i32],
268    fine_priority: &[i32],
269    bits_left: i32,
270    enc: &mut RangeCoder,
271    channels: usize,
272) {
273    let mut bits_left = bits_left;
274    for priority in 0..2 {
275        for i in start..end {
276            for c in 0..channels {
277                if bits_left >= 8
278                    && fine_priority[c * m.nb_ebands + i] == priority
279                    && fine_quant[c * m.nb_ebands + i] < 7
280                {
281                    let q = if error[c * m.nb_ebands + i] >= 0.0 {
282                        1
283                    } else {
284                        0
285                    };
286                    enc.enc_bits(q as u32, 1);
287                    let offset = if q == 1 { 0.25 } else { -0.25 };
288                    old_e_bands[c * m.nb_ebands + i] += offset;
289                    error[c * m.nb_ebands + i] -= offset;
290                    bits_left -= 8;
291                }
292            }
293        }
294    }
295}
296
297#[allow(clippy::too_many_arguments)]
298pub fn unquant_energy_finalise(
299    m: &CeltMode,
300    start: usize,
301    end: usize,
302    old_e_bands: &mut [f32],
303    fine_quant: &[i32],
304    fine_priority: &[i32],
305    bits_left: i32,
306    dec: &mut RangeCoder,
307    channels: usize,
308) {
309    let mut bits_left = bits_left;
310    for priority in 0..2 {
311        for i in start..end {
312            for c in 0..channels {
313                if bits_left >= 8
314                    && fine_priority[c * m.nb_ebands + i] == priority
315                    && fine_quant[c * m.nb_ebands + i] < 7
316                {
317                    let q = dec.dec_bits(1);
318                    let offset = if q == 1 { 0.25 } else { -0.25 };
319                    old_e_bands[c * m.nb_ebands + i] += offset;
320                    bits_left -= 8;
321                }
322            }
323        }
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use crate::range_coder::RangeCoder;
331
332    #[test]
333    fn test_coarse_fine_energy() {
334        let mode = crate::modes::default_mode();
335        let mut e_bands = vec![0.0; mode.nb_ebands];
336        for (i, v) in e_bands.iter_mut().enumerate() {
337            *v = 5.0 + (i as f32 * 0.5).sin() * 2.0;
338        }
339
340        let mut old_e_bands = vec![0.0; mode.nb_ebands];
341        let mut error = vec![0.0; mode.nb_ebands];
342        let mut enc = RangeCoder::new_encoder(1000);
343
344        quant_coarse_energy(
345            mode,
346            0,
347            mode.nb_ebands,
348            &e_bands,
349            &mut old_e_bands,
350            10000,
351            &mut error,
352            &mut enc,
353            1,
354            3,
355            false,
356            80,
357        );
358
359        let mut fine_quant = vec![0; mode.nb_ebands];
360        for (i, v) in fine_quant.iter_mut().enumerate() {
361            *v = (i % 3) as i32;
362        }
363
364        quant_fine_energy(
365            mode,
366            0,
367            mode.nb_ebands,
368            &mut old_e_bands,
369            &mut error,
370            &fine_quant,
371            &mut enc,
372            1,
373        );
374
375        let mut fine_priority = vec![0i32; mode.nb_ebands];
376        for (i, v) in fine_priority.iter_mut().enumerate() {
377            *v = (i % 2) as i32;
378        }
379
380        quant_energy_finalise(
381            mode,
382            0,
383            mode.nb_ebands,
384            &mut old_e_bands,
385            &mut error,
386            &fine_quant,
387            &fine_priority,
388            10,
389            &mut enc,
390            1,
391        );
392
393        enc.done();
394        let _compressed = &enc.buf;
395
396        let mut dec = RangeCoder::new_decoder(&enc.buf);
397
398        let mut decoded_old_e_bands = vec![0.0; mode.nb_ebands];
399        unquant_coarse_energy(
400            mode,
401            0,
402            mode.nb_ebands,
403            &mut decoded_old_e_bands,
404            10000,
405            &mut dec,
406            1,
407            3,
408            false,
409        );
410
411        unquant_fine_energy(
412            mode,
413            0,
414            mode.nb_ebands,
415            &mut decoded_old_e_bands,
416            &fine_quant,
417            &mut dec,
418            1,
419        );
420
421        unquant_energy_finalise(
422            mode,
423            0,
424            mode.nb_ebands,
425            &mut decoded_old_e_bands,
426            &fine_quant,
427            &fine_priority,
428            10,
429            &mut dec,
430            1,
431        );
432
433        for i in 0..mode.nb_ebands {
434            if (decoded_old_e_bands[i] - old_e_bands[i]).abs() >= 1e-5 {
435                println!(
436                    "Mismatch at band {}: enc={} dec={} diff={}",
437                    i,
438                    old_e_bands[i],
439                    decoded_old_e_bands[i],
440                    (decoded_old_e_bands[i] - old_e_bands[i]).abs()
441                );
442            }
443            assert!((decoded_old_e_bands[i] - old_e_bands[i]).abs() < 1e-5);
444        }
445    }
446}