alac/
dec.rs

1use std::cmp::min;
2
3use {invalid_data, InvalidData, StreamInfo};
4use bitcursor::BitCursor;
5
6/// A type that can be used to represent audio samples.
7pub trait Sample: Copy + private::Sealed {
8    /// Constructs `Self` from a right-aligned sample with bit depth `bits`.
9    fn from_decoder(sample: i32, bits: u8) -> Self;
10
11    fn bits() -> u8;
12}
13
14impl Sample for i16 {
15    #[inline(always)]
16    fn from_decoder(sample: i32, _: u8) -> Self {
17        sample as i16
18    }
19
20    #[inline(always)]
21    fn bits() -> u8 {
22        16
23    }
24}
25
26impl Sample for i32 {
27    #[inline(always)]
28    fn from_decoder(sample: i32, bits: u8) -> Self {
29        sample << (32 - bits)
30    }
31
32    #[inline(always)]
33    fn bits() -> u8 {
34        32
35    }
36}
37
38mod private {
39    /// Sealed prevents other crates from implementing any traits that use it.
40    pub trait Sealed {}
41    impl Sealed for i16 {}
42    impl Sealed for i32 {}
43}
44
45/// An ALAC packet decoder.
46pub struct Decoder {
47    config: StreamInfo,
48    buf: Box<[i32]>,
49}
50
51const ID_SCE: u8 = 0; // Single Channel Element
52const ID_CPE: u8 = 1; // Channel Pair Element
53const ID_CCE: u8 = 2; // Coupling Channel Element
54const ID_LFE: u8 = 3; // LFE Channel Element
55const ID_DSE: u8 = 4; // not yet supported
56const ID_PCE: u8 = 5;
57const ID_FIL: u8 = 6; // filler element
58const ID_END: u8 = 7; // frame end
59
60impl Decoder {
61    /// Creates a `Decoder` for a stream described by the `StreamInfo`.
62    pub fn new(config: StreamInfo) -> Decoder {
63        Decoder {
64            buf: vec![0; config.frame_length as usize * 2].into_boxed_slice(),
65            config,
66        }
67    }
68
69    /// Returns the `StreamInfo` used to create this decoder.
70    pub fn stream_info(&self) -> &StreamInfo {
71        &self.config
72    }
73
74    /// Decodes an ALAC packet into `out`.
75    ///
76    /// Channels are interleaved, e.g. for a stereo packet `out` would contains samples in the
77    /// order `[left, right, left, right, ..]`.
78    ///
79    /// Panics if `out` is shorter than `StreamInfo::max_samples_per_packet`.
80    pub fn decode_packet<'a, S: Sample>(
81        &mut self,
82        packet: &[u8],
83        out: &'a mut [S],
84    ) -> Result<&'a [S], InvalidData> {
85        let mut reader = BitCursor::new(packet);
86
87        let mut channel_index = 0;
88        let mut frame_samples = None;
89
90        assert!(out.len() >= self.config.max_samples_per_packet() as usize);
91        assert!(S::bits() >= self.config.bit_depth);
92
93        loop {
94            let tag = reader.read_u8(3)?;
95
96            match tag {
97                tag @ ID_SCE | tag @ ID_LFE | tag @ ID_CPE => {
98                    let element_channels = match tag {
99                        ID_SCE => 1,
100                        ID_LFE => 1,
101                        ID_CPE => 2,
102                        _ => unreachable!(),
103                    };
104
105                    // Check that there aren't too many channels in this packet.
106                    if channel_index + element_channels > self.config.num_channels {
107                        return Err(invalid_data("packet contains more channels than expected"));
108                    }
109
110                    let element_samples = decode_audio_element(
111                        self,
112                        &mut reader,
113                        out,
114                        channel_index,
115                        element_channels,
116                    )?;
117
118                    // Check that the number of samples are consistent within elements of a frame.
119                    if let Some(frame_samples) = frame_samples {
120                        if frame_samples != element_samples {
121                            return Err(invalid_data(
122                                "all channels in a packet must contain the same number of samples",
123                            ));
124                        }
125                    } else {
126                        frame_samples = Some(element_samples);
127                    }
128
129                    channel_index += element_channels;
130                }
131                ID_CCE | ID_PCE => {
132                    return Err(invalid_data("packet cce and pce elements are unsupported"));
133                }
134                ID_DSE => {
135                    // data stream element -- parse but ignore
136
137                    // the tag associates this data stream element with a given audio element
138                    // Unused
139                    let _element_instance_tag = reader.read_u8(4)?;
140                    let data_byte_align_flag = reader.read_bit()?;
141
142                    // 8-bit count or (8-bit + 8-bit count) if 8-bit count == 255
143                    let mut skip_bytes = reader.read_u8(8)? as usize;
144                    if skip_bytes == 255 {
145                        skip_bytes += reader.read_u8(8)? as usize;
146                    }
147
148                    // the align flag means the bitstream should be byte-aligned before reading the
149                    // following data bytes
150                    if data_byte_align_flag {
151                        reader.skip_to_byte()?;
152                    }
153
154                    reader.skip(skip_bytes * 8)?;
155                }
156                ID_FIL => {
157                    // fill element -- parse but ignore
158
159                    // 4-bit count or (4-bit + 8-bit count) if 4-bit count == 15
160                    // - plus this weird -1 thing I still don't fully understand
161                    let mut skip_bytes = reader.read_u8(4)? as usize;
162                    if skip_bytes == 15 {
163                        skip_bytes += reader.read_u8(8)? as usize - 1
164                    }
165
166                    reader.skip(skip_bytes * 8)?;
167                }
168                ID_END => {
169                    // We've finished decoding the frame. Skip to the end of this byte. There may
170                    // be data left in the packet.
171                    // TODO: Should we throw an error about leftover data.
172                    reader.skip_to_byte()?;
173
174                    // Check that there were as many channels in the packet as there ought to be.
175                    if channel_index != self.config.num_channels {
176                        return Err(invalid_data("packet contains fewer channels than expected"));
177                    }
178
179                    let frame_samples = frame_samples.unwrap_or(self.config.frame_length);
180                    return Ok(&out[..frame_samples as usize * channel_index as usize]);
181                }
182                // `tag` is 3 bits long and we've exhaused all 8 options.
183                _ => unreachable!(),
184            }
185        }
186    }
187}
188
189fn decode_audio_element<'a, S: Sample>(
190    this: &mut Decoder,
191    reader: &mut BitCursor<'a>,
192    out: &mut [S],
193    channel_index: u8,
194    element_channels: u8,
195) -> Result<u32, InvalidData> {
196    // Unused
197    let _element_instance_tag = reader.read_u8(4)?;
198
199    let unused = reader.read_u16(12)?;
200    if unused != 0 {
201        return Err(invalid_data("unused channel header bits must be zero"));
202    }
203
204    // read the 1-bit "partial frame" flag, 2-bit "shift-off" flag & 1-bit "escape" flag
205    let partial_frame = reader.read_bit()?;
206
207    let sample_shift_bytes = reader.read_u8(2)?;
208    if sample_shift_bytes > 2 {
209        return Err(invalid_data(
210            "channel sample shift must not be greater than 16",
211        ));
212    }
213    let sample_shift = sample_shift_bytes * 8;
214
215    let is_uncompressed = reader.read_bit()?;
216
217    // check for partial frame to override requested numSamples
218    let num_samples = if partial_frame {
219        // TODO: this could change within a frame. That would be bad
220        let num_samples = reader.read_u32(32)?;
221
222        if num_samples > this.config.frame_length {
223            return Err(invalid_data("channel contains more samples than expected"));
224        }
225
226        num_samples as usize
227    } else {
228        this.config.frame_length as usize
229    };
230
231    if !is_uncompressed {
232        let (buf_u, buf_v) = this.buf.split_at_mut(this.config.frame_length as usize);
233        let mut mix_buf = [&mut buf_u[..num_samples], &mut buf_v[..num_samples]];
234
235        let chan_bits = this.config.bit_depth - sample_shift + element_channels - 1;
236        if chan_bits > 32 {
237            // unimplemented - could in theory be 33
238            return Err(invalid_data("channel bit depth cannot be greater than 32"));
239        }
240
241        // compressed frame, read rest of parameters
242        let mix_bits: u8 = reader.read_u8(8)?;
243        let mix_res: i8 = reader.read_u8(8)? as i8;
244
245        let mut lpc_mode = [0; 2]; //u8
246        let mut lpc_quant = [0; 2]; //u32
247        let mut pb_factor = [0; 2]; //u16
248        let mut lpc_order = [0; 2]; //u8
249        let mut lpc_coefs = [[0; 32]; 2]; //i16*
250
251        for i in 0..(element_channels as usize) {
252            lpc_mode[i] = reader.read_u8(4)?;
253            lpc_quant[i] = reader.read_u8(4)? as u32;
254            pb_factor[i] = reader.read_u8(3)? as u16;
255            lpc_order[i] = reader.read_u8(5)?;
256
257            // Coefficients are used in reverse order of storage for prediction
258            for j in (0..lpc_order[i] as usize).rev() {
259                lpc_coefs[i][j] = reader.read_u16(16)? as i16;
260            }
261        }
262
263        let extra_bits_reader = if sample_shift != 0 {
264            let extra_bits_reader = reader.clone();
265            reader.skip((sample_shift as usize) * num_samples * element_channels as usize)?;
266            Some(extra_bits_reader)
267        } else {
268            None
269        };
270
271        // TODO: Tidy and comment these steps see below for an example
272        // https://github.com/ruud-v-a/claxon/blob/master/src/subframe.rs
273        // It should be possible to it without allocating buffers quite easily
274        for i in 0..(element_channels as usize) {
275            rice_decompress(
276                reader,
277                &this.config,
278                &mut mix_buf[i],
279                chan_bits,
280                pb_factor[i],
281            )?;
282
283            if lpc_mode[i as usize] == 15 {
284                // the special "numActive == 31" mode can be done in-place
285                lpc_predict_order_31(mix_buf[i], chan_bits);
286            } else if lpc_mode[i as usize] > 0 {
287                return Err(invalid_data("invalid lpc mode"));
288            }
289
290            // We have a seperate function for this
291            assert!(lpc_order[i] != 31);
292
293            let lpc_coefs = &mut lpc_coefs[i][..lpc_order[i] as usize];
294            lpc_predict(mix_buf[i], chan_bits, lpc_coefs, lpc_quant[i]);
295        }
296
297        if element_channels == 2 && mix_res != 0 {
298            unmix_stereo(&mut mix_buf, mix_bits, mix_res);
299        }
300
301        // now read the shifted values into the shift buffer
302        // We directly apply the shifts to avoid needing a buffer
303        if let Some(mut extra_bits_reader) = extra_bits_reader {
304            append_extra_bits(
305                &mut extra_bits_reader,
306                &mut mix_buf,
307                element_channels,
308                sample_shift,
309            )?;
310        }
311
312        for i in 0..num_samples {
313            for j in 0..element_channels as usize {
314                let sample = mix_buf[j][i];
315
316                let idx = i * this.config.num_channels as usize + channel_index as usize + j;
317
318                out[idx] = S::from_decoder(sample, this.config.bit_depth);
319            }
320        }
321    } else {
322        // uncompressed frame, copy data into the mix buffers to use common output code
323
324        // Here we deviate here from the reference implementation and just copy
325        // straight to the output buffer.
326
327        if sample_shift != 0 {
328            return Err(invalid_data(
329                "sample shift cannot be greater than zero for uncompressed channels",
330            ));
331        }
332
333        for i in 0..num_samples {
334            for j in 0..element_channels as usize {
335                let sample = reader.read_u32(this.config.bit_depth as usize)? as i32;
336
337                let idx = i * this.config.num_channels as usize + channel_index as usize + j;
338
339                out[idx] = S::from_decoder(sample, this.config.bit_depth);
340            }
341        }
342    }
343
344    Ok(num_samples as u32)
345}
346
347#[inline]
348fn decode_rice_symbol<'a>(
349    reader: &mut BitCursor<'a>,
350    m: u32,
351    k: u8,
352    bps: u8,
353) -> Result<u32, InvalidData> {
354    // Rice coding encodes a symbol S as the product of a quotient Q and a
355    // modulus M added to a remainder R. Q is encoded in unary (Q 1s followed
356    // by a 0) and R in binary in K bits.
357    //
358    // S = Q × M + R where M = 2^K
359
360    // K cannot be zero as a modulus is 2^K - 1 is used instead of 2^K.
361    debug_assert!(k != 0);
362
363    let k = k as usize;
364
365    // First we need to try to read Q which is encoded in unary and is at most
366    // 9. If it is greater than 8 the entire symbol is simply encoded in binary
367    // after Q.
368    let mut q = 0;
369    while q != 9 && reader.read_bit()? == true {
370        q += 1;
371    }
372
373    if q == 9 {
374        return Ok(reader.read_u32(bps as usize)?);
375    }
376
377    // A modulus of 2^K - 1 is used instead of 2^K. Therefore if K = 1 then
378    // M = 1 and there is no remainder (here K cannot be 0 as it comes from
379    // log_2 which cannot be 0). This is presumably an optimisation that aims
380    // to store small numbers more efficiently.
381    if k == 1 {
382        return Ok(q);
383    }
384
385    // Next we read the remainder which is at most K bits. If it is zero it is
386    // stored as K - 1 zeros. Otherwise it is stored in K bits as R + 1. This
387    // saves one bit in cases where the remainder is zero.
388    let mut r = reader.read_u32(k - 1)?;
389    if r > 0 {
390        let extra_bit = reader.read_bit()? as u32;
391        r = (r << 1) + extra_bit - 1;
392    }
393
394    // Due to the issue mentioned in rice_decompress we use a parameter for m
395    // rather than calculating it here (e.g. let mut s = (q << k) - q);
396    let s = q * m + r;
397
398    Ok(s)
399}
400
401fn rice_decompress<'a>(
402    reader: &mut BitCursor<'a>,
403    config: &StreamInfo,
404    buf: &mut [i32],
405    bps: u8,
406    pb_factor: u16,
407) -> Result<(), InvalidData> {
408    #[inline(always)]
409    fn log_2(x: u32) -> u32 {
410        31 - (x | 1).leading_zeros()
411    }
412
413    let mut rice_history: u32 = config.mb as u32;
414    let rice_history_mult = (config.pb as u32 * pb_factor as u32) / 4;
415    let k_max = config.kb;
416    let mut sign_modifier = 0;
417
418    let mut i = 0;
419    while i < buf.len() {
420        let k = log_2((rice_history >> 9) + 3);
421        let k = min(k as u8, k_max);
422        // See below for info on the m thing
423        let m = (1 << k) - 1;
424        let val = decode_rice_symbol(reader, m, k, bps)?;
425        // The least significant bit of val is the sign bit - the plus is weird tho
426        // if val and sgn mod = 0 then nothing happens
427        // if one is 1 the lsb = 1
428        // val & 1 = 1 => val is all 1s => flip all the bits
429        // if they are both 1 then val_eff += 2
430        // val & 1 = 0 => nothing happens...?
431        let val = val + sign_modifier;
432        sign_modifier = 0;
433        // As lsb sign bit right shift by 1
434        buf[i] = ((val >> 1) as i32) ^ -((val & 1) as i32);
435
436        // Update the history value
437        if val > 0xffff {
438            rice_history = 0xffff;
439        } else {
440            // Avoid += as that has a tendency to underflow
441            rice_history = (rice_history + val * rice_history_mult)
442                - ((rice_history * rice_history_mult) >> 9);
443        }
444
445        // There may be a compressed block of zeros. See if there is.
446        if (rice_history < 128) && (i + 1 < buf.len()) {
447            // calculate rice param and decode block size
448            let k = rice_history.leading_zeros() - 24 + ((rice_history + 16) >> 6);
449            // The maximum value k above can take is 7. The rice limit seems to always be higher
450            // than this. This is called infrequently enough that the if statement below should
451            // have a minimal effect on performance.
452            if k as u8 > k_max {
453                debug_assert!(
454                    false,
455                    "k ({}) greater than rice limit ({}). Unsure how to continue.",
456                    k, k_max
457                );
458            }
459
460            // Apple version
461            let k = k as u8;
462            let wb_local = (1 << k_max) - 1;
463            let m = ((1 << k) - 1) & wb_local;
464            // FFMPEG version
465            // let k = min(k as u8, k_max);
466            // let mz = ((1 << k) - 1);
467            // End versions
468
469            let zero_block_len = decode_rice_symbol(reader, m, k, 16)? as usize;
470
471            if zero_block_len > 0 {
472                if zero_block_len >= buf.len() - i {
473                    return Err(invalid_data(
474                        "zero block contains too many samples for channel",
475                    ));
476                }
477                // TODO: Use memset equivalent here.
478                let buf = &mut buf[i + 1..];
479                for j in 0..zero_block_len {
480                    buf[j] = 0;
481                }
482                i += zero_block_len;
483            }
484            if zero_block_len <= 0xffff {
485                sign_modifier = 1;
486            }
487            rice_history = 0;
488        }
489
490        i += 1;
491    }
492    Ok(())
493}
494
495#[inline(always)]
496fn sign_extend(val: i32, bits: u8) -> i32 {
497    let shift = 32 - bits;
498    (val << shift) >> shift
499}
500
501fn lpc_predict_order_31(buf: &mut [i32], bps: u8) {
502    // When lpc_order is 31 samples are encoded using differential coding. Samples values are the
503    // sum of the previous and the difference between the previous and current sample.
504    for i in 1..buf.len() {
505        buf[i] = sign_extend(buf[i] + buf[i - 1], bps);
506    }
507}
508
509fn lpc_predict(buf: &mut [i32], bps: u8, lpc_coefs: &mut [i16], lpc_quant: u32) {
510    let lpc_order = lpc_coefs.len();
511
512    // Prediction needs lpc_order + 1 previous decoded samples.
513    for i in 1..min(lpc_order + 1, buf.len()) {
514        buf[i] = sign_extend(buf[i] + buf[i - 1], bps);
515    }
516
517    for i in (lpc_order + 1)..buf.len() {
518        // The (lpc_order - 1)'th predicted sample is used as the mean signal value for this
519        // prediction.
520        let mean = buf[i - lpc_order - 1];
521
522        // The previous lpc_order samples are used to predict this sample.
523        let buf = &mut buf[i - lpc_order..i + 1];
524
525        // Predict the next sample using linear predictive coding.
526        let mut predicted = 0;
527        for (x, coef) in buf.iter().zip(lpc_coefs.iter()) {
528            predicted += (x - mean) * (*coef as i32);
529        }
530
531        // Round up to and then truncate by lpc_quant bits.
532        // 1 << (lpc_quant - 1) sets the (lpc_quant - 1)'th bit.
533        let predicted = (predicted + (1 << (lpc_quant - 1))) >> lpc_quant;
534
535        // Store the sample for output and to be used in the next prediction.
536        let prediction_error = buf[lpc_order];
537        let sample = predicted + mean + prediction_error;
538        buf[lpc_order] = sign_extend(sample, bps);
539
540        if prediction_error != 0 {
541            // The prediction was not exact so adjust LPC coefficients to try to reduce the size
542            // of the next prediction error. Add or subtract 1 from each coefficient until the
543            // sign of error has changed or we run out of coefficients to adjust.
544            let error_sign = prediction_error.signum();
545
546            // This implementation always uses a positive prediction error.
547            let mut prediction_error = error_sign * prediction_error;
548
549            for j in 0..lpc_order {
550                let predicted = buf[j] - mean;
551                let sign = predicted.signum() * error_sign;
552                lpc_coefs[j] += sign as i16;
553                // Update the prediction error now we have changed a coefficient.
554                prediction_error -= error_sign * (predicted * sign >> lpc_quant) * (j as i32 + 1);
555                // Stop updating coefficients if the prediction error changes sign.
556                if prediction_error <= 0 {
557                    break;
558                }
559            }
560        }
561    }
562}
563
564fn unmix_stereo(buf: &mut [&mut [i32]; 2], mix_bits: u8, mix_res: i8) {
565    debug_assert_eq!(buf[0].len(), buf[1].len());
566
567    let mix_res = mix_res as i32;
568    let num_samples = min(buf[0].len(), buf[1].len());
569
570    for i in 0..num_samples {
571        let u = buf[0][i];
572        let v = buf[1][i];
573
574        let r = u - ((v * mix_res) >> mix_bits);
575        let l = r + v;
576
577        buf[0][i] = l;
578        buf[1][i] = r;
579    }
580}
581
582fn append_extra_bits<'a>(
583    reader: &mut BitCursor<'a>,
584    buf: &mut [&mut [i32]; 2],
585    channels: u8,
586    sample_shift: u8,
587) -> Result<(), InvalidData> {
588    debug_assert_eq!(buf[0].len(), buf[1].len());
589
590    let channels = min(channels as usize, buf.len());
591    let num_samples = min(buf[0].len(), buf[1].len());
592    let sample_shift = sample_shift as usize;
593
594    for i in 0..num_samples {
595        for j in 0..channels {
596            let extra_bits = reader.read_u16(sample_shift)? as i32;
597            buf[j][i] = (buf[j][i] << sample_shift) | extra_bits as i32;
598        }
599    }
600
601    Ok(())
602}