1use std::cmp::min;
2
3use {invalid_data, InvalidData, StreamInfo};
4use bitcursor::BitCursor;
5
6pub trait Sample: Copy + private::Sealed {
8 fn from_decoder(sample: i32, bits: u8) -> Self;
10
11 fn bits() -> u8;
12}
13
14impl Sample for i16 {
15 #[inline(always)]
16 fn from_decoder(sample: i32, _: u8) -> Self {
17 sample as i16
18 }
19
20 #[inline(always)]
21 fn bits() -> u8 {
22 16
23 }
24}
25
26impl Sample for i32 {
27 #[inline(always)]
28 fn from_decoder(sample: i32, bits: u8) -> Self {
29 sample << (32 - bits)
30 }
31
32 #[inline(always)]
33 fn bits() -> u8 {
34 32
35 }
36}
37
38mod private {
39 pub trait Sealed {}
41 impl Sealed for i16 {}
42 impl Sealed for i32 {}
43}
44
45pub struct Decoder {
47 config: StreamInfo,
48 buf: Box<[i32]>,
49}
50
51const ID_SCE: u8 = 0; const ID_CPE: u8 = 1; const ID_CCE: u8 = 2; const ID_LFE: u8 = 3; const ID_DSE: u8 = 4; const ID_PCE: u8 = 5;
57const ID_FIL: u8 = 6; const ID_END: u8 = 7; impl Decoder {
61 pub fn new(config: StreamInfo) -> Decoder {
63 Decoder {
64 buf: vec![0; config.frame_length as usize * 2].into_boxed_slice(),
65 config,
66 }
67 }
68
69 pub fn stream_info(&self) -> &StreamInfo {
71 &self.config
72 }
73
74 pub fn decode_packet<'a, S: Sample>(
81 &mut self,
82 packet: &[u8],
83 out: &'a mut [S],
84 ) -> Result<&'a [S], InvalidData> {
85 let mut reader = BitCursor::new(packet);
86
87 let mut channel_index = 0;
88 let mut frame_samples = None;
89
90 assert!(out.len() >= self.config.max_samples_per_packet() as usize);
91 assert!(S::bits() >= self.config.bit_depth);
92
93 loop {
94 let tag = reader.read_u8(3)?;
95
96 match tag {
97 tag @ ID_SCE | tag @ ID_LFE | tag @ ID_CPE => {
98 let element_channels = match tag {
99 ID_SCE => 1,
100 ID_LFE => 1,
101 ID_CPE => 2,
102 _ => unreachable!(),
103 };
104
105 if channel_index + element_channels > self.config.num_channels {
107 return Err(invalid_data("packet contains more channels than expected"));
108 }
109
110 let element_samples = decode_audio_element(
111 self,
112 &mut reader,
113 out,
114 channel_index,
115 element_channels,
116 )?;
117
118 if let Some(frame_samples) = frame_samples {
120 if frame_samples != element_samples {
121 return Err(invalid_data(
122 "all channels in a packet must contain the same number of samples",
123 ));
124 }
125 } else {
126 frame_samples = Some(element_samples);
127 }
128
129 channel_index += element_channels;
130 }
131 ID_CCE | ID_PCE => {
132 return Err(invalid_data("packet cce and pce elements are unsupported"));
133 }
134 ID_DSE => {
135 let _element_instance_tag = reader.read_u8(4)?;
140 let data_byte_align_flag = reader.read_bit()?;
141
142 let mut skip_bytes = reader.read_u8(8)? as usize;
144 if skip_bytes == 255 {
145 skip_bytes += reader.read_u8(8)? as usize;
146 }
147
148 if data_byte_align_flag {
151 reader.skip_to_byte()?;
152 }
153
154 reader.skip(skip_bytes * 8)?;
155 }
156 ID_FIL => {
157 let mut skip_bytes = reader.read_u8(4)? as usize;
162 if skip_bytes == 15 {
163 skip_bytes += reader.read_u8(8)? as usize - 1
164 }
165
166 reader.skip(skip_bytes * 8)?;
167 }
168 ID_END => {
169 reader.skip_to_byte()?;
173
174 if channel_index != self.config.num_channels {
176 return Err(invalid_data("packet contains fewer channels than expected"));
177 }
178
179 let frame_samples = frame_samples.unwrap_or(self.config.frame_length);
180 return Ok(&out[..frame_samples as usize * channel_index as usize]);
181 }
182 _ => unreachable!(),
184 }
185 }
186 }
187}
188
189fn decode_audio_element<'a, S: Sample>(
190 this: &mut Decoder,
191 reader: &mut BitCursor<'a>,
192 out: &mut [S],
193 channel_index: u8,
194 element_channels: u8,
195) -> Result<u32, InvalidData> {
196 let _element_instance_tag = reader.read_u8(4)?;
198
199 let unused = reader.read_u16(12)?;
200 if unused != 0 {
201 return Err(invalid_data("unused channel header bits must be zero"));
202 }
203
204 let partial_frame = reader.read_bit()?;
206
207 let sample_shift_bytes = reader.read_u8(2)?;
208 if sample_shift_bytes > 2 {
209 return Err(invalid_data(
210 "channel sample shift must not be greater than 16",
211 ));
212 }
213 let sample_shift = sample_shift_bytes * 8;
214
215 let is_uncompressed = reader.read_bit()?;
216
217 let num_samples = if partial_frame {
219 let num_samples = reader.read_u32(32)?;
221
222 if num_samples > this.config.frame_length {
223 return Err(invalid_data("channel contains more samples than expected"));
224 }
225
226 num_samples as usize
227 } else {
228 this.config.frame_length as usize
229 };
230
231 if !is_uncompressed {
232 let (buf_u, buf_v) = this.buf.split_at_mut(this.config.frame_length as usize);
233 let mut mix_buf = [&mut buf_u[..num_samples], &mut buf_v[..num_samples]];
234
235 let chan_bits = this.config.bit_depth - sample_shift + element_channels - 1;
236 if chan_bits > 32 {
237 return Err(invalid_data("channel bit depth cannot be greater than 32"));
239 }
240
241 let mix_bits: u8 = reader.read_u8(8)?;
243 let mix_res: i8 = reader.read_u8(8)? as i8;
244
245 let mut lpc_mode = [0; 2]; let mut lpc_quant = [0; 2]; let mut pb_factor = [0; 2]; let mut lpc_order = [0; 2]; let mut lpc_coefs = [[0; 32]; 2]; for i in 0..(element_channels as usize) {
252 lpc_mode[i] = reader.read_u8(4)?;
253 lpc_quant[i] = reader.read_u8(4)? as u32;
254 pb_factor[i] = reader.read_u8(3)? as u16;
255 lpc_order[i] = reader.read_u8(5)?;
256
257 for j in (0..lpc_order[i] as usize).rev() {
259 lpc_coefs[i][j] = reader.read_u16(16)? as i16;
260 }
261 }
262
263 let extra_bits_reader = if sample_shift != 0 {
264 let extra_bits_reader = reader.clone();
265 reader.skip((sample_shift as usize) * num_samples * element_channels as usize)?;
266 Some(extra_bits_reader)
267 } else {
268 None
269 };
270
271 for i in 0..(element_channels as usize) {
275 rice_decompress(
276 reader,
277 &this.config,
278 &mut mix_buf[i],
279 chan_bits,
280 pb_factor[i],
281 )?;
282
283 if lpc_mode[i as usize] == 15 {
284 lpc_predict_order_31(mix_buf[i], chan_bits);
286 } else if lpc_mode[i as usize] > 0 {
287 return Err(invalid_data("invalid lpc mode"));
288 }
289
290 assert!(lpc_order[i] != 31);
292
293 let lpc_coefs = &mut lpc_coefs[i][..lpc_order[i] as usize];
294 lpc_predict(mix_buf[i], chan_bits, lpc_coefs, lpc_quant[i]);
295 }
296
297 if element_channels == 2 && mix_res != 0 {
298 unmix_stereo(&mut mix_buf, mix_bits, mix_res);
299 }
300
301 if let Some(mut extra_bits_reader) = extra_bits_reader {
304 append_extra_bits(
305 &mut extra_bits_reader,
306 &mut mix_buf,
307 element_channels,
308 sample_shift,
309 )?;
310 }
311
312 for i in 0..num_samples {
313 for j in 0..element_channels as usize {
314 let sample = mix_buf[j][i];
315
316 let idx = i * this.config.num_channels as usize + channel_index as usize + j;
317
318 out[idx] = S::from_decoder(sample, this.config.bit_depth);
319 }
320 }
321 } else {
322 if sample_shift != 0 {
328 return Err(invalid_data(
329 "sample shift cannot be greater than zero for uncompressed channels",
330 ));
331 }
332
333 for i in 0..num_samples {
334 for j in 0..element_channels as usize {
335 let sample = reader.read_u32(this.config.bit_depth as usize)? as i32;
336
337 let idx = i * this.config.num_channels as usize + channel_index as usize + j;
338
339 out[idx] = S::from_decoder(sample, this.config.bit_depth);
340 }
341 }
342 }
343
344 Ok(num_samples as u32)
345}
346
347#[inline]
348fn decode_rice_symbol<'a>(
349 reader: &mut BitCursor<'a>,
350 m: u32,
351 k: u8,
352 bps: u8,
353) -> Result<u32, InvalidData> {
354 debug_assert!(k != 0);
362
363 let k = k as usize;
364
365 let mut q = 0;
369 while q != 9 && reader.read_bit()? == true {
370 q += 1;
371 }
372
373 if q == 9 {
374 return Ok(reader.read_u32(bps as usize)?);
375 }
376
377 if k == 1 {
382 return Ok(q);
383 }
384
385 let mut r = reader.read_u32(k - 1)?;
389 if r > 0 {
390 let extra_bit = reader.read_bit()? as u32;
391 r = (r << 1) + extra_bit - 1;
392 }
393
394 let s = q * m + r;
397
398 Ok(s)
399}
400
401fn rice_decompress<'a>(
402 reader: &mut BitCursor<'a>,
403 config: &StreamInfo,
404 buf: &mut [i32],
405 bps: u8,
406 pb_factor: u16,
407) -> Result<(), InvalidData> {
408 #[inline(always)]
409 fn log_2(x: u32) -> u32 {
410 31 - (x | 1).leading_zeros()
411 }
412
413 let mut rice_history: u32 = config.mb as u32;
414 let rice_history_mult = (config.pb as u32 * pb_factor as u32) / 4;
415 let k_max = config.kb;
416 let mut sign_modifier = 0;
417
418 let mut i = 0;
419 while i < buf.len() {
420 let k = log_2((rice_history >> 9) + 3);
421 let k = min(k as u8, k_max);
422 let m = (1 << k) - 1;
424 let val = decode_rice_symbol(reader, m, k, bps)?;
425 let val = val + sign_modifier;
432 sign_modifier = 0;
433 buf[i] = ((val >> 1) as i32) ^ -((val & 1) as i32);
435
436 if val > 0xffff {
438 rice_history = 0xffff;
439 } else {
440 rice_history = (rice_history + val * rice_history_mult)
442 - ((rice_history * rice_history_mult) >> 9);
443 }
444
445 if (rice_history < 128) && (i + 1 < buf.len()) {
447 let k = rice_history.leading_zeros() - 24 + ((rice_history + 16) >> 6);
449 if k as u8 > k_max {
453 debug_assert!(
454 false,
455 "k ({}) greater than rice limit ({}). Unsure how to continue.",
456 k, k_max
457 );
458 }
459
460 let k = k as u8;
462 let wb_local = (1 << k_max) - 1;
463 let m = ((1 << k) - 1) & wb_local;
464 let zero_block_len = decode_rice_symbol(reader, m, k, 16)? as usize;
470
471 if zero_block_len > 0 {
472 if zero_block_len >= buf.len() - i {
473 return Err(invalid_data(
474 "zero block contains too many samples for channel",
475 ));
476 }
477 let buf = &mut buf[i + 1..];
479 for j in 0..zero_block_len {
480 buf[j] = 0;
481 }
482 i += zero_block_len;
483 }
484 if zero_block_len <= 0xffff {
485 sign_modifier = 1;
486 }
487 rice_history = 0;
488 }
489
490 i += 1;
491 }
492 Ok(())
493}
494
495#[inline(always)]
496fn sign_extend(val: i32, bits: u8) -> i32 {
497 let shift = 32 - bits;
498 (val << shift) >> shift
499}
500
501fn lpc_predict_order_31(buf: &mut [i32], bps: u8) {
502 for i in 1..buf.len() {
505 buf[i] = sign_extend(buf[i] + buf[i - 1], bps);
506 }
507}
508
509fn lpc_predict(buf: &mut [i32], bps: u8, lpc_coefs: &mut [i16], lpc_quant: u32) {
510 let lpc_order = lpc_coefs.len();
511
512 for i in 1..min(lpc_order + 1, buf.len()) {
514 buf[i] = sign_extend(buf[i] + buf[i - 1], bps);
515 }
516
517 for i in (lpc_order + 1)..buf.len() {
518 let mean = buf[i - lpc_order - 1];
521
522 let buf = &mut buf[i - lpc_order..i + 1];
524
525 let mut predicted = 0;
527 for (x, coef) in buf.iter().zip(lpc_coefs.iter()) {
528 predicted += (x - mean) * (*coef as i32);
529 }
530
531 let predicted = (predicted + (1 << (lpc_quant - 1))) >> lpc_quant;
534
535 let prediction_error = buf[lpc_order];
537 let sample = predicted + mean + prediction_error;
538 buf[lpc_order] = sign_extend(sample, bps);
539
540 if prediction_error != 0 {
541 let error_sign = prediction_error.signum();
545
546 let mut prediction_error = error_sign * prediction_error;
548
549 for j in 0..lpc_order {
550 let predicted = buf[j] - mean;
551 let sign = predicted.signum() * error_sign;
552 lpc_coefs[j] += sign as i16;
553 prediction_error -= error_sign * (predicted * sign >> lpc_quant) * (j as i32 + 1);
555 if prediction_error <= 0 {
557 break;
558 }
559 }
560 }
561 }
562}
563
564fn unmix_stereo(buf: &mut [&mut [i32]; 2], mix_bits: u8, mix_res: i8) {
565 debug_assert_eq!(buf[0].len(), buf[1].len());
566
567 let mix_res = mix_res as i32;
568 let num_samples = min(buf[0].len(), buf[1].len());
569
570 for i in 0..num_samples {
571 let u = buf[0][i];
572 let v = buf[1][i];
573
574 let r = u - ((v * mix_res) >> mix_bits);
575 let l = r + v;
576
577 buf[0][i] = l;
578 buf[1][i] = r;
579 }
580}
581
582fn append_extra_bits<'a>(
583 reader: &mut BitCursor<'a>,
584 buf: &mut [&mut [i32]; 2],
585 channels: u8,
586 sample_shift: u8,
587) -> Result<(), InvalidData> {
588 debug_assert_eq!(buf[0].len(), buf[1].len());
589
590 let channels = min(channels as usize, buf.len());
591 let num_samples = min(buf[0].len(), buf[1].len());
592 let sample_shift = sample_shift as usize;
593
594 for i in 0..num_samples {
595 for j in 0..channels {
596 let extra_bits = reader.read_u16(sample_shift)? as i32;
597 buf[j][i] = (buf[j][i] << sample_shift) | extra_bits as i32;
598 }
599 }
600
601 Ok(())
602}