Skip to main content

opus_rs/
quant_bands.rs

1use crate::modes::CeltMode;
2use crate::range_coder::{BITRES, RangeCoder};
3
4pub const PRED_COEF: [f32; 4] = [
5    29440.0 / 32768.0,
6    26112.0 / 32768.0,
7    21248.0 / 32768.0,
8    16384.0 / 32768.0,
9];
10pub const BETA_COEF: [f32; 4] = [
11    30147.0 / 32768.0,
12    22282.0 / 32768.0,
13    12124.0 / 32768.0,
14    6554.0 / 32768.0,
15];
16pub const BETA_INTRA: f32 = 4915.0 / 32768.0;
17
18pub const E_PROB_MODEL: [[[u8; 42]; 2]; 4] = [
19    [
20        [
21            72, 127, 65, 129, 66, 128, 65, 128, 64, 128, 62, 128, 64, 128, 64, 128, 92, 78, 92, 79,
22            92, 78, 90, 79, 116, 41, 115, 40, 114, 40, 132, 26, 132, 26, 145, 17, 161, 12, 176, 10,
23            177, 11,
24        ],
25        [
26            24, 179, 48, 138, 54, 135, 54, 132, 53, 134, 56, 133, 55, 132, 55, 132, 61, 114, 70,
27            96, 74, 88, 75, 88, 87, 74, 89, 66, 91, 67, 100, 59, 108, 50, 120, 40, 122, 37, 97, 43,
28            78, 50,
29        ],
30    ],
31    [
32        [
33            83, 78, 84, 81, 88, 75, 86, 74, 87, 71, 90, 73, 93, 74, 93, 74, 109, 40, 114, 36, 117,
34            34, 117, 34, 143, 17, 145, 18, 146, 19, 162, 12, 165, 10, 178, 7, 189, 6, 190, 8, 177,
35            9,
36        ],
37        [
38            23, 178, 54, 115, 63, 102, 66, 98, 69, 99, 74, 89, 71, 91, 73, 91, 78, 89, 86, 80, 92,
39            66, 93, 64, 102, 59, 103, 60, 104, 60, 117, 52, 123, 44, 138, 35, 133, 31, 97, 38, 77,
40            45,
41        ],
42    ],
43    [
44        [
45            61, 90, 93, 60, 105, 42, 107, 41, 110, 45, 116, 38, 113, 38, 112, 38, 124, 26, 132, 27,
46            136, 19, 140, 20, 155, 14, 159, 16, 158, 18, 170, 13, 177, 10, 187, 8, 192, 6, 175, 9,
47            159, 10,
48        ],
49        [
50            21, 178, 59, 110, 71, 86, 75, 85, 84, 83, 91, 66, 88, 73, 87, 72, 92, 75, 98, 72, 105,
51            58, 107, 54, 115, 52, 114, 55, 112, 56, 129, 51, 132, 40, 150, 33, 140, 29, 98, 35, 77,
52            42,
53        ],
54    ],
55    [
56        [
57            42, 121, 96, 66, 108, 43, 111, 40, 117, 44, 123, 32, 120, 36, 119, 33, 127, 33, 134,
58            34, 139, 21, 147, 23, 152, 20, 158, 25, 154, 26, 166, 21, 173, 16, 184, 13, 184, 10,
59            150, 13, 139, 15,
60        ],
61        [
62            22, 178, 63, 114, 74, 82, 84, 83, 92, 82, 103, 62, 96, 72, 96, 67, 101, 73, 107, 72,
63            113, 55, 118, 52, 125, 52, 118, 52, 117, 55, 135, 49, 137, 39, 157, 32, 145, 29, 97,
64            33, 77, 40,
65        ],
66    ],
67];
68
69pub const SMALL_ENERGY_ICDF: [u8; 3] = [2, 1, 0];
70
71fn loss_distortion(
72    e_bands: &[f32],
73    old_e_bands: &[f32],
74    start: usize,
75    end: usize,
76    len: usize,
77    channels: usize,
78) -> f32 {
79    let mut dist = 0.0f32;
80    for c in 0..channels {
81        let off = c * len;
82        for i in start..end.min(len) {
83            let d = e_bands[off + i] - old_e_bands[off + i];
84            dist += d * d;
85        }
86    }
87    dist.min(200.0)
88}
89
90#[allow(clippy::too_many_arguments)]
91fn quant_coarse_energy_impl(
92    m: &CeltMode,
93    start: usize,
94    end: usize,
95    e_bands: &[f32],
96    old_e_bands: &mut [f32],
97    budget: u32,
98    tell_start: i32,
99    prob_model: &[u8; 42],
100    error: &mut [f32],
101    enc: &mut RangeCoder,
102    channels: usize,
103    lm: usize,
104    intra: bool,
105    max_decay: f32,
106    lfe: bool,
107) -> i32 {
108    let coef = if intra { 0.0 } else { PRED_COEF[lm] };
109    let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
110    let mut prev = [0.0f32; 2];
111    let mut badness = 0i32;
112
113    if tell_start + 3 <= budget as i32 {
114        enc.encode_bit_logp(intra, 3);
115    }
116
117    for i in start..end {
118        for c in 0..channels {
119            let x = e_bands[c * m.nb_ebands + i];
120            let old_e_val = old_e_bands[c * m.nb_ebands + i];
121            let old_e = old_e_val.max(-9.0);
122            let f = x - coef * old_e - prev[c];
123
124            let mut qi = (f + 0.5).floor() as i32;
125            let qi0 = qi;
126
127            let decay_bound = old_e_val.max(-28.0) - max_decay;
128            if qi < 0 && x < decay_bound {
129                qi += ((decay_bound - x) as i32).max(0);
130                if qi > 0 {
131                    qi = 0;
132                }
133            }
134
135            let tell = enc.tell();
136            let bits_left = budget as i32 - tell - 3 * channels as i32 * (end - i) as i32;
137            if i != start && bits_left < 30 {
138                if bits_left < 24 {
139                    qi = qi.min(1);
140                }
141                if bits_left < 16 {
142                    qi = qi.max(-1);
143                }
144            }
145            if lfe && i >= 2 {
146                qi = qi.min(0);
147            }
148
149            if tell + 15 <= budget as i32 {
150                let prob_idx = 2 * i.min(20);
151                let fs = (prob_model[prob_idx] as u32) << 7;
152                let decay = (prob_model[prob_idx + 1] as i32) << 6;
153                enc.laplace_encode(&mut qi, fs, decay);
154            } else if tell + 2 <= budget as i32 {
155                qi = qi.clamp(-1, 1);
156                enc.encode_icdf(
157                    (2 * qi) ^ (if qi < 0 { -1 } else { 0 }),
158                    &SMALL_ENERGY_ICDF,
159                    2,
160                );
161            } else if tell < budget as i32 {
162                qi = qi.min(0);
163                enc.encode_bit_logp(qi != 0, 1);
164            } else {
165                qi = -1;
166            }
167
168            badness += (qi0 - qi).abs();
169
170            let q = qi as f32;
171            error[c * m.nb_ebands + i] = f - q;
172            let tmp = coef * old_e + prev[c] + q;
173            old_e_bands[c * m.nb_ebands + i] = tmp;
174            prev[c] = prev[c] + q - beta * q;
175        }
176    }
177
178    if lfe { 0 } else { badness }
179}
180
181#[allow(clippy::too_many_arguments)]
182pub fn quant_coarse_energy_advanced(
183    m: &CeltMode,
184    start: usize,
185    end: usize,
186    eff_end: usize,
187    e_bands: &[f32],
188    old_e_bands: &mut [f32],
189    budget: u32,
190    error: &mut [f32],
191    enc: &mut RangeCoder,
192    channels: usize,
193    lm: usize,
194    nb_available_bytes: usize,
195    force_intra: bool,
196    delayed_intra: &mut f32,
197    mut two_pass: bool,
198    loss_rate: i32,
199    lfe: bool,
200) {
201    let mut intra = force_intra
202        || (!two_pass
203            && *delayed_intra > 2.0 * channels as f32 * (end.saturating_sub(start)) as f32
204            && nb_available_bytes > (end.saturating_sub(start)) * channels);
205
206    let intra_bias = ((budget as f32) * (*delayed_intra) * (loss_rate as f32)
207        / ((channels as f32) * 512.0)) as i32;
208    let new_distortion =
209        loss_distortion(e_bands, old_e_bands, start, eff_end, m.nb_ebands, channels);
210
211    let tell = enc.tell();
212    if tell + 3 > budget as i32 {
213        two_pass = false;
214        intra = false;
215    }
216
217    let mut max_decay = if end - start > 10 {
218        16.0f32.min(0.125 * nb_available_bytes as f32)
219    } else {
220        16.0f32
221    };
222    if lfe {
223        max_decay = 3.0;
224    }
225
226    let enc_start_state = enc.clone();
227    let mut old_e_bands_intra = old_e_bands.to_vec();
228    let mut error_intra = error.to_vec();
229    let mut badness1 = 0i32;
230    let mut tell_intra = 0i32;
231    let intra_prob = &E_PROB_MODEL[lm][1];
232
233    if two_pass || intra {
234        badness1 = quant_coarse_energy_impl(
235            m,
236            start,
237            end,
238            e_bands,
239            &mut old_e_bands_intra,
240            budget,
241            tell,
242            intra_prob,
243            &mut error_intra,
244            enc,
245            channels,
246            lm,
247            true,
248            max_decay,
249            lfe,
250        );
251        tell_intra = crate::tell_frac_inline!(enc);
252    }
253
254    if !intra {
255        let enc_intra_state = enc.clone();
256
257        *enc = enc_start_state.clone();
258        let inter_prob = &E_PROB_MODEL[lm][0];
259        let badness2 = quant_coarse_energy_impl(
260            m,
261            start,
262            end,
263            e_bands,
264            old_e_bands,
265            budget,
266            tell,
267            inter_prob,
268            error,
269            enc,
270            channels,
271            lm,
272            false,
273            max_decay,
274            lfe,
275        );
276
277        if two_pass
278            && (badness1 < badness2
279                || (badness1 == badness2
280                    && crate::tell_frac_inline!(enc) + intra_bias > tell_intra))
281        {
282            *enc = enc_intra_state;
283            old_e_bands.copy_from_slice(&old_e_bands_intra);
284            error.copy_from_slice(&error_intra);
285            intra = true;
286        }
287    } else {
288        old_e_bands.copy_from_slice(&old_e_bands_intra);
289        error.copy_from_slice(&error_intra);
290    }
291
292    if intra {
293        *delayed_intra = new_distortion;
294    } else {
295        let pred2 = PRED_COEF[lm] * PRED_COEF[lm];
296        *delayed_intra = pred2 * *delayed_intra + new_distortion;
297    }
298}
299
300#[allow(clippy::too_many_arguments)]
301pub fn quant_coarse_energy(
302    m: &CeltMode,
303    start: usize,
304    end: usize,
305    e_bands: &[f32],
306    old_e_bands: &mut [f32],
307    budget: u32,
308    error: &mut [f32],
309    enc: &mut RangeCoder,
310    channels: usize,
311    lm: usize,
312    force_intra: bool,
313    nb_available_bytes: usize,
314) {
315    let mut delayed_intra = 0.0f32;
316    quant_coarse_energy_advanced(
317        m,
318        start,
319        end,
320        end,
321        e_bands,
322        old_e_bands,
323        budget,
324        error,
325        enc,
326        channels,
327        lm,
328        nb_available_bytes,
329        force_intra,
330        &mut delayed_intra,
331        false,
332        0,
333        false,
334    );
335}
336
337#[allow(clippy::too_many_arguments)]
338pub fn unquant_coarse_energy(
339    m: &CeltMode,
340    start: usize,
341    end: usize,
342    old_e_bands: &mut [f32],
343    budget: u32,
344    dec: &mut RangeCoder,
345    channels: usize,
346    lm: usize,
347) {
348    let intra: bool;
349    let tell = dec.tell();
350    if tell + 3 <= budget as i32 {
351        intra = dec.decode_bit_logp(3);
352    } else {
353        intra = false;
354    }
355    let prob_model = &E_PROB_MODEL[lm][if intra { 1 } else { 0 }];
356    let coef = if intra { 0.0 } else { PRED_COEF[lm] };
357    let beta = if intra { BETA_INTRA } else { BETA_COEF[lm] };
358    debug_assert!(channels <= 2);
359    let mut prev = [0.0f32; 2];
360
361    for i in start..end {
362        for c in 0..channels {
363            // Clamp in-place, matching C: oldEBands[i] = MAXG(-GCONST(9.f), oldEBands[i])
364            old_e_bands[c * m.nb_ebands + i] = old_e_bands[c * m.nb_ebands + i].max(-9.0);
365            let old_e = old_e_bands[c * m.nb_ebands + i];
366
367            let qi;
368            let tell = dec.tell();
369            if tell + 15 <= budget as i32 {
370                let prob_idx = 2 * i.min(20);
371                let fs = (prob_model[prob_idx] as u32) << 7;
372                let decay = (prob_model[prob_idx + 1] as i32) << 6;
373                qi = dec.laplace_decode(fs, decay);
374            } else if tell + 2 <= budget as i32 {
375                let s = dec.decode_icdf(&SMALL_ENERGY_ICDF, 2);
376                qi = (s >> 1) ^ -(s & 1);
377            } else if tell < budget as i32 {
378                qi = if dec.decode_bit_logp(1) { -1 } else { 0 };
379            } else {
380                qi = -1;
381            }
382
383            let q = qi as f32;
384            let tmp = coef * old_e + prev[c] + q;
385            old_e_bands[c * m.nb_ebands + i] = tmp;
386            prev[c] = prev[c] + q - beta * q;
387        }
388    }
389}
390
391#[allow(clippy::too_many_arguments)]
392pub fn quant_fine_energy(
393    m: &CeltMode,
394    start: usize,
395    end: usize,
396    old_e_bands: &mut [f32],
397    error: &mut [f32],
398    fine_quant: &[i32],
399    enc: &mut RangeCoder,
400    channels: usize,
401) {
402    for i in start..end {
403        for c in 0..channels {
404            let bits = fine_quant[c * m.nb_ebands + i];
405            if bits <= 0 {
406                continue;
407            }
408            let mut q = ((error[c * m.nb_ebands + i] + 0.5) * (1 << bits) as f32).floor() as i32;
409            q = q.max(0).min((1 << bits) - 1);
410            enc.enc_bits(q as u32, bits as u32);
411            let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
412            old_e_bands[c * m.nb_ebands + i] += offset;
413            error[c * m.nb_ebands + i] -= offset;
414        }
415    }
416}
417
418pub fn unquant_fine_energy(
419    m: &CeltMode,
420    start: usize,
421    end: usize,
422    old_e_bands: &mut [f32],
423    fine_quant: &[i32],
424    dec: &mut RangeCoder,
425    channels: usize,
426) {
427    for i in start..end {
428        for c in 0..channels {
429            let bits = fine_quant[c * m.nb_ebands + i];
430            if bits <= 0 {
431                continue;
432            }
433            let q = dec.dec_bits(bits as u32);
434            let offset = (q as f32 + 0.5) / (1 << bits) as f32 - 0.5;
435            old_e_bands[c * m.nb_ebands + i] += offset;
436        }
437    }
438}
439
440#[allow(clippy::too_many_arguments)]
441pub fn quant_energy_finalise(
442    m: &CeltMode,
443    start: usize,
444    end: usize,
445    old_e_bands: &mut [f32],
446    error: &mut [f32],
447    fine_quant: &[i32],
448    fine_priority: &[i32],
449    bits_left: i32,
450    enc: &mut RangeCoder,
451    channels: usize,
452) {
453    let mut bits_left = bits_left;
454    for priority in 0..2 {
455        let mut i = start;
456        while i < end && bits_left >= channels as i32 {
457            if fine_quant[i] >= 8 || fine_priority[i] != priority {
458                i += 1;
459                continue;
460            }
461            let mut c = 0;
462            while c < channels {
463                let q2 = if error[i + c * m.nb_ebands] < 0.0 {
464                    0
465                } else {
466                    1
467                };
468                enc.enc_bits(q2 as u32, 1);
469                let offset =
470                    (q2 as f32 - 0.5) * (1i32 << (14 - fine_quant[i] - 1)) as f32 * (1.0 / 16384.0);
471                old_e_bands[i + c * m.nb_ebands] += offset;
472                error[i + c * m.nb_ebands] -= offset;
473                bits_left -= 1;
474                c += 1;
475            }
476            i += 1;
477        }
478    }
479}
480
481#[allow(clippy::too_many_arguments)]
482pub fn unquant_energy_finalise(
483    m: &CeltMode,
484    start: usize,
485    end: usize,
486    old_e_bands: &mut [f32],
487    fine_quant: &[i32],
488    fine_priority: &[i32],
489    bits_left: i32,
490    dec: &mut RangeCoder,
491    channels: usize,
492) {
493    let mut bits_left = bits_left;
494    for priority in 0..2 {
495        let mut i = start;
496        while i < end && bits_left >= channels as i32 {
497            if fine_quant[i] >= 8 || fine_priority[i] != priority {
498                i += 1;
499                continue;
500            }
501            let mut c = 0;
502            while c < channels {
503                let q2 = dec.dec_bits(1);
504                let offset =
505                    (q2 as f32 - 0.5) * (1i32 << (14 - fine_quant[i] - 1)) as f32 * (1.0 / 16384.0);
506                old_e_bands[i + c * m.nb_ebands] += offset;
507                bits_left -= 1;
508                c += 1;
509            }
510            i += 1;
511        }
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use crate::range_coder::RangeCoder;
519
520    #[test]
521    fn test_coarse_fine_energy() {
522        let mode = crate::modes::default_mode();
523        let mut e_bands = vec![0.0; mode.nb_ebands];
524        for (i, v) in e_bands.iter_mut().enumerate() {
525            *v = 5.0 + (i as f32 * 0.5).sin() * 2.0;
526        }
527
528        let mut old_e_bands = vec![0.0; mode.nb_ebands];
529        let mut error = vec![0.0; mode.nb_ebands];
530        let mut enc = RangeCoder::new_encoder(1000);
531
532        quant_coarse_energy(
533            mode,
534            0,
535            mode.nb_ebands,
536            &e_bands,
537            &mut old_e_bands,
538            10000,
539            &mut error,
540            &mut enc,
541            1,
542            3,
543            false,
544            80,
545        );
546
547        let mut fine_quant = vec![0; mode.nb_ebands];
548        for (i, v) in fine_quant.iter_mut().enumerate() {
549            *v = (i % 3) as i32;
550        }
551
552        quant_fine_energy(
553            mode,
554            0,
555            mode.nb_ebands,
556            &mut old_e_bands,
557            &mut error,
558            &fine_quant,
559            &mut enc,
560            1,
561        );
562
563        let mut fine_priority = vec![0i32; mode.nb_ebands];
564        for (i, v) in fine_priority.iter_mut().enumerate() {
565            *v = (i % 2) as i32;
566        }
567
568        quant_energy_finalise(
569            mode,
570            0,
571            mode.nb_ebands,
572            &mut old_e_bands,
573            &mut error,
574            &fine_quant,
575            &fine_priority,
576            10,
577            &mut enc,
578            1,
579        );
580
581        enc.done();
582        let _compressed = &enc.buf;
583
584        let mut dec = RangeCoder::new_decoder(&enc.buf);
585
586        let mut decoded_old_e_bands = vec![0.0; mode.nb_ebands];
587        unquant_coarse_energy(
588            mode,
589            0,
590            mode.nb_ebands,
591            &mut decoded_old_e_bands,
592            10000,
593            &mut dec,
594            1,
595            3,
596        );
597
598        unquant_fine_energy(
599            mode,
600            0,
601            mode.nb_ebands,
602            &mut decoded_old_e_bands,
603            &fine_quant,
604            &mut dec,
605            1,
606        );
607
608        unquant_energy_finalise(
609            mode,
610            0,
611            mode.nb_ebands,
612            &mut decoded_old_e_bands,
613            &fine_quant,
614            &fine_priority,
615            10,
616            &mut dec,
617            1,
618        );
619
620        for i in 0..mode.nb_ebands {
621            if (decoded_old_e_bands[i] - old_e_bands[i]).abs() >= 1e-5 {
622                println!(
623                    "Mismatch at band {}: enc={} dec={} diff={}",
624                    i,
625                    old_e_bands[i],
626                    decoded_old_e_bands[i],
627                    (decoded_old_e_bands[i] - old_e_bands[i]).abs()
628                );
629            }
630            assert!((decoded_old_e_bands[i] - old_e_bands[i]).abs() < 1e-5);
631        }
632    }
633}