Skip to main content

opus_rs/
quant_bands.rs

1use crate::modes::CeltMode;
2use crate::range_coder::{BITRES, RangeCoder};
3
4pub const PRED_COEF: [f32; 4] = [
5    29440.0 / 32768.0,
6    26112.0 / 32768.0,
7    21248.0 / 32768.0,
8    16384.0 / 32768.0,
9];
10pub const BETA_COEF: [f32; 4] = [
11    30147.0 / 32768.0,
12    22282.0 / 32768.0,
13    12124.0 / 32768.0,
14    6554.0 / 32768.0,
15];
16pub const BETA_INTRA: f32 = 4915.0 / 32768.0;
17
18pub const E_PROB_MODEL: [[[u8; 42]; 2]; 4] = [
19    [
20        [
21            72, 127, 65, 129, 66, 128, 65, 128, 64, 128, 62, 128, 64, 128, 64, 128, 92, 78, 92, 79,
22            92, 78, 90, 79, 116, 41, 115, 40, 114, 40, 132, 26, 132, 26, 145, 17, 161, 12, 176, 10,
23            177, 11,
24        ],
25        [
26            24, 179, 48, 138, 54, 135, 54, 132, 53, 134, 56, 133, 55, 132, 55, 132, 61, 114, 70,
27            96, 74, 88, 75, 88, 87, 74, 89, 66, 91, 67, 100, 59, 108, 50, 120, 40, 122, 37, 97, 43,
28            78, 50,
29        ],
30    ],
31    [
32        [
33            83, 78, 84, 81, 88, 75, 86, 74, 87, 71, 90, 73, 93, 74, 93, 74, 109, 40, 114, 36, 117,
34            34, 117, 34, 143, 17, 145, 18, 146, 19, 162, 12, 165, 10, 178, 7, 189, 6, 190, 8, 177,
35            9,
36        ],
37        [
38            23, 178, 54, 115, 63, 102, 66, 98, 69, 99, 74, 89, 71, 91, 73, 91, 78, 89, 86, 80, 92,
39            66, 93, 64, 102, 59, 103, 60, 104, 60, 117, 52, 123, 44, 138, 35, 133, 31, 97, 38, 77,
40            45,
41        ],
42    ],
43    [
44        [
45            61, 90, 93, 60, 105, 42, 107, 41, 110, 45, 116, 38, 113, 38, 112, 38, 124, 26, 132, 27,
46            136, 19, 140, 20, 155, 14, 159, 16, 158, 18, 170, 13, 177, 10, 187, 8, 192, 6, 175, 9,
47            159, 10,
48        ],
49        [
50            21, 178, 59, 110, 71, 86, 75, 85, 84, 83, 91, 66, 88, 73, 87, 72, 92, 75, 98, 72, 105,
51            58, 107, 54, 115, 52, 114, 55, 112, 56, 129, 51, 132, 40, 150, 33, 140, 29, 98, 35, 77,
52            42,
53        ],
54    ],
55    [
56        [
57            42, 121, 96, 66, 108, 43, 111, 40, 117, 44, 123, 32, 120, 36, 119, 33, 127, 33, 134,
58            34, 139, 21, 147, 23, 152, 20, 158, 25, 154, 26, 166, 21, 173, 16, 184, 13, 184, 10,
59            150, 13, 139, 15,
60        ],
61        [
62            22, 178, 63, 114, 74, 82, 84, 83, 92, 82, 103, 62, 96, 72, 96, 67, 101, 73, 107, 72,
63            113, 55, 118, 52, 125, 52, 118, 52, 117, 55, 135, 49, 137, 39, 157, 32, 145, 29, 97,
64            33, 77, 40,
65        ],
66    ],
67];
68
69pub const SMALL_ENERGY_ICDF: [u8; 3] = [2, 1, 0];
70
71fn loss_distortion(
72    e_bands: &[f32],
73    old_e_bands: &[f32],
74    start: usize,
75    end: usize,
76    len: usize,
77    channels: usize,
78) -> f32 {
79    let mut dist = 0.0f32;
80    for c in 0..channels {
81        let off = c * len;
82        for i in start..end.min(len) {
83            let d = e_bands[off + i] - old_e_bands[off + i];
84            dist += d * d;
85        }
86    }
87    dist.min(200.0)
88}
89
90#[allow(clippy::too_many_arguments)]
91fn quant_coarse_energy_impl(
92    m: &CeltMode,
93    start: usize,
94    end: usize,
95    e_bands: &[f32],
96    old_e_bands: &mut [f32],
97    budget: u32,
98    tell_start: i32,
99    prob_model: &[u8; 42],
100    error: &mut [f32],
101    enc: &mut RangeCoder,
102    channels: usize,
103    lm: usize,
104    intra: bool,
105    max_decay: f32,
106    lfe: bool,
107) -> i32 {
108    let coef = if intra { 0.0 } else { PRED_COEF[lm] };
109    let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
110    let mut prev = [0.0f32; 2];
111    let mut badness = 0i32;
112
113    if tell_start + 3 <= budget as i32 {
114        enc.encode_bit_logp(intra, 3);
115    }
116
117    for i in start..end {
118        for c in 0..channels {
119            let x = e_bands[c * m.nb_ebands + i];
120            let old_e_val = old_e_bands[c * m.nb_ebands + i];
121            let old_e = old_e_val.max(-9.0);
122            let f = x - coef * old_e - prev[c];
123
124            let mut qi = (f + 0.5).floor() as i32;
125            let qi0 = qi;
126
127            let decay_bound = old_e_val.max(-28.0) - max_decay;
128            if qi < 0 && x < decay_bound {
129                qi += ((decay_bound - x) as i32).max(0);
130                if qi > 0 {
131                    qi = 0;
132                }
133            }
134
135            let tell = enc.tell();
136            let bits_left = budget as i32 - tell - 3 * channels as i32 * (end - i) as i32;
137            if i != start && bits_left < 30 {
138                if bits_left < 24 {
139                    qi = qi.min(1);
140                }
141                if bits_left < 16 {
142                    qi = qi.max(-1);
143                }
144            }
145            if lfe && i >= 2 {
146                qi = qi.min(0);
147            }
148
149            if tell + 15 <= budget as i32 {
150                let prob_idx = 2 * i.min(20);
151                let fs = (prob_model[prob_idx] as u32) << 7;
152                let decay = (prob_model[prob_idx + 1] as i32) << 6;
153                enc.laplace_encode(&mut qi, fs, decay);
154            } else if tell + 2 <= budget as i32 {
155                qi = qi.clamp(-1, 1);
156                enc.encode_icdf(
157                    (2 * qi) ^ (if qi < 0 { -1 } else { 0 }),
158                    &SMALL_ENERGY_ICDF,
159                    2,
160                );
161            } else if tell < budget as i32 {
162                qi = qi.min(0);
163                enc.encode_bit_logp(qi != 0, 1);
164            } else {
165                qi = -1;
166            }
167
168            badness += (qi0 - qi).abs();
169
170            let q = qi as f32;
171            error[c * m.nb_ebands + i] = f - q;
172            let tmp = coef * old_e + prev[c] + q;
173            old_e_bands[c * m.nb_ebands + i] = tmp;
174            prev[c] = prev[c] + q - beta * q;
175        }
176    }
177
178    if lfe { 0 } else { badness }
179}
180
181#[allow(clippy::too_many_arguments)]
182pub fn quant_coarse_energy_advanced(
183    m: &CeltMode,
184    start: usize,
185    end: usize,
186    eff_end: usize,
187    e_bands: &[f32],
188    old_e_bands: &mut [f32],
189    budget: u32,
190    error: &mut [f32],
191    enc: &mut RangeCoder,
192    channels: usize,
193    lm: usize,
194    nb_available_bytes: usize,
195    force_intra: bool,
196    delayed_intra: &mut f32,
197    mut two_pass: bool,
198    loss_rate: i32,
199    lfe: bool,
200) {
201    let mut intra = force_intra
202        || (!two_pass
203            && *delayed_intra > 2.0 * channels as f32 * (end.saturating_sub(start)) as f32
204            && nb_available_bytes > (end.saturating_sub(start)) * channels);
205
206    let intra_bias = ((budget as f32) * (*delayed_intra) * (loss_rate as f32)
207        / ((channels as f32) * 512.0)) as i32;
208    let new_distortion =
209        loss_distortion(e_bands, old_e_bands, start, eff_end, m.nb_ebands, channels);
210
211    let tell = enc.tell();
212    if tell + 3 > budget as i32 {
213        two_pass = false;
214        intra = false;
215    }
216
217    let mut max_decay = if end - start > 10 {
218        16.0f32.min(0.125 * nb_available_bytes as f32)
219    } else {
220        16.0f32
221    };
222    if lfe {
223        max_decay = 3.0;
224    }
225
226    let enc_start_state = enc.clone();
227    let mut old_e_bands_intra = old_e_bands.to_vec();
228    let mut error_intra = error.to_vec();
229    let mut badness1 = 0i32;
230    let mut tell_intra = 0i32;
231    let intra_prob = &E_PROB_MODEL[lm][1];
232
233    if two_pass || intra {
234        badness1 = quant_coarse_energy_impl(
235            m,
236            start,
237            end,
238            e_bands,
239            &mut old_e_bands_intra,
240            budget,
241            tell,
242            intra_prob,
243            &mut error_intra,
244            enc,
245            channels,
246            lm,
247            true,
248            max_decay,
249            lfe,
250        );
251        tell_intra = crate::tell_frac_inline!(enc);
252    }
253
254    if !intra {
255        let enc_intra_state = enc.clone();
256
257        *enc = enc_start_state.clone();
258        let inter_prob = &E_PROB_MODEL[lm][0];
259        let badness2 = quant_coarse_energy_impl(
260            m,
261            start,
262            end,
263            e_bands,
264            old_e_bands,
265            budget,
266            tell,
267            inter_prob,
268            error,
269            enc,
270            channels,
271            lm,
272            false,
273            max_decay,
274            lfe,
275        );
276
277        if two_pass
278            && (badness1 < badness2
279                || (badness1 == badness2
280                    && crate::tell_frac_inline!(enc) + intra_bias > tell_intra))
281        {
282            *enc = enc_intra_state;
283            old_e_bands.copy_from_slice(&old_e_bands_intra);
284            error.copy_from_slice(&error_intra);
285            intra = true;
286        }
287    } else {
288        old_e_bands.copy_from_slice(&old_e_bands_intra);
289        error.copy_from_slice(&error_intra);
290    }
291
292    if intra {
293        *delayed_intra = new_distortion;
294    } else {
295        let pred2 = PRED_COEF[lm] * PRED_COEF[lm];
296        *delayed_intra = pred2 * *delayed_intra + new_distortion;
297    }
298}
299
300#[allow(clippy::too_many_arguments)]
301pub fn quant_coarse_energy(
302    m: &CeltMode,
303    start: usize,
304    end: usize,
305    e_bands: &[f32],
306    old_e_bands: &mut [f32],
307    budget: u32,
308    error: &mut [f32],
309    enc: &mut RangeCoder,
310    channels: usize,
311    lm: usize,
312    force_intra: bool,
313    nb_available_bytes: usize,
314) {
315    let mut delayed_intra = 0.0f32;
316    quant_coarse_energy_advanced(
317        m,
318        start,
319        end,
320        end,
321        e_bands,
322        old_e_bands,
323        budget,
324        error,
325        enc,
326        channels,
327        lm,
328        nb_available_bytes,
329        force_intra,
330        &mut delayed_intra,
331        false,
332        0,
333        false,
334    );
335}
336
337#[allow(clippy::too_many_arguments)]
338pub fn unquant_coarse_energy(
339    m: &CeltMode,
340    start: usize,
341    end: usize,
342    old_e_bands: &mut [f32],
343    intra: bool,
344    dec: &mut RangeCoder,
345    channels: usize,
346    lm: usize,
347) {
348    let prob_model = &E_PROB_MODEL[lm][if intra { 1 } else { 0 }];
349    let coef = if intra { 0.0 } else { PRED_COEF[lm] };
350    let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
351    debug_assert!(channels <= 2);
352    let mut prev = [0.0f32; 2];
353    let budget = (dec.storage * 8) as i32;
354
355    for i in start..end {
356        for c in 0..channels {
357            let qi;
358            let tell = dec.tell();
359            if budget - tell >= 15 {
360                let prob_idx = 2 * i.min(20);
361                let fs = (prob_model[prob_idx] as u32) << 7;
362                let decay = (prob_model[prob_idx + 1] as i32) << 6;
363                qi = dec.laplace_decode(fs, decay);
364            } else if budget - tell >= 2 {
365                let s = dec.decode_icdf(&SMALL_ENERGY_ICDF, 2);
366                qi = (s >> 1) ^ -(s & 1);
367            } else if budget - tell >= 1 {
368                qi = if dec.decode_bit_logp(1) { -1 } else { 0 };
369            } else {
370                qi = -1;
371            }
372
373            // Clamp in-place, matching C: oldEBands[i] = MAXG(-GCONST(9.f), oldEBands[i])
374            old_e_bands[c * m.nb_ebands + i] = old_e_bands[c * m.nb_ebands + i].max(-9.0);
375            let old_e = old_e_bands[c * m.nb_ebands + i];
376
377            let q = qi as f32;
378            let tmp = coef * old_e + prev[c] + q;
379            old_e_bands[c * m.nb_ebands + i] = tmp;
380            prev[c] = prev[c] + q - beta * q;
381        }
382    }
383}
384
385#[allow(clippy::too_many_arguments)]
386pub fn quant_fine_energy(
387    m: &CeltMode,
388    start: usize,
389    end: usize,
390    old_e_bands: &mut [f32],
391    error: &mut [f32],
392    fine_quant: &[i32],
393    enc: &mut RangeCoder,
394    channels: usize,
395) {
396    for i in start..end {
397        for c in 0..channels {
398            let bits = fine_quant[i];
399            if bits <= 0 {
400                continue;
401            }
402            let mut q = ((error[c * m.nb_ebands + i] + 0.5) * (1 << bits) as f32).floor() as i32;
403            q = q.max(0).min((1 << bits) - 1);
404            enc.enc_bits(q as u32, bits as u32);
405            let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
406            old_e_bands[c * m.nb_ebands + i] += offset;
407            error[c * m.nb_ebands + i] -= offset;
408        }
409    }
410}
411
412pub fn unquant_fine_energy(
413    m: &CeltMode,
414    start: usize,
415    end: usize,
416    old_e_bands: &mut [f32],
417    fine_quant: &[i32],
418    dec: &mut RangeCoder,
419    channels: usize,
420) {
421    for i in start..end {
422        for c in 0..channels {
423            let bits = fine_quant[i];
424            if bits <= 0 {
425                continue;
426            }
427            let q = dec.dec_bits(bits as u32);
428            let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
429            old_e_bands[c * m.nb_ebands + i] += offset;
430        }
431    }
432}
433
434#[allow(clippy::too_many_arguments)]
435pub fn quant_energy_finalise(
436    m: &CeltMode,
437    start: usize,
438    end: usize,
439    old_e_bands: &mut [f32],
440    error: &mut [f32],
441    fine_quant: &[i32],
442    fine_priority: &[i32],
443    bits_left: i32,
444    enc: &mut RangeCoder,
445    channels: usize,
446) {
447    let mut bits_left = bits_left;
448    for priority in 0..2 {
449        let mut i = start;
450        while i < end && bits_left >= channels as i32 {
451            if fine_quant[i] >= 8 || fine_priority[i] != priority {
452                i += 1;
453                continue;
454            }
455            let mut c = 0;
456            while c < channels {
457                let q2 = if error[i + c * m.nb_ebands] < 0.0 {
458                    0
459                } else {
460                    1
461                };
462                enc.enc_bits(q2 as u32, 1);
463                let offset =
464                    (q2 as f32 - 0.5) * (1i32 << (14 - fine_quant[i] - 1)) as f32 * (1.0 / 16384.0);
465                old_e_bands[i + c * m.nb_ebands] += offset;
466                error[i + c * m.nb_ebands] -= offset;
467                bits_left -= 1;
468                c += 1;
469            }
470            i += 1;
471        }
472    }
473}
474
475#[allow(clippy::too_many_arguments)]
476pub fn unquant_energy_finalise(
477    m: &CeltMode,
478    start: usize,
479    end: usize,
480    old_e_bands: &mut [f32],
481    fine_quant: &[i32],
482    fine_priority: &[i32],
483    bits_left: i32,
484    dec: &mut RangeCoder,
485    channels: usize,
486) {
487    let mut bits_left = bits_left;
488    for priority in 0..2 {
489        let mut i = start;
490        while i < end && bits_left >= channels as i32 {
491            if fine_quant[i] >= 8 || fine_priority[i] != priority {
492                i += 1;
493                continue;
494            }
495            let mut c = 0;
496            while c < channels {
497                let q2 = dec.dec_bits(1);
498                let offset =
499                    (q2 as f32 - 0.5) * (1i32 << (14 - fine_quant[i] - 1)) as f32 * (1.0 / 16384.0);
500                old_e_bands[i + c * m.nb_ebands] += offset;
501                bits_left -= 1;
502                c += 1;
503            }
504            i += 1;
505        }
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::range_coder::RangeCoder;
513
514    #[test]
515    fn test_coarse_fine_energy() {
516        let mode = crate::modes::default_mode();
517        let mut e_bands = vec![0.0; mode.nb_ebands];
518        for (i, v) in e_bands.iter_mut().enumerate() {
519            *v = 5.0 + (i as f32 * 0.5).sin() * 2.0;
520        }
521
522        let mut old_e_bands = vec![0.0; mode.nb_ebands];
523        let mut error = vec![0.0; mode.nb_ebands];
524        let mut enc = RangeCoder::new_encoder(1000);
525
526        quant_coarse_energy(
527            mode,
528            0,
529            mode.nb_ebands,
530            &e_bands,
531            &mut old_e_bands,
532            10000,
533            &mut error,
534            &mut enc,
535            1,
536            3,
537            false,
538            80,
539        );
540
541        let mut fine_quant = vec![0; mode.nb_ebands];
542        for (i, v) in fine_quant.iter_mut().enumerate() {
543            *v = (i % 3) as i32;
544        }
545
546        quant_fine_energy(
547            mode,
548            0,
549            mode.nb_ebands,
550            &mut old_e_bands,
551            &mut error,
552            &fine_quant,
553            &mut enc,
554            1,
555        );
556
557        let mut fine_priority = vec![0i32; mode.nb_ebands];
558        for (i, v) in fine_priority.iter_mut().enumerate() {
559            *v = (i % 2) as i32;
560        }
561
562        quant_energy_finalise(
563            mode,
564            0,
565            mode.nb_ebands,
566            &mut old_e_bands,
567            &mut error,
568            &fine_quant,
569            &fine_priority,
570            10,
571            &mut enc,
572            1,
573        );
574
575        enc.done();
576        let _compressed = &enc.buf;
577
578        let mut dec = RangeCoder::new_decoder(&enc.buf);
579
580        let mut decoded_old_e_bands = vec![0.0; mode.nb_ebands];
581        let intra = dec.decode_bit_logp(3);
582        unquant_coarse_energy(
583            mode,
584            0,
585            mode.nb_ebands,
586            &mut decoded_old_e_bands,
587            intra,
588            &mut dec,
589            1,
590            3,
591        );
592
593        unquant_fine_energy(
594            mode,
595            0,
596            mode.nb_ebands,
597            &mut decoded_old_e_bands,
598            &fine_quant,
599            &mut dec,
600            1,
601        );
602
603        unquant_energy_finalise(
604            mode,
605            0,
606            mode.nb_ebands,
607            &mut decoded_old_e_bands,
608            &fine_quant,
609            &fine_priority,
610            10,
611            &mut dec,
612            1,
613        );
614
615        for i in 0..mode.nb_ebands {
616            if (decoded_old_e_bands[i] - old_e_bands[i]).abs() >= 1e-5 {
617                println!(
618                    "Mismatch at band {}: enc={} dec={} diff={}",
619                    i,
620                    old_e_bands[i],
621                    decoded_old_e_bands[i],
622                    (decoded_old_e_bands[i] - old_e_bands[i]).abs()
623                );
624            }
625            assert!((decoded_old_e_bands[i] - old_e_bands[i]).abs() < 1e-5);
626        }
627    }
628}