1use super::{
2 buffer_reader::{BufferReader, BufferReaderError},
3 side_info::SideInfo,
4};
5use crate::{
6 common::{
7 config::FrameDuration,
8 constants::{MAX_LEN_FREQUENCY, MAX_LEN_SPECTRAL},
9 },
10 tables::{
11 spectral_data_tables::{AC_SPEC_CUMFREQ, AC_SPEC_FREQ, AC_SPEC_LOOKUP},
12 temporal_noise_shaping_tables::{
13 AC_TNS_COEF_CUMFREQ, AC_TNS_COEF_FREQ, AC_TNS_ORDER_CUMFREQ, AC_TNS_ORDER_FREQ, MAXLAG, TNS_NUMFILTERS_MAX,
14 },
15 },
16};
17use heapless::Vec;
18#[allow(unused_imports)]
19use num_traits::real::Real;
20
21#[derive(Debug)]
22struct ArithmeticDecoderState {
23 ac_low: u32, ac_range: u32, }
26
27#[derive(Debug)]
28pub enum ArithmeticCodecError {
29 AcRangeFlOutOfRange(u32, u32),
30 BufferReader(BufferReaderError),
31}
32
33impl From<BufferReaderError> for ArithmeticCodecError {
34 fn from(err: BufferReaderError) -> Self {
35 Self::BufferReader(err)
36 }
37}
38
39#[derive(Debug)]
40pub enum ArithmeticDecodeError {
41 ArithmeticCodec(ArithmeticCodecError),
42 TnsOrder(usize, ArithmeticCodecError),
43 TnsCoef(usize, usize, ArithmeticCodecError),
44 SpectralData(usize, usize, ArithmeticCodecError),
45 SpectralBoolData(usize, usize, BufferReaderError),
46 NegativeResidualNumBits,
47 ResidualBoolData(bool, usize),
48 ResidualBoolDataOverflow(bool, usize, usize),
49}
50
51impl From<ArithmeticCodecError> for ArithmeticDecodeError {
52 fn from(err: ArithmeticCodecError) -> Self {
53 Self::ArithmeticCodec(err)
54 }
55}
56
57fn ac_dec_init(buf: &[u8], reader: &mut BufferReader) -> Result<ArithmeticDecoderState, ArithmeticCodecError> {
58 let ac_low_fl = reader.read_head_u24(buf)?;
59 let ac_range_fl = 0x00ffffff;
60
61 Ok(ArithmeticDecoderState {
62 ac_low: ac_low_fl,
63 ac_range: ac_range_fl,
64 })
65}
66
67fn ac_decode(
68 buf: &[u8],
69 reader: &mut BufferReader,
70 st: &mut ArithmeticDecoderState,
71 cum_freq: &[i16],
72 sym_freq: &[i16],
73) -> Result<usize, ArithmeticCodecError> {
74 let tmp = st.ac_range >> 10;
75
76 let limit = tmp << 10;
77 if st.ac_low >= limit {
78 return Err(ArithmeticCodecError::AcRangeFlOutOfRange(st.ac_low, limit));
79 }
80
81 let mut val = cum_freq.len() - 1;
82 while st.ac_low < (tmp * cum_freq[val] as u32) {
83 val -= 1;
84 }
85
86 st.ac_low -= tmp * cum_freq[val] as u32;
87 st.ac_range = tmp * sym_freq[val] as u32;
88
89 while st.ac_range < 0x10000 {
90 st.ac_low <<= 8;
91 st.ac_low &= 0x00ffffff;
92 st.ac_low += reader.read_head_byte(buf)? as u32;
93 st.ac_range <<= 8;
94 }
95
96 Ok(val)
97}
98
99#[derive(Debug, PartialEq)]
100pub struct ArithmeticData {
101 pub reflect_coef_order: [usize; 2], pub reflect_coef_ints: [usize; 16], pub residual_bits: Vec<bool, 480>,
104 pub noise_filling_seed: i32,
105 pub is_zero_frame: bool,
106 pub frame_num_bits: usize, }
108
109pub fn decode(
110 buf: &[u8], reader: &mut BufferReader, fs_ind: usize, ne: usize, side_info: &SideInfo, n_ms: &FrameDuration,
116 x: &mut [i32],
117) -> Result<ArithmeticData, ArithmeticDecodeError> {
118 let num_bytes = buf.len();
119 let nbits = num_bytes * 8;
120
121 let mut st = ac_dec_init(buf, reader)?;
123
124 let (tns_idx, tns_order) = decode_tns_data(buf, reader, side_info, &mut st, nbits, n_ms)?;
126
127 let mut save_lev: [i32; MAX_LEN_SPECTRAL] = [0; MAX_LEN_SPECTRAL];
129 decode_spectral_data(buf, reader, side_info, nbits, fs_ind, ne, &mut st, x, &mut save_lev)?;
130
131 for item in &mut x[side_info.lastnz..] {
133 *item = 0;
134 }
135
136 let residual_bits = decode_residual_bits(buf, reader, side_info, &st, nbits, ne, x, &mut save_lev)?;
138
139 let noise_filling_seed = x[..ne]
141 .iter()
142 .enumerate()
143 .map(|(k, item)| item.abs() * k as i32)
144 .sum::<i32>()
145 & 0xFFFF;
146
147 let is_zero_frame = side_info.lastnz == 2 && x[0] == 0 && x[1] == 0 && side_info.global_gain_index == 0;
149
150 Ok(ArithmeticData {
151 is_zero_frame,
152 noise_filling_seed,
153 reflect_coef_ints: tns_idx,
154 reflect_coef_order: tns_order,
155 residual_bits,
156 frame_num_bits: nbits,
157 })
158}
159
160fn decode_residual_bits(
161 buf: &[u8],
162 reader: &mut BufferReader,
163 side_info: &SideInfo,
164 st: &ArithmeticDecoderState,
165 nbits: usize,
166 ne: usize,
167 x: &mut [i32],
168 save_lev: &mut [i32],
169) -> Result<Vec<bool, MAX_LEN_FREQUENCY>, ArithmeticDecodeError> {
170 let mut nbits_residual = calc_num_residual_bits(reader, st, nbits)?;
172 let lsb_mode = side_info.lsb_mode;
173 let mut residual_bits = Vec::new();
174
175 if !lsb_mode {
177 for (k, x_k) in x[..ne].iter().enumerate() {
179 if *x_k != 0 {
180 if residual_bits.len() == nbits_residual {
181 break;
182 }
183
184 let bit = reader
185 .read_tail_bool(buf)
186 .map_err(|_| ArithmeticDecodeError::ResidualBoolData(lsb_mode, k))?;
187
188 residual_bits
189 .push(bit)
190 .map_err(|_| ArithmeticDecodeError::ResidualBoolDataOverflow(lsb_mode, k, residual_bits.len()))?;
191 }
192 }
193 } else {
194 for k in (0..side_info.lastnz).step_by(2) {
195 if save_lev[k] > 0 {
196 if !read_res_bit(x, reader, buf, k, &mut nbits_residual, lsb_mode)? {
197 break;
198 }
199
200 if !read_res_bit(x, reader, buf, k + 1, &mut nbits_residual, lsb_mode)? {
201 break;
202 }
203 }
204 }
205 }
206
207 Ok(residual_bits)
208}
209
210fn decode_spectral_data(
212 buf: &[u8],
213 reader: &mut BufferReader,
214 side_info: &SideInfo,
215 nbits: usize,
216 fs_ind: usize,
217 ne: usize,
218 st: &mut ArithmeticDecoderState,
219 x: &mut [i32],
220 save_lev: &mut [i32],
221) -> Result<(), ArithmeticDecodeError> {
222 let rate_flag = if nbits > (160 + fs_ind * 160) { 512 } else { 0 };
224 let mut c = 0;
225
226 for (k, chunk) in x[..side_info.lastnz].chunks_exact_mut(2).enumerate() {
227 let mut t = c + rate_flag + if (k * 2) > (ne / 2) { 256 } else { 0 };
228
229 let (x_k, x_kplus1) = chunk.split_at_mut(1);
231 let x_k = &mut x_k[0];
232 let x_kplus1 = &mut x_kplus1[0];
233
234 *x_k = 0;
235 *x_kplus1 = 0;
236 let mut sym = 0;
237 let mut lev: usize = 0;
238
239 while lev < 14 {
241 let pki_index = t + lev.min(3) * 1024;
242 let pki = AC_SPEC_LOOKUP[pki_index] as usize;
243
244 let cum_freq = &AC_SPEC_CUMFREQ[pki];
245 let spec_freq = &AC_SPEC_FREQ[pki];
246 sym = ac_decode(buf, reader, st, cum_freq, spec_freq)
247 .map_err(|err| ArithmeticDecodeError::SpectralData(k, lev, err))?;
248
249 if sym < 16 {
250 break;
251 }
252
253 if !side_info.lsb_mode || lev > 0 {
254 let bit = reader
255 .read_tail_bool(buf)
256 .map_err(|err| ArithmeticDecodeError::SpectralBoolData(k, lev, err))?
257 as i32;
258 *x_k += bit << lev;
259 let bit = reader
260 .read_tail_bool(buf)
261 .map_err(|err| ArithmeticDecodeError::SpectralBoolData(k, lev, err))?
262 as i32;
263 *x_kplus1 += bit << lev;
264 }
265
266 lev += 1;
267 }
268
269 if side_info.lsb_mode {
270 save_lev[k] = lev as i32;
272 }
273
274 let a = sym & 0x3;
275 let b = sym >> 2;
276
277 *x_k += (a as i32) << lev;
278 *x_kplus1 += (b as i32) << lev;
279
280 if *x_k > 0 {
281 let bit = reader
282 .read_tail_bool(buf)
283 .map_err(|err| ArithmeticDecodeError::SpectralBoolData(k, lev, err))?;
284 if bit {
285 *x_k = -*x_k;
286 }
287 }
288
289 if *x_kplus1 > 0 {
290 let bit = reader
291 .read_tail_bool(buf)
292 .map_err(|err| ArithmeticDecodeError::SpectralBoolData(k, lev, err))?;
293 if bit {
294 *x_kplus1 = -*x_kplus1;
295 }
296 }
297
298 lev = lev.min(3);
299 t = if lev <= 1 { 1 + (a + b) * (lev + 1) } else { 12 + lev };
300
301 c = (c & 15) * 16 + t;
302 }
303
304 Ok(())
305}
306
307fn decode_tns_data(
308 buf: &[u8],
309 reader: &mut BufferReader,
310 side_info: &SideInfo,
311 st: &mut ArithmeticDecoderState,
312 nbits: usize,
313 n_ms: &FrameDuration,
314) -> Result<([usize; 16], [usize; 2]), ArithmeticDecodeError> {
315 let max_bits = match n_ms {
316 FrameDuration::SevenPointFiveMs => 360,
317 FrameDuration::TenMs => 480,
318 };
319
320 let tns_lpc_weighting = nbits < max_bits; let tns_lpc_weighting_idx = tns_lpc_weighting as usize;
322
323 let mut tns_idx: [usize; TNS_NUMFILTERS_MAX * MAXLAG] = [0; TNS_NUMFILTERS_MAX * MAXLAG];
324 let mut tns_order = side_info.reflect_coef_order_ari_input; for (f, tns_order_f) in tns_order[..side_info.num_tns_filters].iter_mut().enumerate() {
326 if *tns_order_f > 0 {
327 let cum_freq = &AC_TNS_ORDER_CUMFREQ[tns_lpc_weighting_idx];
328 let sym_freq = &AC_TNS_ORDER_FREQ[tns_lpc_weighting_idx];
329 let order = ac_decode(buf, reader, st, cum_freq, sym_freq)
330 .map_err(|err| ArithmeticDecodeError::TnsOrder(f, err))?;
331
332 *tns_order_f = order + 1;
333 for k in 0..*tns_order_f {
334 let idx = f * 8 + k;
335 let cum_freq = &AC_TNS_COEF_CUMFREQ[k];
336 let sym_freq = &AC_TNS_COEF_FREQ[k];
337 tns_idx[idx] = ac_decode(buf, reader, st, cum_freq, sym_freq)
338 .map_err(|err| ArithmeticDecodeError::TnsCoef(f, k, err))?;
339 }
340 }
341 }
342
343 Ok((tns_idx, tns_order))
344}
345
346fn read_res_bit(
347 x: &mut [i32],
348 reader: &mut BufferReader,
349 buf: &[u8],
350 x_index: usize,
351 nbits_res: &mut usize,
352 lsb_mode: bool,
353) -> Result<bool, ArithmeticDecodeError> {
354 if *nbits_res == 0 {
356 return Ok(false);
357 }
358 let bit = reader
359 .read_tail_bool(buf)
360 .map_err(|_| ArithmeticDecodeError::ResidualBoolData(lsb_mode, x_index))?;
361 *nbits_res -= 1;
362
363 if bit {
364 let val = &mut x[x_index];
365 match val {
366 v if *v > 0 => {
367 *v += 1;
368 }
369 v if *v < 0 => {
370 *v -= 1;
371 }
372 v => {
373 if *nbits_res == 0 {
375 return Ok(false);
376 }
377 let bit = reader
378 .read_tail_bool(buf)
379 .map_err(|_| ArithmeticDecodeError::ResidualBoolData(lsb_mode, x_index))?;
380 *nbits_res -= 1;
381
382 *v = if bit { -1 } else { 1 };
383 }
384 };
385 }
386
387 Ok(true)
388}
389
390fn calc_num_residual_bits(
391 reader: &BufferReader,
392 st: &ArithmeticDecoderState,
393 total_bits: usize,
394) -> Result<usize, ArithmeticDecodeError> {
395 let nbits_side = reader.get_tail_bit_cursor() - 8;
396
397 let nbits_ari = (reader.get_head_byte_cursor() + 1 - 3) * 8 + 25 - (st.ac_range as f64).log2().floor() as usize;
399
400 if total_bits >= (nbits_side + nbits_ari) {
401 Ok(total_bits - nbits_side - nbits_ari)
402 } else {
403 Err(ArithmeticDecodeError::NegativeResidualNumBits)
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 extern crate std;
410 use crate::decoder::side_info::{Bandwidth, LongTermPostFilterInfo, SnsVq};
411
412 use super::*;
413
414 #[test]
415 fn arithmetic_decode() {
416 let buf = [
417 187, 56, 111, 155, 76, 236, 70, 99, 10, 135, 219, 76, 176, 3, 108, 203, 131, 111, 206, 221, 195, 25, 96,
418 240, 18, 202, 163, 241, 109, 142, 198, 122, 176, 70, 37, 6, 35, 190, 110, 184, 251, 162, 71, 7, 151, 58,
419 42, 79, 200, 192, 99, 157, 234, 156, 245, 43, 84, 64, 167, 32, 52, 106, 43, 75, 4, 102, 213, 123, 168, 120,
420 213, 252, 208, 118, 78, 115, 154, 158, 157, 26, 152, 231, 121, 146, 203, 11, 169, 227, 75, 154, 237, 154,
421 227, 145, 196, 182, 207, 94, 95, 26, 184, 248, 1, 118, 72, 47, 18, 205, 56, 96, 195, 139, 216, 240, 113,
422 233, 44, 198, 245, 157, 139, 70, 162, 182, 139, 136, 165, 68, 79, 247, 161, 126, 17, 135, 36, 30, 229, 24,
423 196, 2, 5, 65, 111, 80, 124, 168, 70, 156, 198, 60,
424 ];
425 let mut reader = BufferReader::new_at(0, 64);
426 let fs_ind = 4;
427 let ne = 400;
428 let side_info = SideInfo {
429 bandwidth: Bandwidth::FullBand,
430 lastnz: 400,
431 lsb_mode: false,
432 global_gain_index: 204,
433 num_tns_filters: 2,
434 reflect_coef_order_ari_input: [1, 0],
435 sns_vq: SnsVq {
436 ind_lf: 13,
437 ind_hf: 4,
438 ls_inda: 1,
439 ls_indb: 0,
440 idx_a: 1718290,
441 idx_b: 2,
442 submode_lsb: 0,
443 submode_msb: 0,
444 g_ind: 0,
445 },
446 long_term_post_filter_info: LongTermPostFilterInfo {
447 pitch_present: false,
448 is_active: false,
449 pitch_index: 0,
450 },
451 noise_factor: 3,
452 };
453 let n_ms = &FrameDuration::TenMs;
454 let mut x = [0; MAX_LEN_SPECTRAL];
455
456 let arithmetic_data = decode(&buf, &mut reader, fs_ind, ne, &side_info, &n_ms, &mut x).unwrap();
457
458 assert_eq!(arithmetic_data.is_zero_frame, false);
459 assert_eq!(arithmetic_data.frame_num_bits, 1200);
460 assert_eq!(arithmetic_data.noise_filling_seed, 56909);
461 assert_eq!(
462 arithmetic_data.reflect_coef_ints,
463 [6, 10, 7, 8, 7, 9, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0]
464 );
465 assert_eq!(
466 arithmetic_data.residual_bits,
467 [
468 false, true, true, true, false, false, false, true, false, false, true, true, true, false, false,
469 false, true, true, true, false, true, false, true, true, false, false, true, true, false, true, true,
470 false, true, true, true, false, true, false, true, true, false, false, true, true, true
471 ]
472 );
473 assert_eq!(arithmetic_data.reflect_coef_order, [8, 0]);
474 }
475}