lc3_codec/encoder/
bitstream_encoding.rs

1use super::{
2    bandwidth_detector::BandwidthDetectorResult, buffer_writer::BufferWriter,
3    long_term_post_filter::LongTermPostFilterResult, residual_spectrum::ResidualBits,
4    spectral_noise_shaping::SnsResult, spectral_quantization::SpectralQuantizationResult,
5    temporal_noise_shaping::TnsResult,
6};
7use crate::tables::{spec_noise_shape_quant_tables::*, spectral_data_tables::*, temporal_noise_shaping_tables::*};
8use heapless::Vec;
9
10#[allow(unused_imports)]
11use num_traits::real::Real;
12
13const MAX_NBITS_LSB: usize = 480 * 8; // NOTE: this is a guess, find out what the max can be!!
14
15#[derive(Default)]
16pub struct BitstreamEncoding {
17    ne: usize,
18    nbytes: usize,
19    nbits: usize,
20    nbits_side_initial: usize,
21    nlsbs: usize,
22    lsbs: Vec<u8, MAX_NBITS_LSB>,
23    st: ArithmeticEncoderState,
24    writer: BufferWriter,
25}
26
27#[derive(Default)]
28struct ArithmeticEncoderState {
29    pub low: u32,
30    pub range: u32,
31    pub cache: i32,
32    pub carry: i32,
33    pub carry_count: i32,
34}
35
36impl BitstreamEncoding {
37    pub fn new(ne: usize) -> Self {
38        Self {
39            ne,
40            ..Default::default()
41        }
42    }
43
44    fn write_uint_backward(&mut self, value: usize, num_bits: usize, bytes: &mut [u8]) {
45        self.writer.write_uint_backward(bytes, value, num_bits)
46    }
47
48    fn write_bool_backward(&mut self, value: bool, bytes: &mut [u8]) {
49        self.writer.write_bool_backward(bytes, value)
50    }
51
52    fn write_byte_forward(&mut self, value: u8, bytes: &mut [u8]) {
53        self.writer.write_byte_forward(bytes, value)
54    }
55
56    fn write_uint_forward(&mut self, value: usize, num_bits: usize, bytes: &mut [u8]) {
57        self.writer.write_uint_forward(bytes, value as u16, num_bits)
58    }
59
60    fn nbits_side_written(&self) -> usize {
61        self.writer.nbits_side_written(self.nbits)
62    }
63
64    fn nbits_side_forcast(&self) -> usize {
65        let mut nbits_ari = (self.writer.bp * 8) as i32;
66        nbits_ari += 25 - (self.st.range as f64).log2().floor() as i32;
67        if self.st.carry >= 0 {
68            nbits_ari += 8;
69        }
70        if self.st.carry_count > 0 {
71            nbits_ari += self.st.carry_count * 8;
72        }
73
74        nbits_ari as usize
75    }
76
77    pub fn encode<'a>(
78        &mut self,
79        bandwidth: BandwidthDetectorResult,
80        sns: SnsResult,
81        tns: TnsResult,
82        post_filter: LongTermPostFilterResult,
83        spec: SpectralQuantizationResult,
84        residual_bits: ResidualBits<'a>,
85        noise_factor: usize,
86        spec_output: &[i16],
87        buf_out: &mut [u8],
88    ) {
89        // bitstream encoding
90        self.init(buf_out);
91
92        // side information
93        self.bandwidth(bandwidth.bandwidth_ind, bandwidth.nbits_bandwidth, buf_out);
94        self.last_non_zero_tuple(spec.lastnz_trunc, buf_out);
95        self.lsb_mode_bit(spec.lsb_mode, buf_out);
96        self.global_gain(spec.gg_ind as usize, buf_out);
97        self.tns_activation_flag(tns.num_tns_filters, &tns.rc_order, buf_out);
98        self.pitch_present_flag(post_filter.pitch_present, buf_out);
99
100        // encode SCF VQ parameters - 1st stage
101        self.encode_scf_vq_1st_stage(sns.ind_lf, sns.ind_hf, buf_out);
102
103        // encode SCF VQ parameters - 2nd stage
104        self.encode_scf_vq_2nd_stage(sns.shape_j, sns.gind, sns.ls_inda as usize, sns.index_joint_j, buf_out);
105
106        if post_filter.pitch_present {
107            self.ltpf_data(post_filter.ltpf_active, post_filter.pitch_index, buf_out);
108        }
109
110        self.noise_factor(noise_factor, buf_out);
111
112        // arithmetic encoding
113        self.ac_enc_init();
114
115        // tns data
116        self.tns_data(
117            tns.lpc_weighting,
118            tns.num_tns_filters,
119            &tns.rc_order,
120            &tns.rc_i,
121            buf_out,
122        );
123
124        // spectral data
125        self.spectral_data(
126            spec.lastnz_trunc,
127            spec.rate_flag,
128            spec.lsb_mode,
129            spec_output,
130            spec.nbits_lsb,
131            buf_out,
132        );
133
134        // residual data and finalization
135        self.residual_data_and_finalization(spec.lsb_mode, residual_bits, buf_out);
136    }
137
138    fn init(&mut self, bytes: &mut [u8]) {
139        self.nbytes = bytes.len();
140        self.nbits = self.nbytes * 8;
141        self.writer = BufferWriter::new(bytes.len());
142        bytes.fill(0);
143        self.nlsbs = 0;
144    }
145
146    fn bandwidth(&mut self, p_bw: usize, nbits_bw: usize, bytes: &mut [u8]) {
147        if nbits_bw > 0 {
148            self.write_uint_backward(p_bw, nbits_bw, bytes);
149        }
150    }
151
152    fn last_non_zero_tuple(&mut self, lastnz_trunc: usize, bytes: &mut [u8]) {
153        let value = (lastnz_trunc >> 1) - 1;
154        let num_bits = (self.ne as f64 / 2.0).log2().ceil() as usize;
155        self.write_uint_backward(value, num_bits, bytes)
156    }
157
158    fn lsb_mode_bit(&mut self, lsb_mode: bool, bytes: &mut [u8]) {
159        self.write_bool_backward(lsb_mode, bytes)
160    }
161
162    fn global_gain(&mut self, gg_ind: usize, bytes: &mut [u8]) {
163        self.write_uint_backward(gg_ind, 8, bytes)
164    }
165
166    fn tns_activation_flag(&mut self, num_tns_filters: usize, rc_order: &[usize], bytes: &mut [u8]) {
167        for rc_order_f in rc_order[..num_tns_filters].iter() {
168            let value = *rc_order_f != 0;
169            self.write_bool_backward(value, bytes);
170        }
171    }
172
173    fn pitch_present_flag(&mut self, pitch_present: bool, bytes: &mut [u8]) {
174        self.write_bool_backward(pitch_present, bytes)
175    }
176
177    fn encode_scf_vq_1st_stage(&mut self, ind_lf: usize, ind_hf: usize, bytes: &mut [u8]) {
178        self.write_uint_backward(ind_lf, 5, bytes);
179        self.write_uint_backward(ind_hf, 5, bytes);
180    }
181
182    fn encode_scf_vq_2nd_stage(
183        &mut self,
184        shape_j: usize,
185        gain_i: usize,
186        ls_inda: usize,
187        index_joint_j: usize,
188        bytes: &mut [u8],
189    ) {
190        let submode_msb = (shape_j >> 1) != 0;
191        self.write_bool_backward(submode_msb, bytes);
192        let gain_msbs_num_bits = SNS_GAIN_MSB_BITS[shape_j];
193        let gain_msbs = gain_i >> SNS_GAIN_LSB_BITS[shape_j];
194        self.write_uint_backward(gain_msbs, gain_msbs_num_bits, bytes);
195        let ls_inda_flag = ls_inda != 0;
196        self.write_bool_backward(ls_inda_flag, bytes);
197
198        if !submode_msb {
199            self.write_uint_backward(index_joint_j, 13, bytes);
200            self.write_uint_backward(index_joint_j >> 13, 12, bytes);
201        } else {
202            self.write_uint_backward(index_joint_j, 12, bytes);
203            self.write_uint_backward(index_joint_j >> 12, 12, bytes);
204        }
205    }
206
207    pub fn ltpf_data(&mut self, ltpf_active: bool, pitch_index: usize, bytes: &mut [u8]) {
208        self.write_bool_backward(ltpf_active, bytes);
209        self.write_uint_backward(pitch_index, 9, bytes);
210    }
211
212    pub fn noise_factor(&mut self, f_nf: usize, bytes: &mut [u8]) {
213        self.write_uint_backward(f_nf, 3, bytes);
214    }
215
216    pub fn ac_enc_init(&mut self) {
217        self.st.low = 0;
218        self.st.range = 0x00ff_ffff;
219        self.st.cache = -1;
220        self.st.carry = 0;
221        self.st.carry_count = 0;
222    }
223
224    pub fn tns_data(
225        &mut self,
226        tns_lpc_weighting: u8, // TODO: this is actually used as a usize
227        num_tns_filters: usize,
228        rc_order: &[usize],
229        rc_i: &[usize],
230        bytes: &mut [u8],
231    ) {
232        for f in 0..num_tns_filters {
233            if rc_order[f] > 0 {
234                let cum_freq = AC_TNS_ORDER_CUMFREQ[tns_lpc_weighting as usize][rc_order[f] - 1];
235                let sym_freq = AC_TNS_ORDER_FREQ[tns_lpc_weighting as usize][rc_order[f] - 1];
236                self.ac_encode(cum_freq, sym_freq, bytes);
237                for k in 0..rc_order[f] {
238                    let cum_freq = AC_TNS_COEF_CUMFREQ[k][rc_i[k + 8 * f]];
239                    let sym_freq = AC_TNS_COEF_FREQ[k][rc_i[k + 8 * f]];
240                    self.ac_encode(cum_freq, sym_freq, bytes);
241                }
242            }
243        }
244    }
245
246    pub fn spectral_data(
247        &mut self,
248        lastnz_trunc: usize,
249        rate_flag: usize,
250        lsb_mode: bool,
251        x_q: &[i16],
252        nbits_lsb: usize,
253        bytes: &mut [u8],
254    ) {
255        self.nbits_side_initial = self.nbits_side_written();
256        self.lsbs.clear();
257        for _ in 0..nbits_lsb {
258            self.lsbs.push(0).unwrap();
259        }
260
261        let mut c = 0;
262        for k in (0..lastnz_trunc).step_by(2) {
263            let mut t = c + rate_flag + if k > (self.ne / 2) { 256 } else { 0 };
264            let mut a = x_q[k].unsigned_abs();
265            let mut a_lsb = a;
266            let mut b = x_q[k + 1].unsigned_abs();
267            let mut b_lsb = b;
268            let mut lev = 0;
269            let mut lsb0: u8 = 0;
270            let mut lsb1: u8 = 0;
271            while a.max(b) >= 4 {
272                let pki_index = t + lev.min(3) * 1024;
273                let pki = AC_SPEC_LOOKUP[pki_index] as usize;
274                let cum_freq = AC_SPEC_CUMFREQ[pki][16];
275                let sym_freq = AC_SPEC_FREQ[pki][16];
276                self.ac_encode(cum_freq, sym_freq, bytes);
277                if lsb_mode && lev == 0 {
278                    lsb0 = a as u8 & 1;
279                    lsb1 = b as u8 & 1;
280                } else {
281                    self.write_bool_backward((a & 1) == 1, bytes);
282                    self.write_bool_backward((b & 1) == 1, bytes);
283                }
284                a >>= 1;
285                b >>= 1;
286                lev += 1;
287            }
288            let pki_index = t + lev.min(3) * 1024;
289            let pki = AC_SPEC_LOOKUP[pki_index] as usize;
290            let sym = (a + 4 * b) as usize;
291            let cum_freq = AC_SPEC_CUMFREQ[pki][sym];
292            let sym_freq = AC_SPEC_FREQ[pki][sym];
293            self.ac_encode(cum_freq, sym_freq, bytes);
294
295            if lsb_mode && lev > 0 {
296                a_lsb >>= 1;
297                b_lsb >>= 1;
298
299                self.lsbs[self.nlsbs] = lsb0;
300                self.nlsbs += 1;
301                if a_lsb == 0 && x_q[k] != 0 {
302                    self.lsbs[self.nlsbs] = if x_q[k] > 0 { 0 } else { 1 };
303                    self.nlsbs += 1;
304                }
305                self.lsbs[self.nlsbs] = lsb1;
306                self.nlsbs += 1;
307                if b_lsb == 0 && x_q[k + 1] != 0 {
308                    self.lsbs[self.nlsbs] = if x_q[k + 1] > 0 { 0 } else { 1 };
309                    self.nlsbs += 1;
310                }
311            }
312            if a_lsb > 0 {
313                self.write_bool_backward(x_q[k] <= 0, bytes);
314            }
315            if b_lsb > 0 {
316                self.write_bool_backward(x_q[k + 1] <= 0, bytes);
317            }
318            lev = lev.min(3);
319            t = if lev <= 1 {
320                1 + ((a + b) as usize) * (lev + 1)
321            } else {
322                12 + lev
323            };
324            c = (c & 15) * 16 + t;
325        }
326    }
327
328    pub fn residual_data_and_finalization(
329        &mut self,
330        lsb_mode: bool,
331        residual_bits: impl Iterator<Item = bool>,
332        bytes: &mut [u8],
333    ) {
334        let nbits_side = self.nbits_side_written();
335        let nbits_ari = self.nbits_side_forcast();
336        let nbits_residual_enc = self.nbits as i32 - (nbits_side + nbits_ari) as i32;
337        let nbits_residual_enc = nbits_residual_enc.max(0) as usize;
338
339        if !lsb_mode {
340            for res_bit in residual_bits.take(nbits_residual_enc) {
341                self.write_bool_backward(res_bit, bytes);
342            }
343        } else {
344            let nbits_residual_enc = nbits_residual_enc.min(self.nlsbs);
345            for k in 0..nbits_residual_enc {
346                let value = self.lsbs[k] == 1;
347                self.write_bool_backward(value, bytes);
348            }
349        }
350
351        self.ac_enc_finish(bytes);
352    }
353
354    fn ac_enc_finish(&mut self, bytes: &mut [u8]) {
355        let mut bits: i8 = 1;
356        while (self.st.range >> (24 - bits)) == 0 {
357            bits += 1;
358        }
359        let mut mask = 0x00ff_ffff >> bits;
360        let mut val = self.st.low + mask;
361        let over1 = val >> 24;
362        let high = self.st.low + self.st.range;
363        let over2 = high >> 24;
364        val &= 0x00ff_ffff & !mask;
365        if over1 == over2 {
366            if (val + mask) >= high {
367                bits += 1;
368                mask >>= 1;
369                val = ((self.st.low + mask) & 0x00ff_ffff) & !mask;
370            }
371            if val < self.st.low {
372                self.st.carry = 1;
373            }
374        }
375        self.st.low = val;
376        while bits > 0 {
377            self.ac_shift(bytes);
378            bits -= 8;
379        }
380        bits += 8;
381        if bits < 0 {
382            panic!("bits is negative: {}", bits);
383        }
384        if self.st.carry_count > 0 {
385            self.write_byte_forward(self.st.cache as u8, bytes);
386            while self.st.carry_count > 1 {
387                self.write_byte_forward(0xff, bytes);
388                self.st.carry_count -= 1;
389            }
390            let value = 0xff >> (8 - bits);
391            self.write_uint_forward(value, bits as usize, bytes);
392        } else {
393            self.write_uint_forward(self.st.cache as usize, bits as usize, bytes);
394        }
395    }
396
397    fn ac_shift(&mut self, bytes: &mut [u8]) {
398        if self.st.low < 0x00ff_0000 || self.st.carry == 1 {
399            if self.st.cache >= 0 {
400                let byte = ((self.st.cache + self.st.carry) & 0xff) as u8;
401                self.write_byte_forward(byte, bytes);
402            }
403            while self.st.carry_count > 0 {
404                let byte = ((self.st.carry + 0xff) & 0xff) as u8;
405                self.write_byte_forward(byte, bytes);
406                self.st.carry_count -= 1;
407            }
408            self.st.cache = (self.st.low >> 16) as i32;
409            self.st.carry = 0;
410        } else {
411            self.st.carry_count += 1;
412        }
413        self.st.low <<= 8;
414        self.st.low &= 0x00ff_ffff;
415    }
416
417    fn ac_encode(&mut self, cum_freq: i16, sym_freq: i16, bytes: &mut [u8]) {
418        let r = self.st.range >> 10;
419        self.st.low += r * cum_freq as u32;
420        if self.st.low >> 24 != 0 {
421            self.st.carry = 1;
422        }
423        self.st.low &= 0x00ff_ffff;
424        self.st.range = r * sym_freq as u32;
425        while self.st.range < 0x10000 {
426            self.st.range <<= 8;
427            self.ac_shift(bytes);
428        }
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    extern crate std;
435
436    use core::slice::Iter;
437
438    use super::*;
439
440    pub struct ResidualBitsTest<'a> {
441        inner: Iter<'a, bool>,
442    }
443
444    impl<'a> Iterator for ResidualBitsTest<'a> {
445        type Item = bool;
446
447        fn next(&mut self) -> Option<Self::Item> {
448            match self.inner.next() {
449                Some(x) => Some(*x),
450                None => None,
451            }
452        }
453    }
454
455    // TODO: this test sucks, fix it
456    #[test]
457    fn bitstream_encoding_run() {
458        let mut bitstream_encoding = BitstreamEncoding::new(400);
459        let mut buf_out = [0; 150];
460
461        bitstream_encoding.init(&mut buf_out);
462
463        bitstream_encoding.bandwidth(4, 3, &mut buf_out);
464
465        bitstream_encoding.last_non_zero_tuple(350, &mut buf_out);
466        bitstream_encoding.lsb_mode_bit(false, &mut buf_out);
467        bitstream_encoding.global_gain(193, &mut buf_out);
468        let rc_order = [8, 6];
469        bitstream_encoding.tns_activation_flag(2, &rc_order, &mut buf_out);
470        let pitch_present = true;
471        bitstream_encoding.pitch_present_flag(pitch_present, &mut buf_out);
472
473        bitstream_encoding.encode_scf_vq_1st_stage(8, 17, &mut buf_out);
474
475        bitstream_encoding.encode_scf_vq_2nd_stage(3, 0, 0, 15253432, &mut buf_out);
476
477        if pitch_present {
478            bitstream_encoding.ltpf_data(false, 0, &mut buf_out);
479        }
480
481        bitstream_encoding.noise_factor(6, &mut buf_out);
482
483        bitstream_encoding.ac_enc_init();
484
485        let rc_i = [10, 7, 8, 9, 7, 9, 8, 9, 14, 11, 6, 9, 7, 9, 8, 8];
486        bitstream_encoding.tns_data(0, 2, &rc_order, &rc_i, &mut buf_out);
487
488        let x_q = [
489            102, -146, -18, -14, -104, -128, 264, 254, -417, -180, 94, -28, 20, -38, 21, -62, -125, 10, -15, -4, 27,
490            -9, -4, 3, 3, -1, 0, -13, -2, 0, -11, 3, 5, 4, -10, -18, -22, 4, 10, -5, 17, 4, -6, 2, 6, 11, -3, -3, 29,
491            16, -15, 3, 4, 7, 4, -3, 5, 0, 6, 0, -6, 1, 0, -1, -2, 7, 6, 2, -9, -4, 3, -5, 3, 6, 4, -1, 3, 5, -1, -10,
492            -16, 1, 1, 0, -4, -1, 7, -5, -4, -2, 0, -4, 1, 4, -1, -2, -7, 1, -2, 1, 1, -7, 1, 4, -1, -1, 2, 0, -1, -2,
493            1, 3, -5, -1, 0, 2, 0, 0, 2, 0, 1, -3, 1, 2, 0, -5, -1, 5, -1, 0, -3, 0, 0, -1, 0, -2, 2, -3, 0, 1, -2, -1,
494            -2, 0, 1, 2, -2, 0, -1, -3, -2, -1, 3, -2, -2, 0, 1, 0, -3, 1, 0, 0, -1, 0, 1, 0, 1, -2, 1, 1, 0, -1, 0, 0,
495            1, 2, -1, 0, -1, 1, 0, -1, 1, -1, 1, -1, 0, 0, 0, -1, -1, 0, -2, 1, -1, -1, -1, -1, 0, -2, 0, -1, -1, 0, 0,
496            0, 1, 0, -1, 0, 1, 1, 0, 0, 0, -1, 0, 0, -2, -1, 0, 1, 0, 0, 0, 1, -1, -1, 1, 0, 0, -1, -2, -1, -1, 0, 0,
497            0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, -1, -1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0,
498            0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
499            0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
500            0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
501            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
502        ];
503        bitstream_encoding.spectral_data(350, 512, false, &x_q, 107, &mut buf_out);
504
505        let res_bits = [
506            false, true, false, false, false, false, true, true, false, true, false, true, true, true, true, false,
507            false, true, false, true, true, true, false, true, true, true, false, true, false, true, true, true, false,
508            false, true, true, true, true, false, false, true, true, true, false, true, false, false, true, false,
509            true, true, true, true, false, true, false, true, false, false, true, false, true, true, false, false,
510            true, false, false, false, true, false, true, true, true, false, false, true, false, false, true, true,
511            false, true, true, false, false, true, false, false, true, false, true, false, false, false, true, false,
512            true, false, true, true, true, true, true, true, false, false, true, true, false, false, true, false,
513            false, false, true, true, true, false, true, false, true, true, true, true, false, false, false, true,
514            true, true, false, true, true, true, true, true, true, true, true,
515        ];
516
517        let res_bits = ResidualBitsTest { inner: res_bits.iter() };
518        bitstream_encoding.residual_data_and_finalization(false, res_bits, &mut buf_out);
519
520        let buf_out_expected = [
521            230, 243, 160, 169, 152, 75, 36, 156, 223, 96, 241, 214, 150, 248, 180, 106, 115, 92, 147, 213, 56, 100,
522            96, 52, 194, 178, 44, 31, 222, 246, 83, 116, 240, 220, 40, 241, 82, 228, 209, 57, 128, 152, 9, 144, 112,
523            249, 48, 46, 135, 182, 250, 59, 135, 221, 129, 46, 204, 178, 232, 100, 172, 27, 177, 120, 86, 253, 35, 137,
524            19, 253, 191, 202, 97, 240, 10, 45, 124, 110, 234, 149, 49, 115, 209, 177, 153, 231, 93, 211, 214, 19, 127,
525            143, 103, 47, 239, 86, 73, 91, 231, 94, 248, 143, 54, 54, 190, 51, 47, 136, 92, 157, 13, 226, 13, 96, 104,
526            159, 17, 206, 66, 25, 157, 51, 5, 252, 166, 135, 213, 118, 107, 152, 226, 253, 51, 136, 74, 186, 52, 64,
527            236, 152, 115, 0, 29, 23, 247, 3, 20, 124, 21, 116,
528        ];
529
530        assert_eq!(buf_out, buf_out_expected);
531    }
532}