Skip to main content

opus_rs/
rate.rs

1use crate::modes::CeltMode;
2use crate::range_coder::RangeCoder;
3use std::cmp::{max, min};
4
5const MAX_EBANDS: usize = 21;
6pub const BITRES: i32 = 3;
7pub const FINE_OFFSET: i32 = 21;
8pub const QTHETA_OFFSET: i32 = 4;
9pub const QTHETA_OFFSET_TWOPHASE: i32 = 16;
10pub const MAX_FINE_BITS: i32 = 8;
11
12pub const LOG2_FRAC_TABLE: [u8; 24] = [
13    0, 8, 13, 16, 19, 21, 23, 24, 26, 27, 28, 29, 30, 31, 32, 32, 33, 34, 34, 35, 36, 36, 37, 37,
14];
15
16#[inline(always)]
17pub fn get_pulses(i: i32) -> i32 {
18    if i < 8 {
19        i
20    } else {
21        let shift = (i >> 3) - 1;
22        if shift >= 31 {
23            return 0x7FFFFFFF;
24        }
25        (8 + (i & 7)) << shift
26    }
27}
28
29#[inline(always)]
30pub fn bits2pulses(m: &CeltMode, band: usize, mut lm: i32, bits: i32) -> i32 {
31    lm += 1;
32    let idx = lm as usize * m.nb_ebands + band;
33    let cache_index = unsafe { *m.cache.index.get_unchecked(idx) };
34    if cache_index < 0 {
35        return 0;
36    }
37    let cache = &m.cache.bits[cache_index as usize..];
38    let cache_ptr = cache.as_ptr();
39
40    let mut lo = 0i32;
41    let mut hi = unsafe { *cache_ptr } as i32;
42    let bits = bits - 1; // bits--
43
44    unsafe {
45        for _ in 0..6 {
46            // LOG_MAX_PSEUDO = 6
47            let mid = (lo + hi + 1) >> 1; // round up, matches C
48            if *cache_ptr.add(mid as usize) as i32 >= bits {
49                hi = mid;
50            } else {
51                lo = mid;
52            }
53        }
54
55        let lo_val = if lo == 0 {
56            -1i32
57        } else {
58            *cache_ptr.add(lo as usize) as i32
59        };
60        let hi_val = *cache_ptr.add(hi as usize) as i32;
61        if bits - lo_val <= hi_val - bits {
62            lo
63        } else {
64            hi
65        }
66    }
67}
68
69#[inline(always)]
70pub fn pulses2bits(m: &CeltMode, band: usize, mut lm: i32, pulses: i32) -> i32 {
71    if pulses == 0 {
72        return 0;
73    }
74    lm += 1;
75    let idx = lm as usize * m.nb_ebands + band;
76    let cache_index = unsafe { *m.cache.index.get_unchecked(idx) };
77    if cache_index < 0 {
78        return 0;
79    }
80    let cache = &m.cache.bits[cache_index as usize..];
81
82    unsafe { (*cache.as_ptr().add(pulses as usize) as i32) + 1 }
83}
84
85#[allow(clippy::too_many_arguments)]
86pub fn clt_compute_allocation(
87    m: &CeltMode,
88    start: usize,
89    end: usize,
90    offsets: &[i32],
91    cap: &[i32],
92    alloc_trim: i32,
93    intensity: &mut i32,
94    dual_stereo: &mut i32,
95    mut total: i32,
96    balance_out: &mut i32,
97    pulses: &mut [i32],
98    ebits: &mut [i32],
99    fine_priority: &mut [i32],
100    c: i32,
101    lm: i32,
102    rc: &mut RangeCoder,
103    encode: bool,
104    prev: i32,
105    signal_bandwidth: i32,
106) -> i32 {
107    total = max(total, 0);
108    let nb_ebands = m.nb_ebands;
109    let mut skip_start = start;
110
111    let skip_rsv = if total >= (1 << BITRES) {
112        1 << BITRES
113    } else {
114        0
115    };
116    total -= skip_rsv;
117
118    let mut intensity_rsv = 0;
119    let mut dual_stereo_rsv = 0;
120    if c == 2 {
121        intensity_rsv = LOG2_FRAC_TABLE[end - start] as i32;
122        if intensity_rsv > total {
123            intensity_rsv = 0;
124        } else {
125            total -= intensity_rsv;
126            dual_stereo_rsv = if total >= (1 << BITRES) {
127                1 << BITRES
128            } else {
129                0
130            };
131            total -= dual_stereo_rsv;
132        }
133    }
134
135    let mut thresh_buf = [0i32; MAX_EBANDS];
136    let thresh = &mut thresh_buf[..nb_ebands];
137    let mut trim_offset_buf = [0i32; MAX_EBANDS];
138    let trim_offset = &mut trim_offset_buf[..nb_ebands];
139
140    for j in start..end {
141        thresh[j] = max(
142            c << BITRES,
143            ((3 * (m.e_bands[j + 1] - m.e_bands[j]) as i32) << (lm + BITRES)) >> 4,
144        );
145        trim_offset[j] = (c
146            * (m.e_bands[j + 1] - m.e_bands[j]) as i32
147            * (alloc_trim - 5 - lm)
148            * (end - j - 1) as i32
149            * (1 << (lm + BITRES)))
150            >> 6;
151        if (m.e_bands[j + 1] - m.e_bands[j]) << lm == 1 {
152            trim_offset[j] -= c << BITRES;
153        }
154    }
155
156    let mut lo = 1;
157    let mut hi = m.nb_alloc_vectors as i32 - 1;
158    while lo <= hi {
159        let mut done = false;
160        let mut psum = 0;
161        let mid = (lo + hi) >> 1;
162        for j in (start..end).rev() {
163            let n = (m.e_bands[j + 1] - m.e_bands[j]) as i32;
164            let raw = m.alloc_vectors[mid as usize * m.alloc_stride + j] as i32;
165            let mut bitsj = (c * n * raw) << lm >> 2;
166            if bitsj > 0 {
167                bitsj = max(0, bitsj + trim_offset[j]);
168            }
169            bitsj += offsets[j];
170            if bitsj >= thresh[j] || done {
171                done = true;
172                psum += min(bitsj, cap[j]);
173            } else if bitsj >= (c << BITRES) {
174                psum += c << BITRES;
175            }
176        }
177        if psum > total {
178            hi = mid - 1;
179        } else {
180            lo = mid + 1;
181        }
182    }
183
184    let hi_final = lo as usize;
185    let lo_final = (lo - 1) as usize;
186
187    let mut bits1_buf = [0i32; MAX_EBANDS];
188    let bits1 = &mut bits1_buf[..nb_ebands];
189    let mut bits2_buf = [0i32; MAX_EBANDS];
190    let bits2 = &mut bits2_buf[..nb_ebands];
191
192    for j in start..end {
193        let n = (m.e_bands[j + 1] - m.e_bands[j]) as i32;
194        let mut bits1j = (c * n * m.alloc_vectors[lo_final * m.alloc_stride + j] as i32) << lm >> 2;
195        let mut bits2j = if hi_final >= m.nb_alloc_vectors {
196            cap[j]
197        } else {
198            (c * n * m.alloc_vectors[hi_final * m.alloc_stride + j] as i32) << lm >> 2
199        };
200
201        if bits1j > 0 {
202            bits1j = max(0, bits1j + trim_offset[j]);
203        }
204        if bits2j > 0 {
205            bits2j = max(0, bits2j + trim_offset[j]);
206        }
207        if lo_final > 0 {
208            bits1j += offsets[j];
209        }
210        bits2j += offsets[j];
211        if offsets[j] > 0 {
212            skip_start = j;
213        }
214        bits2j = max(0, bits2j - bits1j);
215        bits1[j] = bits1j;
216        bits2[j] = bits2j;
217    }
218
219    interp_bits2pulses(
220        m,
221        start,
222        end,
223        skip_start,
224        bits1,
225        bits2,
226        thresh,
227        cap,
228        total,
229        balance_out,
230        skip_rsv,
231        intensity,
232        intensity_rsv,
233        dual_stereo,
234        dual_stereo_rsv,
235        pulses,
236        ebits,
237        fine_priority,
238        c,
239        lm,
240        rc,
241        encode,
242        prev,
243        signal_bandwidth,
244    )
245}
246
247#[allow(clippy::too_many_arguments)]
248fn interp_bits2pulses(
249    m: &CeltMode,
250    start: usize,
251    end: usize,
252    skip_start: usize,
253    bits1: &[i32],
254    bits2: &[i32],
255    thresh: &[i32],
256    cap: &[i32],
257    total: i32,
258    balance_out: &mut i32,
259    skip_rsv: i32,
260    intensity: &mut i32,
261    mut intensity_rsv: i32,
262    dual_stereo: &mut i32,
263    dual_stereo_rsv: i32,
264    pulses: &mut [i32],
265    ebits: &mut [i32],
266    fine_priority: &mut [i32],
267    c: i32,
268    lm: i32,
269    rc: &mut RangeCoder,
270    encode: bool,
271    prev: i32,
272    signal_bandwidth: i32,
273) -> i32 {
274    let mut psum: i32;
275    let mut lo = 0;
276    let mut hi = 1 << 6;
277    let alloc_floor = c << BITRES;
278    let stereo = if c > 1 { 1 } else { 0 };
279    let log_m = lm << BITRES;
280
281    let mut bits_buf = [0i32; MAX_EBANDS];
282    let bits = &mut bits_buf[..m.nb_ebands];
283
284    for _ in 0..6 {
285        let mid = (lo + hi) >> 1;
286        psum = 0;
287        let mut done = false;
288        for j in (start..end).rev() {
289            let tmp = bits1[j] + ((mid * bits2[j]) >> 6);
290            if tmp >= thresh[j] || done {
291                done = true;
292                psum += min(tmp, cap[j]);
293            } else if tmp >= alloc_floor {
294                psum += alloc_floor;
295            }
296        }
297        if psum > total {
298            hi = mid;
299        } else {
300            lo = mid;
301        }
302    }
303    psum = 0;
304    let mut done = false;
305    for j in (start..end).rev() {
306        let mut tmp = bits1[j] + ((lo * bits2[j]) >> 6);
307        if tmp < thresh[j] && !done {
308            if tmp >= alloc_floor {
309                tmp = alloc_floor;
310            } else {
311                tmp = 0;
312            }
313        } else {
314            done = true;
315        }
316        tmp = min(tmp, cap[j]);
317        bits[j] = tmp;
318        psum += tmp;
319    }
320
321    let mut coded_bands = end;
322    let mut total_with_rsv = total;
323    loop {
324        if coded_bands <= start {
325            break;
326        }
327        let j = coded_bands - 1;
328        if j <= skip_start {
329            total_with_rsv += skip_rsv;
330            break;
331        }
332
333        let left = total_with_rsv - psum;
334        let nb_samples = (m.e_bands[coded_bands] - m.e_bands[start]) as i32;
335        let percoeff = left / nb_samples;
336        let left_rem = left - nb_samples * percoeff;
337        let rem = max(left_rem - (m.e_bands[j] - m.e_bands[start]) as i32, 0);
338        let band_width = (m.e_bands[coded_bands] - m.e_bands[j]) as i32;
339        let mut band_bits = bits[j] + percoeff * band_width + rem;
340
341        if band_bits >= max(thresh[j], alloc_floor + (1 << BITRES)) {
342            if encode {
343                let depth_threshold = if coded_bands > 17 {
344                    if (j as i32) < prev { 7 } else { 9 }
345                } else {
346                    0
347                };
348                if coded_bands <= start + 2
349                    || (band_bits > ((depth_threshold * band_width) << lm << BITRES) >> 4
350                        && (j as i32) <= signal_bandwidth)
351                {
352                    rc.encode_bit_logp(true, 1);
353                    break;
354                }
355                rc.encode_bit_logp(false, 1);
356            } else {
357                let bit = rc.decode_bit_logp(1);
358                if bit {
359                    break;
360                }
361            }
362            psum += 1 << BITRES;
363            band_bits -= 1 << BITRES;
364        }
365        psum -= bits[j] + intensity_rsv;
366        if intensity_rsv > 0 {
367            intensity_rsv = LOG2_FRAC_TABLE[j - start] as i32;
368        }
369        psum += intensity_rsv;
370        if band_bits >= alloc_floor {
371            psum += alloc_floor;
372            bits[j] = alloc_floor;
373        } else {
374            bits[j] = 0;
375        }
376        coded_bands -= 1;
377    }
378
379    if intensity_rsv > 0 {
380        if encode {
381            *intensity = min(*intensity, coded_bands as i32);
382            rc.enc_uint(
383                (*intensity - start as i32) as u32,
384                (coded_bands + 1 - start) as u32,
385            );
386        } else {
387            *intensity = start as i32 + rc.dec_uint((coded_bands + 1 - start) as u32) as i32;
388        }
389    } else {
390        *intensity = 0;
391    }
392
393    let mut dual_stereo_rsv_final = dual_stereo_rsv;
394    if *intensity <= start as i32 {
395        total_with_rsv += dual_stereo_rsv_final;
396        dual_stereo_rsv_final = 0;
397    }
398    if dual_stereo_rsv_final > 0 {
399        if encode {
400            rc.encode_bit_logp(*dual_stereo != 0, 1);
401        } else {
402            *dual_stereo = if rc.decode_bit_logp(1) { 1 } else { 0 };
403        }
404    } else {
405        *dual_stereo = 0;
406    }
407
408    let mut left = total_with_rsv - psum;
409    let nb_samples = (m.e_bands[coded_bands] - m.e_bands[start]) as i32;
410    let percoeff = left / nb_samples;
411    left -= nb_samples * percoeff;
412    for (j, bits_j) in bits[start..coded_bands]
413        .iter_mut()
414        .enumerate()
415        .map(|(i, v)| (i + start, v))
416    {
417        *bits_j += percoeff * (m.e_bands[j + 1] - m.e_bands[j]) as i32;
418    }
419    for (j, bits_j) in bits[start..coded_bands]
420        .iter_mut()
421        .enumerate()
422        .map(|(i, v)| (i + start, v))
423    {
424        let tmp = min(left, (m.e_bands[j + 1] - m.e_bands[j]) as i32);
425        *bits_j += tmp;
426        left -= tmp;
427    }
428
429    let mut balance = 0;
430    for j in start..coded_bands {
431        let n0 = (m.e_bands[j + 1] - m.e_bands[j]) as i32;
432        let n = n0 << lm;
433        let bit = bits[j] + balance;
434
435        let mut excess;
436        if n > 1 {
437            excess = max(bit - cap[j], 0);
438            bits[j] = bit - excess;
439
440            let den = c * n
441                + (if c == 2 && n > 2 && *dual_stereo == 0 && (j as i32) < *intensity {
442                    1
443                } else {
444                    0
445                });
446            let nc_log_n = den * (m.log_n[j] as i32 + log_m);
447            let mut offset = (nc_log_n >> 1) - den * FINE_OFFSET;
448
449            if n == 2 {
450                offset += den << BITRES >> 2;
451            }
452
453            if bits[j] + offset < (den * 2) << BITRES {
454                offset += nc_log_n >> 2;
455            } else if bits[j] + offset < (den * 3) << BITRES {
456                offset += nc_log_n >> 3;
457            }
458
459            ebits[j] = max(0, bits[j] + offset + (den << (BITRES - 1)));
460
461            let num = ebits[j];
462            if den > 0 {
463                ebits[j] = ((num as u32 / den as u32) >> BITRES) as i32;
464            } else {
465                ebits[j] = 0;
466            }
467
468            if c * ebits[j] > (bits[j] >> BITRES) {
469                ebits[j] = bits[j] >> stereo >> BITRES;
470            }
471            ebits[j] = min(ebits[j], MAX_FINE_BITS);
472            fine_priority[j] = if ebits[j] * (den << BITRES) >= bits[j] + offset {
473                1
474            } else {
475                0
476            };
477            bits[j] -= (c * ebits[j]) << BITRES;
478        } else {
479            excess = max(0, bit - (c << BITRES));
480            bits[j] = bit - excess;
481            ebits[j] = 0;
482            fine_priority[j] = 1;
483        }
484
485        if excess > 0 {
486            let extra_fine = min(excess >> (stereo + BITRES), MAX_FINE_BITS - ebits[j]);
487            ebits[j] += extra_fine;
488            let extra_bits = (extra_fine * c) << BITRES;
489            fine_priority[j] = if extra_bits >= excess - balance { 1 } else { 0 };
490            excess -= extra_bits;
491        }
492        balance = excess;
493        pulses[j] = bits[j];
494    }
495    *balance_out = balance;
496
497    for j in coded_bands..end {
498        ebits[j] = bits[j] >> stereo >> BITRES;
499        bits[j] = 0;
500        fine_priority[j] = if ebits[j] < 1 { 1 } else { 0 };
501        pulses[j] = 0;
502    }
503
504    coded_bands as i32
505}