1use super::{
2 bandwidth_detector::BandwidthDetectorResult, buffer_writer::BufferWriter,
3 long_term_post_filter::LongTermPostFilterResult, residual_spectrum::ResidualBits,
4 spectral_noise_shaping::SnsResult, spectral_quantization::SpectralQuantizationResult,
5 temporal_noise_shaping::TnsResult,
6};
7use crate::tables::{spec_noise_shape_quant_tables::*, spectral_data_tables::*, temporal_noise_shaping_tables::*};
8use heapless::Vec;
9
10#[allow(unused_imports)]
11use num_traits::real::Real;
12
13const MAX_NBITS_LSB: usize = 480 * 8; #[derive(Default)]
16pub struct BitstreamEncoding {
17 ne: usize,
18 nbytes: usize,
19 nbits: usize,
20 nbits_side_initial: usize,
21 nlsbs: usize,
22 lsbs: Vec<u8, MAX_NBITS_LSB>,
23 st: ArithmeticEncoderState,
24 writer: BufferWriter,
25}
26
27#[derive(Default)]
28struct ArithmeticEncoderState {
29 pub low: u32,
30 pub range: u32,
31 pub cache: i32,
32 pub carry: i32,
33 pub carry_count: i32,
34}
35
36impl BitstreamEncoding {
37 pub fn new(ne: usize) -> Self {
38 Self {
39 ne,
40 ..Default::default()
41 }
42 }
43
44 fn write_uint_backward(&mut self, value: usize, num_bits: usize, bytes: &mut [u8]) {
45 self.writer.write_uint_backward(bytes, value, num_bits)
46 }
47
48 fn write_bool_backward(&mut self, value: bool, bytes: &mut [u8]) {
49 self.writer.write_bool_backward(bytes, value)
50 }
51
52 fn write_byte_forward(&mut self, value: u8, bytes: &mut [u8]) {
53 self.writer.write_byte_forward(bytes, value)
54 }
55
56 fn write_uint_forward(&mut self, value: usize, num_bits: usize, bytes: &mut [u8]) {
57 self.writer.write_uint_forward(bytes, value as u16, num_bits)
58 }
59
60 fn nbits_side_written(&self) -> usize {
61 self.writer.nbits_side_written(self.nbits)
62 }
63
64 fn nbits_side_forcast(&self) -> usize {
65 let mut nbits_ari = (self.writer.bp * 8) as i32;
66 nbits_ari += 25 - (self.st.range as f64).log2().floor() as i32;
67 if self.st.carry >= 0 {
68 nbits_ari += 8;
69 }
70 if self.st.carry_count > 0 {
71 nbits_ari += self.st.carry_count * 8;
72 }
73
74 nbits_ari as usize
75 }
76
77 pub fn encode<'a>(
78 &mut self,
79 bandwidth: BandwidthDetectorResult,
80 sns: SnsResult,
81 tns: TnsResult,
82 post_filter: LongTermPostFilterResult,
83 spec: SpectralQuantizationResult,
84 residual_bits: ResidualBits<'a>,
85 noise_factor: usize,
86 spec_output: &[i16],
87 buf_out: &mut [u8],
88 ) {
89 self.init(buf_out);
91
92 self.bandwidth(bandwidth.bandwidth_ind, bandwidth.nbits_bandwidth, buf_out);
94 self.last_non_zero_tuple(spec.lastnz_trunc, buf_out);
95 self.lsb_mode_bit(spec.lsb_mode, buf_out);
96 self.global_gain(spec.gg_ind as usize, buf_out);
97 self.tns_activation_flag(tns.num_tns_filters, &tns.rc_order, buf_out);
98 self.pitch_present_flag(post_filter.pitch_present, buf_out);
99
100 self.encode_scf_vq_1st_stage(sns.ind_lf, sns.ind_hf, buf_out);
102
103 self.encode_scf_vq_2nd_stage(sns.shape_j, sns.gind, sns.ls_inda as usize, sns.index_joint_j, buf_out);
105
106 if post_filter.pitch_present {
107 self.ltpf_data(post_filter.ltpf_active, post_filter.pitch_index, buf_out);
108 }
109
110 self.noise_factor(noise_factor, buf_out);
111
112 self.ac_enc_init();
114
115 self.tns_data(
117 tns.lpc_weighting,
118 tns.num_tns_filters,
119 &tns.rc_order,
120 &tns.rc_i,
121 buf_out,
122 );
123
124 self.spectral_data(
126 spec.lastnz_trunc,
127 spec.rate_flag,
128 spec.lsb_mode,
129 spec_output,
130 spec.nbits_lsb,
131 buf_out,
132 );
133
134 self.residual_data_and_finalization(spec.lsb_mode, residual_bits, buf_out);
136 }
137
138 fn init(&mut self, bytes: &mut [u8]) {
139 self.nbytes = bytes.len();
140 self.nbits = self.nbytes * 8;
141 self.writer = BufferWriter::new(bytes.len());
142 bytes.fill(0);
143 self.nlsbs = 0;
144 }
145
146 fn bandwidth(&mut self, p_bw: usize, nbits_bw: usize, bytes: &mut [u8]) {
147 if nbits_bw > 0 {
148 self.write_uint_backward(p_bw, nbits_bw, bytes);
149 }
150 }
151
152 fn last_non_zero_tuple(&mut self, lastnz_trunc: usize, bytes: &mut [u8]) {
153 let value = (lastnz_trunc >> 1) - 1;
154 let num_bits = (self.ne as f64 / 2.0).log2().ceil() as usize;
155 self.write_uint_backward(value, num_bits, bytes)
156 }
157
158 fn lsb_mode_bit(&mut self, lsb_mode: bool, bytes: &mut [u8]) {
159 self.write_bool_backward(lsb_mode, bytes)
160 }
161
162 fn global_gain(&mut self, gg_ind: usize, bytes: &mut [u8]) {
163 self.write_uint_backward(gg_ind, 8, bytes)
164 }
165
166 fn tns_activation_flag(&mut self, num_tns_filters: usize, rc_order: &[usize], bytes: &mut [u8]) {
167 for rc_order_f in rc_order[..num_tns_filters].iter() {
168 let value = *rc_order_f != 0;
169 self.write_bool_backward(value, bytes);
170 }
171 }
172
173 fn pitch_present_flag(&mut self, pitch_present: bool, bytes: &mut [u8]) {
174 self.write_bool_backward(pitch_present, bytes)
175 }
176
177 fn encode_scf_vq_1st_stage(&mut self, ind_lf: usize, ind_hf: usize, bytes: &mut [u8]) {
178 self.write_uint_backward(ind_lf, 5, bytes);
179 self.write_uint_backward(ind_hf, 5, bytes);
180 }
181
182 fn encode_scf_vq_2nd_stage(
183 &mut self,
184 shape_j: usize,
185 gain_i: usize,
186 ls_inda: usize,
187 index_joint_j: usize,
188 bytes: &mut [u8],
189 ) {
190 let submode_msb = (shape_j >> 1) != 0;
191 self.write_bool_backward(submode_msb, bytes);
192 let gain_msbs_num_bits = SNS_GAIN_MSB_BITS[shape_j];
193 let gain_msbs = gain_i >> SNS_GAIN_LSB_BITS[shape_j];
194 self.write_uint_backward(gain_msbs, gain_msbs_num_bits, bytes);
195 let ls_inda_flag = ls_inda != 0;
196 self.write_bool_backward(ls_inda_flag, bytes);
197
198 if !submode_msb {
199 self.write_uint_backward(index_joint_j, 13, bytes);
200 self.write_uint_backward(index_joint_j >> 13, 12, bytes);
201 } else {
202 self.write_uint_backward(index_joint_j, 12, bytes);
203 self.write_uint_backward(index_joint_j >> 12, 12, bytes);
204 }
205 }
206
207 pub fn ltpf_data(&mut self, ltpf_active: bool, pitch_index: usize, bytes: &mut [u8]) {
208 self.write_bool_backward(ltpf_active, bytes);
209 self.write_uint_backward(pitch_index, 9, bytes);
210 }
211
212 pub fn noise_factor(&mut self, f_nf: usize, bytes: &mut [u8]) {
213 self.write_uint_backward(f_nf, 3, bytes);
214 }
215
216 pub fn ac_enc_init(&mut self) {
217 self.st.low = 0;
218 self.st.range = 0x00ff_ffff;
219 self.st.cache = -1;
220 self.st.carry = 0;
221 self.st.carry_count = 0;
222 }
223
224 pub fn tns_data(
225 &mut self,
226 tns_lpc_weighting: u8, num_tns_filters: usize,
228 rc_order: &[usize],
229 rc_i: &[usize],
230 bytes: &mut [u8],
231 ) {
232 for f in 0..num_tns_filters {
233 if rc_order[f] > 0 {
234 let cum_freq = AC_TNS_ORDER_CUMFREQ[tns_lpc_weighting as usize][rc_order[f] - 1];
235 let sym_freq = AC_TNS_ORDER_FREQ[tns_lpc_weighting as usize][rc_order[f] - 1];
236 self.ac_encode(cum_freq, sym_freq, bytes);
237 for k in 0..rc_order[f] {
238 let cum_freq = AC_TNS_COEF_CUMFREQ[k][rc_i[k + 8 * f]];
239 let sym_freq = AC_TNS_COEF_FREQ[k][rc_i[k + 8 * f]];
240 self.ac_encode(cum_freq, sym_freq, bytes);
241 }
242 }
243 }
244 }
245
246 pub fn spectral_data(
247 &mut self,
248 lastnz_trunc: usize,
249 rate_flag: usize,
250 lsb_mode: bool,
251 x_q: &[i16],
252 nbits_lsb: usize,
253 bytes: &mut [u8],
254 ) {
255 self.nbits_side_initial = self.nbits_side_written();
256 self.lsbs.clear();
257 for _ in 0..nbits_lsb {
258 self.lsbs.push(0).unwrap();
259 }
260
261 let mut c = 0;
262 for k in (0..lastnz_trunc).step_by(2) {
263 let mut t = c + rate_flag + if k > (self.ne / 2) { 256 } else { 0 };
264 let mut a = x_q[k].unsigned_abs();
265 let mut a_lsb = a;
266 let mut b = x_q[k + 1].unsigned_abs();
267 let mut b_lsb = b;
268 let mut lev = 0;
269 let mut lsb0: u8 = 0;
270 let mut lsb1: u8 = 0;
271 while a.max(b) >= 4 {
272 let pki_index = t + lev.min(3) * 1024;
273 let pki = AC_SPEC_LOOKUP[pki_index] as usize;
274 let cum_freq = AC_SPEC_CUMFREQ[pki][16];
275 let sym_freq = AC_SPEC_FREQ[pki][16];
276 self.ac_encode(cum_freq, sym_freq, bytes);
277 if lsb_mode && lev == 0 {
278 lsb0 = a as u8 & 1;
279 lsb1 = b as u8 & 1;
280 } else {
281 self.write_bool_backward((a & 1) == 1, bytes);
282 self.write_bool_backward((b & 1) == 1, bytes);
283 }
284 a >>= 1;
285 b >>= 1;
286 lev += 1;
287 }
288 let pki_index = t + lev.min(3) * 1024;
289 let pki = AC_SPEC_LOOKUP[pki_index] as usize;
290 let sym = (a + 4 * b) as usize;
291 let cum_freq = AC_SPEC_CUMFREQ[pki][sym];
292 let sym_freq = AC_SPEC_FREQ[pki][sym];
293 self.ac_encode(cum_freq, sym_freq, bytes);
294
295 if lsb_mode && lev > 0 {
296 a_lsb >>= 1;
297 b_lsb >>= 1;
298
299 self.lsbs[self.nlsbs] = lsb0;
300 self.nlsbs += 1;
301 if a_lsb == 0 && x_q[k] != 0 {
302 self.lsbs[self.nlsbs] = if x_q[k] > 0 { 0 } else { 1 };
303 self.nlsbs += 1;
304 }
305 self.lsbs[self.nlsbs] = lsb1;
306 self.nlsbs += 1;
307 if b_lsb == 0 && x_q[k + 1] != 0 {
308 self.lsbs[self.nlsbs] = if x_q[k + 1] > 0 { 0 } else { 1 };
309 self.nlsbs += 1;
310 }
311 }
312 if a_lsb > 0 {
313 self.write_bool_backward(x_q[k] <= 0, bytes);
314 }
315 if b_lsb > 0 {
316 self.write_bool_backward(x_q[k + 1] <= 0, bytes);
317 }
318 lev = lev.min(3);
319 t = if lev <= 1 {
320 1 + ((a + b) as usize) * (lev + 1)
321 } else {
322 12 + lev
323 };
324 c = (c & 15) * 16 + t;
325 }
326 }
327
328 pub fn residual_data_and_finalization(
329 &mut self,
330 lsb_mode: bool,
331 residual_bits: impl Iterator<Item = bool>,
332 bytes: &mut [u8],
333 ) {
334 let nbits_side = self.nbits_side_written();
335 let nbits_ari = self.nbits_side_forcast();
336 let nbits_residual_enc = self.nbits as i32 - (nbits_side + nbits_ari) as i32;
337 let nbits_residual_enc = nbits_residual_enc.max(0) as usize;
338
339 if !lsb_mode {
340 for res_bit in residual_bits.take(nbits_residual_enc) {
341 self.write_bool_backward(res_bit, bytes);
342 }
343 } else {
344 let nbits_residual_enc = nbits_residual_enc.min(self.nlsbs);
345 for k in 0..nbits_residual_enc {
346 let value = self.lsbs[k] == 1;
347 self.write_bool_backward(value, bytes);
348 }
349 }
350
351 self.ac_enc_finish(bytes);
352 }
353
354 fn ac_enc_finish(&mut self, bytes: &mut [u8]) {
355 let mut bits: i8 = 1;
356 while (self.st.range >> (24 - bits)) == 0 {
357 bits += 1;
358 }
359 let mut mask = 0x00ff_ffff >> bits;
360 let mut val = self.st.low + mask;
361 let over1 = val >> 24;
362 let high = self.st.low + self.st.range;
363 let over2 = high >> 24;
364 val &= 0x00ff_ffff & !mask;
365 if over1 == over2 {
366 if (val + mask) >= high {
367 bits += 1;
368 mask >>= 1;
369 val = ((self.st.low + mask) & 0x00ff_ffff) & !mask;
370 }
371 if val < self.st.low {
372 self.st.carry = 1;
373 }
374 }
375 self.st.low = val;
376 while bits > 0 {
377 self.ac_shift(bytes);
378 bits -= 8;
379 }
380 bits += 8;
381 if bits < 0 {
382 panic!("bits is negative: {}", bits);
383 }
384 if self.st.carry_count > 0 {
385 self.write_byte_forward(self.st.cache as u8, bytes);
386 while self.st.carry_count > 1 {
387 self.write_byte_forward(0xff, bytes);
388 self.st.carry_count -= 1;
389 }
390 let value = 0xff >> (8 - bits);
391 self.write_uint_forward(value, bits as usize, bytes);
392 } else {
393 self.write_uint_forward(self.st.cache as usize, bits as usize, bytes);
394 }
395 }
396
397 fn ac_shift(&mut self, bytes: &mut [u8]) {
398 if self.st.low < 0x00ff_0000 || self.st.carry == 1 {
399 if self.st.cache >= 0 {
400 let byte = ((self.st.cache + self.st.carry) & 0xff) as u8;
401 self.write_byte_forward(byte, bytes);
402 }
403 while self.st.carry_count > 0 {
404 let byte = ((self.st.carry + 0xff) & 0xff) as u8;
405 self.write_byte_forward(byte, bytes);
406 self.st.carry_count -= 1;
407 }
408 self.st.cache = (self.st.low >> 16) as i32;
409 self.st.carry = 0;
410 } else {
411 self.st.carry_count += 1;
412 }
413 self.st.low <<= 8;
414 self.st.low &= 0x00ff_ffff;
415 }
416
417 fn ac_encode(&mut self, cum_freq: i16, sym_freq: i16, bytes: &mut [u8]) {
418 let r = self.st.range >> 10;
419 self.st.low += r * cum_freq as u32;
420 if self.st.low >> 24 != 0 {
421 self.st.carry = 1;
422 }
423 self.st.low &= 0x00ff_ffff;
424 self.st.range = r * sym_freq as u32;
425 while self.st.range < 0x10000 {
426 self.st.range <<= 8;
427 self.ac_shift(bytes);
428 }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 extern crate std;
435
436 use core::slice::Iter;
437
438 use super::*;
439
440 pub struct ResidualBitsTest<'a> {
441 inner: Iter<'a, bool>,
442 }
443
444 impl<'a> Iterator for ResidualBitsTest<'a> {
445 type Item = bool;
446
447 fn next(&mut self) -> Option<Self::Item> {
448 match self.inner.next() {
449 Some(x) => Some(*x),
450 None => None,
451 }
452 }
453 }
454
455 #[test]
457 fn bitstream_encoding_run() {
458 let mut bitstream_encoding = BitstreamEncoding::new(400);
459 let mut buf_out = [0; 150];
460
461 bitstream_encoding.init(&mut buf_out);
462
463 bitstream_encoding.bandwidth(4, 3, &mut buf_out);
464
465 bitstream_encoding.last_non_zero_tuple(350, &mut buf_out);
466 bitstream_encoding.lsb_mode_bit(false, &mut buf_out);
467 bitstream_encoding.global_gain(193, &mut buf_out);
468 let rc_order = [8, 6];
469 bitstream_encoding.tns_activation_flag(2, &rc_order, &mut buf_out);
470 let pitch_present = true;
471 bitstream_encoding.pitch_present_flag(pitch_present, &mut buf_out);
472
473 bitstream_encoding.encode_scf_vq_1st_stage(8, 17, &mut buf_out);
474
475 bitstream_encoding.encode_scf_vq_2nd_stage(3, 0, 0, 15253432, &mut buf_out);
476
477 if pitch_present {
478 bitstream_encoding.ltpf_data(false, 0, &mut buf_out);
479 }
480
481 bitstream_encoding.noise_factor(6, &mut buf_out);
482
483 bitstream_encoding.ac_enc_init();
484
485 let rc_i = [10, 7, 8, 9, 7, 9, 8, 9, 14, 11, 6, 9, 7, 9, 8, 8];
486 bitstream_encoding.tns_data(0, 2, &rc_order, &rc_i, &mut buf_out);
487
488 let x_q = [
489 102, -146, -18, -14, -104, -128, 264, 254, -417, -180, 94, -28, 20, -38, 21, -62, -125, 10, -15, -4, 27,
490 -9, -4, 3, 3, -1, 0, -13, -2, 0, -11, 3, 5, 4, -10, -18, -22, 4, 10, -5, 17, 4, -6, 2, 6, 11, -3, -3, 29,
491 16, -15, 3, 4, 7, 4, -3, 5, 0, 6, 0, -6, 1, 0, -1, -2, 7, 6, 2, -9, -4, 3, -5, 3, 6, 4, -1, 3, 5, -1, -10,
492 -16, 1, 1, 0, -4, -1, 7, -5, -4, -2, 0, -4, 1, 4, -1, -2, -7, 1, -2, 1, 1, -7, 1, 4, -1, -1, 2, 0, -1, -2,
493 1, 3, -5, -1, 0, 2, 0, 0, 2, 0, 1, -3, 1, 2, 0, -5, -1, 5, -1, 0, -3, 0, 0, -1, 0, -2, 2, -3, 0, 1, -2, -1,
494 -2, 0, 1, 2, -2, 0, -1, -3, -2, -1, 3, -2, -2, 0, 1, 0, -3, 1, 0, 0, -1, 0, 1, 0, 1, -2, 1, 1, 0, -1, 0, 0,
495 1, 2, -1, 0, -1, 1, 0, -1, 1, -1, 1, -1, 0, 0, 0, -1, -1, 0, -2, 1, -1, -1, -1, -1, 0, -2, 0, -1, -1, 0, 0,
496 0, 1, 0, -1, 0, 1, 1, 0, 0, 0, -1, 0, 0, -2, -1, 0, 1, 0, 0, 0, 1, -1, -1, 1, 0, 0, -1, -2, -1, -1, 0, 0,
497 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, -1, -1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0,
498 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
499 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0,
500 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
501 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
502 ];
503 bitstream_encoding.spectral_data(350, 512, false, &x_q, 107, &mut buf_out);
504
505 let res_bits = [
506 false, true, false, false, false, false, true, true, false, true, false, true, true, true, true, false,
507 false, true, false, true, true, true, false, true, true, true, false, true, false, true, true, true, false,
508 false, true, true, true, true, false, false, true, true, true, false, true, false, false, true, false,
509 true, true, true, true, false, true, false, true, false, false, true, false, true, true, false, false,
510 true, false, false, false, true, false, true, true, true, false, false, true, false, false, true, true,
511 false, true, true, false, false, true, false, false, true, false, true, false, false, false, true, false,
512 true, false, true, true, true, true, true, true, false, false, true, true, false, false, true, false,
513 false, false, true, true, true, false, true, false, true, true, true, true, false, false, false, true,
514 true, true, false, true, true, true, true, true, true, true, true,
515 ];
516
517 let res_bits = ResidualBitsTest { inner: res_bits.iter() };
518 bitstream_encoding.residual_data_and_finalization(false, res_bits, &mut buf_out);
519
520 let buf_out_expected = [
521 230, 243, 160, 169, 152, 75, 36, 156, 223, 96, 241, 214, 150, 248, 180, 106, 115, 92, 147, 213, 56, 100,
522 96, 52, 194, 178, 44, 31, 222, 246, 83, 116, 240, 220, 40, 241, 82, 228, 209, 57, 128, 152, 9, 144, 112,
523 249, 48, 46, 135, 182, 250, 59, 135, 221, 129, 46, 204, 178, 232, 100, 172, 27, 177, 120, 86, 253, 35, 137,
524 19, 253, 191, 202, 97, 240, 10, 45, 124, 110, 234, 149, 49, 115, 209, 177, 153, 231, 93, 211, 214, 19, 127,
525 143, 103, 47, 239, 86, 73, 91, 231, 94, 248, 143, 54, 54, 190, 51, 47, 136, 92, 157, 13, 226, 13, 96, 104,
526 159, 17, 206, 66, 25, 157, 51, 5, 252, 166, 135, 213, 118, 107, 152, 226, 253, 51, 136, 74, 186, 52, 64,
527 236, 152, 115, 0, 29, 23, 247, 3, 20, 124, 21, 116,
528 ];
529
530 assert_eq!(buf_out, buf_out_expected);
531 }
532}