lc3_codec/decoder/
arithmetic_codec.rs

1use super::{
2    buffer_reader::{BufferReader, BufferReaderError},
3    side_info::SideInfo,
4};
5use crate::{
6    common::{
7        config::FrameDuration,
8        constants::{MAX_LEN_FREQUENCY, MAX_LEN_SPECTRAL},
9    },
10    tables::{
11        spectral_data_tables::{AC_SPEC_CUMFREQ, AC_SPEC_FREQ, AC_SPEC_LOOKUP},
12        temporal_noise_shaping_tables::{
13            AC_TNS_COEF_CUMFREQ, AC_TNS_COEF_FREQ, AC_TNS_ORDER_CUMFREQ, AC_TNS_ORDER_FREQ, MAXLAG, TNS_NUMFILTERS_MAX,
14        },
15    },
16};
17use heapless::Vec;
18#[allow(unused_imports)]
19use num_traits::real::Real;
20
21#[derive(Debug)]
22struct ArithmeticDecoderState {
23    ac_low: u32,   // should this be i32?
24    ac_range: u32, // should this be i32?
25}
26
27#[derive(Debug)]
28pub enum ArithmeticCodecError {
29    AcRangeFlOutOfRange(u32, u32),
30    BufferReader(BufferReaderError),
31}
32
33impl From<BufferReaderError> for ArithmeticCodecError {
34    fn from(err: BufferReaderError) -> Self {
35        Self::BufferReader(err)
36    }
37}
38
39#[derive(Debug)]
40pub enum ArithmeticDecodeError {
41    ArithmeticCodec(ArithmeticCodecError),
42    TnsOrder(usize, ArithmeticCodecError),
43    TnsCoef(usize, usize, ArithmeticCodecError),
44    SpectralData(usize, usize, ArithmeticCodecError),
45    SpectralBoolData(usize, usize, BufferReaderError),
46    NegativeResidualNumBits,
47    ResidualBoolData(bool, usize),
48    ResidualBoolDataOverflow(bool, usize, usize),
49}
50
51impl From<ArithmeticCodecError> for ArithmeticDecodeError {
52    fn from(err: ArithmeticCodecError) -> Self {
53        Self::ArithmeticCodec(err)
54    }
55}
56
57fn ac_dec_init(buf: &[u8], reader: &mut BufferReader) -> Result<ArithmeticDecoderState, ArithmeticCodecError> {
58    let ac_low_fl = reader.read_head_u24(buf)?;
59    let ac_range_fl = 0x00ffffff;
60
61    Ok(ArithmeticDecoderState {
62        ac_low: ac_low_fl,
63        ac_range: ac_range_fl,
64    })
65}
66
67fn ac_decode(
68    buf: &[u8],
69    reader: &mut BufferReader,
70    st: &mut ArithmeticDecoderState,
71    cum_freq: &[i16],
72    sym_freq: &[i16],
73) -> Result<usize, ArithmeticCodecError> {
74    let tmp = st.ac_range >> 10;
75
76    let limit = tmp << 10;
77    if st.ac_low >= limit {
78        return Err(ArithmeticCodecError::AcRangeFlOutOfRange(st.ac_low, limit));
79    }
80
81    let mut val = cum_freq.len() - 1;
82    while st.ac_low < (tmp * cum_freq[val] as u32) {
83        val -= 1;
84    }
85
86    st.ac_low -= tmp * cum_freq[val] as u32;
87    st.ac_range = tmp * sym_freq[val] as u32;
88
89    while st.ac_range < 0x10000 {
90        st.ac_low <<= 8;
91        st.ac_low &= 0x00ffffff;
92        st.ac_low += reader.read_head_byte(buf)? as u32;
93        st.ac_range <<= 8;
94    }
95
96    Ok(val)
97}
98
99#[derive(Debug, PartialEq)]
100pub struct ArithmeticData {
101    pub reflect_coef_order: [usize; 2], // also called rc_order
102    pub reflect_coef_ints: [usize; 16], // also called rc_i or tns_idx
103    pub residual_bits: Vec<bool, 480>,
104    pub noise_filling_seed: i32,
105    pub is_zero_frame: bool,
106    pub frame_num_bits: usize, // number of bits in the frame (nbits) (frame length * 8) (e.g. 1200)
107}
108
109pub fn decode(
110    buf: &[u8],                // the entire frame
111    reader: &mut BufferReader, // a cursor for reading parts of the frame
112    fs_ind: usize,             // sampling rate index
113    ne: usize,                 // number of encoded spectral lines (NE) (also known as L_spec or ylen)
114    side_info: &SideInfo,      // the side info already read from the frame
115    n_ms: &FrameDuration,
116    x: &mut [i32],
117) -> Result<ArithmeticData, ArithmeticDecodeError> {
118    let num_bytes = buf.len();
119    let nbits = num_bytes * 8;
120
121    // start decoding
122    let mut st = ac_dec_init(buf, reader)?;
123
124    // decode TNS data
125    let (tns_idx, tns_order) = decode_tns_data(buf, reader, side_info, &mut st, nbits, n_ms)?;
126
127    // spectral data (mutates st, x and save_lev)
128    let mut save_lev: [i32; MAX_LEN_SPECTRAL] = [0; MAX_LEN_SPECTRAL];
129    decode_spectral_data(buf, reader, side_info, nbits, fs_ind, ne, &mut st, x, &mut save_lev)?;
130
131    // residual data and finalization
132    for item in &mut x[side_info.lastnz..] {
133        *item = 0;
134    }
135
136    // mutates x and save_lev
137    let residual_bits = decode_residual_bits(buf, reader, side_info, &st, nbits, ne, x, &mut save_lev)?;
138
139    // noise filling seed
140    let noise_filling_seed = x[..ne]
141        .iter()
142        .enumerate()
143        .map(|(k, item)| item.abs() * k as i32)
144        .sum::<i32>()
145        & 0xFFFF;
146
147    // zero frame flag
148    let is_zero_frame = side_info.lastnz == 2 && x[0] == 0 && x[1] == 0 && side_info.global_gain_index == 0;
149
150    Ok(ArithmeticData {
151        is_zero_frame,
152        noise_filling_seed,
153        reflect_coef_ints: tns_idx,
154        reflect_coef_order: tns_order,
155        residual_bits,
156        frame_num_bits: nbits,
157    })
158}
159
160fn decode_residual_bits(
161    buf: &[u8],
162    reader: &mut BufferReader,
163    side_info: &SideInfo,
164    st: &ArithmeticDecoderState,
165    nbits: usize,
166    ne: usize,
167    x: &mut [i32],
168    save_lev: &mut [i32],
169) -> Result<Vec<bool, MAX_LEN_FREQUENCY>, ArithmeticDecodeError> {
170    // number of residual bits
171    let mut nbits_residual = calc_num_residual_bits(reader, st, nbits)?;
172    let lsb_mode = side_info.lsb_mode;
173    let mut residual_bits = Vec::new();
174
175    // decode residual bits
176    if !lsb_mode {
177        // Ne (from the spec - also called ylen) - number of encoded spectral lines
178        for (k, x_k) in x[..ne].iter().enumerate() {
179            if *x_k != 0 {
180                if residual_bits.len() == nbits_residual {
181                    break;
182                }
183
184                let bit = reader
185                    .read_tail_bool(buf)
186                    .map_err(|_| ArithmeticDecodeError::ResidualBoolData(lsb_mode, k))?;
187
188                residual_bits
189                    .push(bit)
190                    .map_err(|_| ArithmeticDecodeError::ResidualBoolDataOverflow(lsb_mode, k, residual_bits.len()))?;
191            }
192        }
193    } else {
194        for k in (0..side_info.lastnz).step_by(2) {
195            if save_lev[k] > 0 {
196                if !read_res_bit(x, reader, buf, k, &mut nbits_residual, lsb_mode)? {
197                    break;
198                }
199
200                if !read_res_bit(x, reader, buf, k + 1, &mut nbits_residual, lsb_mode)? {
201                    break;
202                }
203            }
204        }
205    }
206
207    Ok(residual_bits)
208}
209
210// 1.3 ms
211fn decode_spectral_data(
212    buf: &[u8],
213    reader: &mut BufferReader,
214    side_info: &SideInfo,
215    nbits: usize,
216    fs_ind: usize,
217    ne: usize,
218    st: &mut ArithmeticDecoderState,
219    x: &mut [i32],
220    save_lev: &mut [i32],
221) -> Result<(), ArithmeticDecodeError> {
222    // rate flag
223    let rate_flag = if nbits > (160 + fs_ind * 160) { 512 } else { 0 };
224    let mut c = 0;
225
226    for (k, chunk) in x[..side_info.lastnz].chunks_exact_mut(2).enumerate() {
227        let mut t = c + rate_flag + if (k * 2) > (ne / 2) { 256 } else { 0 };
228
229        // seems horrible but the only way to get a reference to this data
230        let (x_k, x_kplus1) = chunk.split_at_mut(1);
231        let x_k = &mut x_k[0];
232        let x_kplus1 = &mut x_kplus1[0];
233
234        *x_k = 0;
235        *x_kplus1 = 0;
236        let mut sym = 0;
237        let mut lev: usize = 0;
238
239        // 1.0 ms
240        while lev < 14 {
241            let pki_index = t + lev.min(3) * 1024;
242            let pki = AC_SPEC_LOOKUP[pki_index] as usize;
243
244            let cum_freq = &AC_SPEC_CUMFREQ[pki];
245            let spec_freq = &AC_SPEC_FREQ[pki];
246            sym = ac_decode(buf, reader, st, cum_freq, spec_freq)
247                .map_err(|err| ArithmeticDecodeError::SpectralData(k, lev, err))?;
248
249            if sym < 16 {
250                break;
251            }
252
253            if !side_info.lsb_mode || lev > 0 {
254                let bit = reader
255                    .read_tail_bool(buf)
256                    .map_err(|err| ArithmeticDecodeError::SpectralBoolData(k, lev, err))?
257                    as i32;
258                *x_k += bit << lev;
259                let bit = reader
260                    .read_tail_bool(buf)
261                    .map_err(|err| ArithmeticDecodeError::SpectralBoolData(k, lev, err))?
262                    as i32;
263                *x_kplus1 += bit << lev;
264            }
265
266            lev += 1;
267        }
268
269        if side_info.lsb_mode {
270            // used later for residual info
271            save_lev[k] = lev as i32;
272        }
273
274        let a = sym & 0x3;
275        let b = sym >> 2;
276
277        *x_k += (a as i32) << lev;
278        *x_kplus1 += (b as i32) << lev;
279
280        if *x_k > 0 {
281            let bit = reader
282                .read_tail_bool(buf)
283                .map_err(|err| ArithmeticDecodeError::SpectralBoolData(k, lev, err))?;
284            if bit {
285                *x_k = -*x_k;
286            }
287        }
288
289        if *x_kplus1 > 0 {
290            let bit = reader
291                .read_tail_bool(buf)
292                .map_err(|err| ArithmeticDecodeError::SpectralBoolData(k, lev, err))?;
293            if bit {
294                *x_kplus1 = -*x_kplus1;
295            }
296        }
297
298        lev = lev.min(3);
299        t = if lev <= 1 { 1 + (a + b) * (lev + 1) } else { 12 + lev };
300
301        c = (c & 15) * 16 + t;
302    }
303
304    Ok(())
305}
306
307fn decode_tns_data(
308    buf: &[u8],
309    reader: &mut BufferReader,
310    side_info: &SideInfo,
311    st: &mut ArithmeticDecoderState,
312    nbits: usize,
313    n_ms: &FrameDuration,
314) -> Result<([usize; 16], [usize; 2]), ArithmeticDecodeError> {
315    let max_bits = match n_ms {
316        FrameDuration::SevenPointFiveMs => 360,
317        FrameDuration::TenMs => 480,
318    };
319
320    let tns_lpc_weighting = nbits < max_bits; // enable linear predictive coding weighting
321    let tns_lpc_weighting_idx = tns_lpc_weighting as usize;
322
323    let mut tns_idx: [usize; TNS_NUMFILTERS_MAX * MAXLAG] = [0; TNS_NUMFILTERS_MAX * MAXLAG];
324    let mut tns_order = side_info.reflect_coef_order_ari_input; // a copy of tns_order is taken
325    for (f, tns_order_f) in tns_order[..side_info.num_tns_filters].iter_mut().enumerate() {
326        if *tns_order_f > 0 {
327            let cum_freq = &AC_TNS_ORDER_CUMFREQ[tns_lpc_weighting_idx];
328            let sym_freq = &AC_TNS_ORDER_FREQ[tns_lpc_weighting_idx];
329            let order = ac_decode(buf, reader, st, cum_freq, sym_freq)
330                .map_err(|err| ArithmeticDecodeError::TnsOrder(f, err))?;
331
332            *tns_order_f = order + 1;
333            for k in 0..*tns_order_f {
334                let idx = f * 8 + k;
335                let cum_freq = &AC_TNS_COEF_CUMFREQ[k];
336                let sym_freq = &AC_TNS_COEF_FREQ[k];
337                tns_idx[idx] = ac_decode(buf, reader, st, cum_freq, sym_freq)
338                    .map_err(|err| ArithmeticDecodeError::TnsCoef(f, k, err))?;
339            }
340        }
341    }
342
343    Ok((tns_idx, tns_order))
344}
345
346fn read_res_bit(
347    x: &mut [i32],
348    reader: &mut BufferReader,
349    buf: &[u8],
350    x_index: usize,
351    nbits_res: &mut usize,
352    lsb_mode: bool,
353) -> Result<bool, ArithmeticDecodeError> {
354    // check and read bit
355    if *nbits_res == 0 {
356        return Ok(false);
357    }
358    let bit = reader
359        .read_tail_bool(buf)
360        .map_err(|_| ArithmeticDecodeError::ResidualBoolData(lsb_mode, x_index))?;
361    *nbits_res -= 1;
362
363    if bit {
364        let val = &mut x[x_index];
365        match val {
366            v if *v > 0 => {
367                *v += 1;
368            }
369            v if *v < 0 => {
370                *v -= 1;
371            }
372            v => {
373                // check and read bit
374                if *nbits_res == 0 {
375                    return Ok(false);
376                }
377                let bit = reader
378                    .read_tail_bool(buf)
379                    .map_err(|_| ArithmeticDecodeError::ResidualBoolData(lsb_mode, x_index))?;
380                *nbits_res -= 1;
381
382                *v = if bit { -1 } else { 1 };
383            }
384        };
385    }
386
387    Ok(true)
388}
389
390fn calc_num_residual_bits(
391    reader: &BufferReader,
392    st: &ArithmeticDecoderState,
393    total_bits: usize,
394) -> Result<usize, ArithmeticDecodeError> {
395    let nbits_side = reader.get_tail_bit_cursor() - 8;
396
397    // TODO: surely there is a better way to do this
398    let nbits_ari = (reader.get_head_byte_cursor() + 1 - 3) * 8 + 25 - (st.ac_range as f64).log2().floor() as usize;
399
400    if total_bits >= (nbits_side + nbits_ari) {
401        Ok(total_bits - nbits_side - nbits_ari)
402    } else {
403        Err(ArithmeticDecodeError::NegativeResidualNumBits)
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    extern crate std;
410    use crate::decoder::side_info::{Bandwidth, LongTermPostFilterInfo, SnsVq};
411
412    use super::*;
413
414    #[test]
415    fn arithmetic_decode() {
416        let buf = [
417            187, 56, 111, 155, 76, 236, 70, 99, 10, 135, 219, 76, 176, 3, 108, 203, 131, 111, 206, 221, 195, 25, 96,
418            240, 18, 202, 163, 241, 109, 142, 198, 122, 176, 70, 37, 6, 35, 190, 110, 184, 251, 162, 71, 7, 151, 58,
419            42, 79, 200, 192, 99, 157, 234, 156, 245, 43, 84, 64, 167, 32, 52, 106, 43, 75, 4, 102, 213, 123, 168, 120,
420            213, 252, 208, 118, 78, 115, 154, 158, 157, 26, 152, 231, 121, 146, 203, 11, 169, 227, 75, 154, 237, 154,
421            227, 145, 196, 182, 207, 94, 95, 26, 184, 248, 1, 118, 72, 47, 18, 205, 56, 96, 195, 139, 216, 240, 113,
422            233, 44, 198, 245, 157, 139, 70, 162, 182, 139, 136, 165, 68, 79, 247, 161, 126, 17, 135, 36, 30, 229, 24,
423            196, 2, 5, 65, 111, 80, 124, 168, 70, 156, 198, 60,
424        ];
425        let mut reader = BufferReader::new_at(0, 64);
426        let fs_ind = 4;
427        let ne = 400;
428        let side_info = SideInfo {
429            bandwidth: Bandwidth::FullBand,
430            lastnz: 400,
431            lsb_mode: false,
432            global_gain_index: 204,
433            num_tns_filters: 2,
434            reflect_coef_order_ari_input: [1, 0],
435            sns_vq: SnsVq {
436                ind_lf: 13,
437                ind_hf: 4,
438                ls_inda: 1,
439                ls_indb: 0,
440                idx_a: 1718290,
441                idx_b: 2,
442                submode_lsb: 0,
443                submode_msb: 0,
444                g_ind: 0,
445            },
446            long_term_post_filter_info: LongTermPostFilterInfo {
447                pitch_present: false,
448                is_active: false,
449                pitch_index: 0,
450            },
451            noise_factor: 3,
452        };
453        let n_ms = &FrameDuration::TenMs;
454        let mut x = [0; MAX_LEN_SPECTRAL];
455
456        let arithmetic_data = decode(&buf, &mut reader, fs_ind, ne, &side_info, &n_ms, &mut x).unwrap();
457
458        assert_eq!(arithmetic_data.is_zero_frame, false);
459        assert_eq!(arithmetic_data.frame_num_bits, 1200);
460        assert_eq!(arithmetic_data.noise_filling_seed, 56909);
461        assert_eq!(
462            arithmetic_data.reflect_coef_ints,
463            [6, 10, 7, 8, 7, 9, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0]
464        );
465        assert_eq!(
466            arithmetic_data.residual_bits,
467            [
468                false, true, true, true, false, false, false, true, false, false, true, true, true, false, false,
469                false, true, true, true, false, true, false, true, true, false, false, true, true, false, true, true,
470                false, true, true, true, false, true, false, true, true, false, false, true, true, true
471            ]
472        );
473        assert_eq!(arithmetic_data.reflect_coef_order, [8, 0]);
474    }
475}