alac_encoder/
lib.rs

1#![no_std]
2
3#[macro_use]
4extern crate alloc;
5
6mod ag;
7mod bit_buffer;
8mod dp;
9mod matrix;
10
11use alloc::vec::Vec;
12
13use ag::AgParams;
14use bit_buffer::BitBuffer;
15use log::debug;
16use matrix::Source;
17
18pub const DEFAULT_FRAME_SIZE: usize = 4096;
19pub const DEFAULT_FRAMES_PER_PACKET: u32 = 4096;
20
21#[deprecated(note = "Use FormatDescription::max_packet_size instead")]
22pub const MAX_ESCAPE_HEADER_BYTES: usize = max_packet_size(0, 8, 0);
23
24const MAX_CHANNELS: usize = 8;
25const MAX_SAMPLE_SIZE: usize = 32;
26const MAX_SEARCHES: usize = 16;
27const MAX_COEFS: usize = 16;
28
29const DEFAULT_MIX_BITS: u32 = 2;
30const MAX_RES: u32 = 4;
31const DEFAULT_NUM_UV: u32 = 8;
32
33const MIN_UV: usize = 4;
34const MAX_UV: usize = 8;
35
36const MAX_RUN_DEFAULT: u16 = 255;
37
38const ALAC_COMPATIBLE_VERSION: u8 = 0;
39
40#[derive(Clone, Copy, Debug)]
41enum ElementType {
42    /// Single channel element
43    Sce = 0,
44    /// Channel pair element
45    Cpe = 1,
46    /// End of frame marker
47    End = 7,
48}
49
50const CHANNEL_MAPS: [[Option<ElementType>; MAX_CHANNELS]; MAX_CHANNELS] = [
51    [Some(ElementType::Sce), None, None, None, None, None, None, None],
52    [Some(ElementType::Cpe), None, None, None, None, None, None, None],
53    [Some(ElementType::Sce), Some(ElementType::Cpe), None, None, None, None, None, None],
54    [Some(ElementType::Sce), Some(ElementType::Cpe), None, Some(ElementType::Sce), None, None, None, None],
55    [Some(ElementType::Sce), Some(ElementType::Cpe), None, Some(ElementType::Cpe), None, None, None, None],
56    [Some(ElementType::Sce), Some(ElementType::Cpe), None, Some(ElementType::Cpe), None, Some(ElementType::Sce), None, None],
57    [Some(ElementType::Sce), Some(ElementType::Cpe), None, Some(ElementType::Cpe), None, Some(ElementType::Sce), Some(ElementType::Sce), None],
58    [Some(ElementType::Sce), Some(ElementType::Cpe), None, Some(ElementType::Cpe), None, Some(ElementType::Cpe), None, Some(ElementType::Sce)],
59];
60
61#[must_use]
62const fn max_packet_size(bit_depth: usize, channels_per_frame: usize, frames_per_packet: usize) -> usize {
63    let sce = 3 + 4 + 12 + 4 + (bit_depth * frames_per_packet);
64    let cpe = 3 + 4 + 12 + 4 + (bit_depth * 2 * frames_per_packet);
65
66    let channel_bits = match channels_per_frame {
67        1 => sce,
68        2 => cpe,
69        3 => sce + cpe,
70        4 => sce + cpe + sce,
71        5 => sce + cpe + cpe,
72        6 => sce + cpe + cpe + sce,
73        7 => sce + cpe + cpe + sce + sce,
74        8 => sce + cpe + cpe + cpe + sce,
75        _ => unreachable!(),
76    };
77
78    (channel_bits + 3 + 7) / 8
79}
80
81pub trait PcmFormat {
82    fn bits() -> u32;
83    fn bytes() -> u32;
84    fn flags() -> u32;
85}
86
87impl PcmFormat for i16 {
88    fn bits() -> u32 { 16 }
89    fn bytes() -> u32 { 2 }
90    fn flags() -> u32 { 4 }
91}
92
93#[derive(Clone, Copy, Debug, PartialEq, Eq)]
94enum FormatType {
95    /// Apple Lossless
96    AppleLossless,
97    /// Linear PCM
98    LinearPcm,
99}
100
101pub struct FormatDescription {
102    sample_rate: f64,
103    format_id: FormatType,
104    bytes_per_packet: u32,
105    frames_per_packet: u32,
106    channels_per_frame: u32,
107    bits_per_channel: u32,
108}
109
110impl FormatDescription {
111    #[must_use]
112    pub fn pcm<T: PcmFormat>(sample_rate: f64, channels: u32) -> FormatDescription {
113        assert!(channels > 0 && channels <= MAX_CHANNELS as u32);
114
115        FormatDescription {
116            sample_rate,
117            format_id: FormatType::LinearPcm,
118            bytes_per_packet: channels * T::bytes(),
119            frames_per_packet: 1,
120            channels_per_frame: channels,
121            bits_per_channel: T::bits(),
122        }
123    }
124
125    #[must_use]
126    pub const fn alac(sample_rate: f64, frames_per_packet: u32, channels: u32) -> FormatDescription {
127        assert!(channels > 0 && channels <= MAX_CHANNELS as u32);
128
129        FormatDescription {
130            sample_rate,
131            format_id: FormatType::AppleLossless,
132            bytes_per_packet: 0,
133            frames_per_packet,
134            channels_per_frame: channels,
135            bits_per_channel: 16,
136        }
137    }
138
139    #[must_use]
140    pub const fn max_packet_size(&self) -> usize {
141        max_packet_size(self.bits_per_channel as usize, self.channels_per_frame as usize, self.frames_per_packet as usize)
142    }
143}
144
145pub struct AlacEncoder {
146    // ALAC encoder parameters
147    bit_depth: usize,
148
149    // encoding state
150    last_mix_res: [i16; 8],
151
152    // encoding buffers
153    mix_buffer_u: Vec<i32>,
154    mix_buffer_v: Vec<i32>,
155    predictor_u: Vec<i32>,
156    predictor_v: Vec<i32>,
157    shift_buffer_uv: Vec<u16>,
158    work_buffer: Vec<u8>,
159
160    // per-channel coefficients buffers
161    coefs_u: [[[i16; MAX_COEFS]; MAX_SEARCHES]; MAX_CHANNELS],
162    coefs_v: [[[i16; MAX_COEFS]; MAX_SEARCHES]; MAX_CHANNELS],
163
164    // encoding statistics
165    total_bytes_generated: usize,
166    avg_bit_rate: u32,
167    max_frame_bytes: u32,
168    frame_size: usize,
169    num_channels: usize,
170    output_sample_rate: u32,
171}
172
173impl AlacEncoder {
174    pub fn new(output_format: &FormatDescription) -> AlacEncoder {
175        assert_eq!(output_format.format_id, FormatType::AppleLossless);
176
177        let frame_size = output_format.frames_per_packet as usize;
178        let num_channels = output_format.channels_per_frame as usize;
179        let max_output_bytes = frame_size * num_channels * ((10 + MAX_SAMPLE_SIZE) / 8) + 1;
180
181        let mut coefs_u = [[[0i16; MAX_COEFS]; MAX_SEARCHES]; MAX_CHANNELS];
182        let mut coefs_v = [[[0i16; MAX_COEFS]; MAX_SEARCHES]; MAX_CHANNELS];
183
184        // allocate mix buffers
185        let mix_buffer_u = vec![0i32; frame_size];
186        let mix_buffer_v = vec![0i32; frame_size];
187
188        // allocate dynamic predictor buffers
189        let predictor_u = vec![0i32; frame_size];
190        let predictor_v = vec![0i32; frame_size];
191
192        // allocate combined shift buffer
193        let shift_buffer_uv = vec![0u16; frame_size * 2];
194
195        // allocate work buffer for search loop
196        let work_buffer = vec![0u8; max_output_bytes];
197
198        // initialize coefs arrays once b/c retaining state across blocks actually improves the encode ratio
199        for channel in 0..num_channels {
200            for search in 0..MAX_SEARCHES {
201                dp::init_coefs(&mut coefs_u[channel][search], dp::DENSHIFT_DEFAULT);
202                dp::init_coefs(&mut coefs_v[channel][search], dp::DENSHIFT_DEFAULT);
203            }
204        }
205
206        AlacEncoder {
207            // ALAC encoder parameters
208            bit_depth: output_format.bits_per_channel as usize,
209
210            // encoding state
211            last_mix_res: [0; 8],
212
213            // encoding buffers
214            mix_buffer_u,
215            mix_buffer_v,
216            predictor_u,
217            predictor_v,
218            shift_buffer_uv,
219            work_buffer,
220
221            // per-channel coefficients buffers
222            coefs_u,
223            coefs_v,
224
225            // encoding statistics
226            total_bytes_generated: 0,
227            avg_bit_rate: 0,
228            max_frame_bytes: 0,
229            frame_size,
230            num_channels,
231            output_sample_rate: output_format.sample_rate as u32,
232        }
233    }
234
235    pub fn bit_depth(&self) -> usize {
236        self.bit_depth
237    }
238
239    pub fn channels(&self) -> usize {
240        self.num_channels
241    }
242
243    pub fn frames(&self) -> usize {
244        self.frame_size
245    }
246
247    pub fn sample_rate(&self) -> usize {
248        self.output_sample_rate as usize
249    }
250
251    pub fn magic_cookie(&self) -> Vec<u8> {
252        let mut result = Vec::with_capacity(if self.num_channels > 2 { 48 } else { 24 });
253
254        /* ALACSpecificConfig */
255        result.extend(u32::to_be_bytes(self.frame_size as u32));
256        result.push(ALAC_COMPATIBLE_VERSION);
257        result.push(self.bit_depth as u8);
258        result.push(ag::PB0 as u8);
259        result.push(ag::MB0 as u8);
260        result.push(ag::KB0 as u8);
261        result.push(self.num_channels as u8);
262        result.extend(u16::to_be_bytes(MAX_RUN_DEFAULT));
263        result.extend(u32::to_be_bytes(self.max_frame_bytes));
264        result.extend(u32::to_be_bytes(self.avg_bit_rate));
265        result.extend(u32::to_be_bytes(self.output_sample_rate));
266
267        /* ALACAudioChannelLayout */
268        if self.num_channels > 2 {
269            let channel_layout_tag = match self.num_channels {
270                1 => [1, 0, 100, 0],
271                2 => [2, 0, 101, 0],
272                3 => [3, 0, 113, 0],
273                4 => [4, 0, 116, 0],
274                5 => [5, 0, 120, 0],
275                6 => [6, 0, 124, 0],
276                7 => [7, 0, 142, 0],
277                8 => [8, 0, 127, 0],
278                _ => panic!("Unsuported number of channels"),
279            };
280
281            result.extend(u32::to_be_bytes(24));
282            result.extend(b"chan");
283            result.extend(u32::to_be_bytes(0));
284            result.extend(&channel_layout_tag);
285            result.extend(u32::to_be_bytes(0));
286            result.extend(u32::to_be_bytes(0));
287        }
288
289        result
290    }
291
292    pub fn encode(&mut self, input_format: &FormatDescription, input_data: &[u8], output_data: &mut [u8]) -> usize {
293        assert_eq!(input_format.format_id, FormatType::LinearPcm);
294
295        let num_frames = input_data.len() / (input_format.bytes_per_packet as usize);
296        assert!(num_frames <= self.frame_size);
297
298        let minimum_buffer_size = max_packet_size(self.bit_depth, self.num_channels, self.frame_size);
299        assert!(output_data.len() >= minimum_buffer_size);
300
301        // create a bit buffer structure pointing to our output buffer
302        let mut bitstream = BitBuffer::new(output_data);
303
304        let input_increment = (self.bit_depth + 7) / 8;
305        let mut input_position = 0usize;
306
307        let mut channel_index = 0;
308        let mut mono_element_tag = 0;
309        let mut stereo_element_tag = 0;
310
311        while channel_index < input_format.channels_per_frame {
312            let tag = CHANNEL_MAPS[input_format.channels_per_frame as usize - 1][channel_index as usize].unwrap();
313
314            bitstream.write_lte25(tag as u32, 3);
315
316            match tag {
317                ElementType::Sce => {
318                    bitstream.write_lte25(mono_element_tag, 4);
319                    let input_size = input_increment;
320                    self.encode_mono(&mut bitstream, &input_data[input_position..], input_format.channels_per_frame as usize, channel_index as usize, num_frames);
321                    input_position += input_size;
322                    channel_index += 1;
323                    mono_element_tag += 1;
324                },
325                ElementType::Cpe => {
326                    bitstream.write_lte25(stereo_element_tag, 4);
327                    let input_size = input_increment * 2;
328                    self.encode_stereo(&mut bitstream, &input_data[input_position..], input_format.channels_per_frame as usize, channel_index as usize, num_frames);
329                    input_position += input_size;
330                    channel_index += 2;
331                    stereo_element_tag += 1;
332                },
333                _ => panic!("Unexpected ElementTag {:?}", tag),
334            }
335        }
336
337        // add 3-bit frame end tag: ID_END
338        bitstream.write_lte25(ElementType::End as u32, 3);
339
340        // byte-align the output data
341        bitstream.byte_align();
342
343        let output_size = bitstream.position() / 8;
344        debug_assert!(output_size <= bitstream.len());
345        debug_assert!(output_size <= minimum_buffer_size);
346
347        self.total_bytes_generated += output_size;
348        self.max_frame_bytes = core::cmp::max(self.max_frame_bytes, output_size as u32);
349
350        output_size
351    }
352
353    fn encode_mono(&mut self, bitstream: &mut BitBuffer, input: &[u8], stride: usize, channel_index: usize, num_samples: usize) {
354        let start_position = bitstream.position();
355
356        // reload coefs array from previous frame
357        let coefs_u = &mut self.coefs_u[channel_index];
358
359        // pick bit depth for actual encoding
360        // - we lop off the lower byte(s) for 24-/32-bit encodings
361        let bytes_shifted: usize = match self.bit_depth { 32 => 2, 24 => 1, _ => 0 };
362
363        let shift = bytes_shifted * 8;
364        let mask: u32 = (1u32 << shift) - 1;
365        let chan_bits = self.bit_depth - shift;
366
367        // flag whether or not this is a partial frame
368        let partial_frame: u8 = if num_samples == self.frame_size { 0 } else { 1 };
369
370        match self.bit_depth {
371            16 => {
372                // convert 16-bit data to 32-bit for predictor
373                let input16 = unsafe { core::slice::from_raw_parts(input.as_ptr() as *const i16, num_samples * stride) };
374                for index in 0..num_samples {
375                    self.mix_buffer_u[index] = input16[index * stride] as i32;
376                }
377            },
378            20 => {
379                // convert 20-bit data to 32-bit for predictor
380                matrix::copy20_to_predictor(input, stride, &mut self.mix_buffer_u, num_samples);
381            },
382            24 =>  {
383                // convert 24-bit data to 32-bit for the predictor and extract the shifted off byte(s)
384                matrix::copy24_to_predictor(input, stride, &mut self.mix_buffer_u, num_samples);
385                for index in 0..num_samples {
386                    self.shift_buffer_uv[index] = ((self.mix_buffer_u[index] as u32) & mask) as u16;
387                    self.mix_buffer_u[index] >>= shift;
388                }
389            },
390            32 => {
391                // just copy the 32-bit input data for the predictor and extract the shifted off byte(s)
392                let input32 = unsafe { core::slice::from_raw_parts(input.as_ptr() as *const i32, num_samples * stride) };
393
394                for index in 0..num_samples {
395                    let val = input32[index * stride];
396
397                    self.shift_buffer_uv[index] = ((val as u32) & mask) as u16;
398                    self.mix_buffer_u[index] = val >> shift;
399                }
400            },
401            _ => panic!("Invalid mBitDepth"),
402        }
403
404        // brute-force encode optimization loop (implied "encode depth" of 0 if comparing to cmd line tool)
405        // - run over variations of the encoding params to find the best choice
406        let min_u = 4;
407        let max_u = 8;
408        let pb_factor = 4;
409
410        let mut min_bits = 1 << 31;
411        let mut best_u = min_u;
412
413        for num_u in (min_u..max_u).step_by(4) {
414            let dilate = 32usize;
415            for _ in 0..7 {
416                dp::pc_block(&self.mix_buffer_u, &mut self.predictor_u, num_samples / dilate, &mut coefs_u[num_u - 1], num_u, chan_bits, dp::DENSHIFT_DEFAULT);
417            }
418
419            let dilate = 8usize;
420            dp::pc_block(&self.mix_buffer_u, &mut self.predictor_u, num_samples / dilate, &mut coefs_u[num_u - 1], num_u, chan_bits, dp::DENSHIFT_DEFAULT);
421
422            let mut work_bits = BitBuffer::new(&mut self.work_buffer);
423            let ag_params = AgParams::new(ag::MB0, (pb_factor * (ag::PB0)) / 4, ag::KB0, (num_samples / dilate) as u32, (num_samples / dilate) as u32);
424            ag::dyn_comp(&ag_params, &self.predictor_u, &mut work_bits, num_samples / dilate, chan_bits);
425
426            let num_bits = (dilate * work_bits.position()) + (16 * num_u);
427            if num_bits < min_bits {
428                best_u = num_u;
429                min_bits = num_bits;
430            }
431        }
432
433        // test for escape hatch if best calculated compressed size turns out to be more than the input size
434        // - first, add bits for the header bytes mixRes/maxRes/shiftU/filterU
435        min_bits += (4 /* mixRes/maxRes/etc. */ * 8) + (if partial_frame == (true as u8) { 32 } else { 0 });
436        if bytes_shifted != 0 {
437            min_bits += num_samples * bytes_shifted * 8;
438        }
439
440        let escape_bits = (num_samples * self.bit_depth) + (if partial_frame == (true as u8) { 32 } else { 0 }) + (2 * 8); /* 2 common header bytes */
441
442        let mut do_escape = min_bits >= escape_bits;
443
444        if !do_escape {
445            // write bitstream header
446            bitstream.write_lte25(0, 12);
447            bitstream.write_lte25(((partial_frame as u32) << 3) | ((bytes_shifted as u32) << 1), 4);
448            if partial_frame > 0 {
449                bitstream.write(num_samples as u32, 32);
450            }
451            bitstream.write_lte25(0, 16); // mixBits = mixRes = 0
452
453            // write the params and predictor coefs
454            let mode_u = 0;
455            bitstream.write_lte25((mode_u << 4) | dp::DENSHIFT_DEFAULT, 8);
456            bitstream.write_lte25(((pb_factor as u32) << 5) | (best_u as u32), 8);
457            for index in 0..best_u {
458                bitstream.write_lte25(coefs_u[(best_u as usize) - 1][index] as u32, 16);
459            }
460
461            // if shift active, write the interleaved shift buffers
462            if bytes_shifted != 0 {
463                for index in 0..num_samples {
464                    bitstream.write(self.shift_buffer_uv[index] as u32, shift);
465                }
466            }
467
468            // run the dynamic predictor with the best result
469            dp::pc_block(&self.mix_buffer_u, &mut self.predictor_u, num_samples, &mut coefs_u[best_u - 1], best_u, chan_bits, dp::DENSHIFT_DEFAULT);
470
471            // do lossless compression
472            let ag_params = AgParams::new_standard(num_samples as u32, num_samples as u32);
473            ag::dyn_comp(&ag_params, &self.predictor_u, bitstream, num_samples, chan_bits);
474
475            // if we happened to create a compressed packet that was actually bigger than an escape packet would be,
476            // chuck it and do an escape packet
477            let min_bits = bitstream.position() - start_position;
478            if min_bits >= escape_bits {
479                bitstream.set_position(start_position);
480                do_escape = true;
481                debug!("compressed frame too big: {} vs. {}", min_bits, escape_bits);
482            }
483        }
484
485        if do_escape {
486            // write bitstream header and coefs
487            bitstream.write_lte25(0, 12);
488            bitstream.write_lte25(((partial_frame as u32) << 3) | 1, 4); // LSB = 1 means "frame not compressed"
489            if partial_frame > 0 {
490                bitstream.write(num_samples as u32, 32);
491            }
492
493            // just copy the input data to the output buffer
494            match self.bit_depth {
495                16 => {
496                    let input16 = unsafe { core::slice::from_raw_parts(input.as_ptr() as *const i16, num_samples * stride) };
497                    for index in (0..(num_samples * stride)).step_by(stride) {
498                        bitstream.write_lte25(input16[index] as u32, 16);
499                    }
500                },
501                20 => {
502                    // convert 20-bit data to 32-bit for simplicity
503                    matrix::copy20_to_predictor(input, stride, &mut self.mix_buffer_u, num_samples);
504                    for index in 0..num_samples {
505                        bitstream.write_lte25(self.mix_buffer_u[index] as u32, 20);
506                    }
507                },
508                24 => {
509                    // convert 24-bit data to 32-bit for simplicity
510                    matrix::copy24_to_predictor(input, stride, &mut self.mix_buffer_u, num_samples);
511                    for index in 0..num_samples {
512                        bitstream.write_lte25(self.mix_buffer_u[index] as u32, 24);
513                    }
514                },
515                32 => {
516                    let input32 = unsafe { core::slice::from_raw_parts(input.as_ptr() as *const i32, num_samples * stride) };
517                    for index in (0..(num_samples * stride)).step_by(stride) {
518                        bitstream.write(input32[index] as u32, 32);
519                    }
520                },
521                _ => panic!("Invalid mBitDepth"),
522            }
523        }
524    }
525
526    fn encode_stereo_escape(&mut self, bitstream: &mut BitBuffer, input: &[u8], stride: usize, num_samples: usize) {
527        // flag whether or not this is a partial frame
528        let partial_frame: u8 = if num_samples == self.frame_size { 0 } else { 1 };
529
530        // write bitstream header
531        bitstream.write_lte25(0, 12);
532        bitstream.write_lte25(((partial_frame as u32) << 3) | 1, 4); // LSB = 1 means "frame not compressed"
533        if partial_frame > 0 {
534            bitstream.write(num_samples as u32, 32);
535        }
536
537        // just copy the input data to the output buffer
538        match self.bit_depth {
539            16 => {
540                let input16 = unsafe { core::slice::from_raw_parts(input.as_ptr() as *const i16, num_samples * stride) };
541
542                for index in (0..(num_samples * stride)).step_by(stride) {
543                    bitstream.write_lte25(input16[index] as u32, 16);
544                    bitstream.write_lte25(input16[index + 1] as u32, 16);
545                }
546            },
547            20 => {
548                // mix20() with mixres param = 0 means de-interleave so use it to simplify things
549                matrix::mix20(Source { data: input, stride, num_samples }, &mut self.mix_buffer_u, &mut self.mix_buffer_v, 0, 0);
550                for index in 0..num_samples {
551                    bitstream.write_lte25(self.mix_buffer_u[index] as u32, 20);
552                    bitstream.write_lte25(self.mix_buffer_v[index] as u32, 20);
553                }
554            },
555            24 => {
556                // mix24() with mixres param = 0 means de-interleave so use it to simplify things
557                matrix::mix24(Source { data: input, stride, num_samples }, &mut self.mix_buffer_u, &mut self.mix_buffer_v, 0, 0, &mut self.shift_buffer_uv, 0);
558                for index in 0..num_samples {
559                    bitstream.write_lte25(self.mix_buffer_u[index] as u32, 24);
560                    bitstream.write_lte25(self.mix_buffer_v[index] as u32, 24);
561                }
562            },
563            32 => {
564                let input32 = unsafe { core::slice::from_raw_parts(input.as_ptr() as *const i32, num_samples * stride) };
565
566                for index in (0..(num_samples * stride)).step_by(stride) {
567                    bitstream.write(input32[index] as u32, 32);
568                    bitstream.write(input32[index + 1] as u32, 32);
569                }
570            },
571            _ => panic!("Invalid mBitDepth"),
572        }
573    }
574
575    fn encode_stereo(&mut self, bitstream: &mut BitBuffer, input: &[u8], stride: usize, channel_index: usize, num_samples: usize) {
576        let start_position = bitstream.position();
577
578        // reload coefs pointers for this channel pair
579        // - note that, while you might think they should be re-initialized per block, retaining state across blocks
580        //   actually results in better overall compression
581        // - strangely, re-using the same coefs for the different passes of the "mixRes" search loop instead of using
582        //   different coefs for the different passes of "mixRes" results in even better compression
583        let coefs_u = &mut self.coefs_u[channel_index];
584        let coefs_v = &mut self.coefs_v[channel_index];
585
586        // matrix encoding adds an extra bit but 32-bit inputs cannot be matrixed b/c 33 is too many
587        // so enable 16-bit "shift off" and encode in 17-bit mode
588        // - in addition, 24-bit mode really improves with one byte shifted off
589        let bytes_shifted: usize = match self.bit_depth { 32 => 2, 24 => 1, _ => 0 };
590
591        let chan_bits = self.bit_depth - (bytes_shifted * 8) + 1;
592
593        // flag whether or not this is a partial frame
594        let partial_frame: u8 = if num_samples == self.frame_size { 0 } else { 1 };
595
596        // brute-force encode optimization loop
597        // - run over variations of the encoding params to find the best choice
598        let mix_bits: i32 = DEFAULT_MIX_BITS as i32;
599        let max_res: i32 = MAX_RES as i32;
600        let num_u: u32 = DEFAULT_NUM_UV;
601        let num_v: u32 = DEFAULT_NUM_UV;
602        let mode: u32 = 0;
603        let pb_factor: u32 = 4;
604
605        let mut min_bits = 1 << 31;
606        let mut best_res: i32 = self.last_mix_res[channel_index] as i32;
607
608        for mix_res in 0..=max_res {
609            let dilate = 8usize;
610            let source = Source { data: input, stride, num_samples: num_samples / dilate };
611
612            // mix the stereo inputs
613            match self.bit_depth {
614                16 => {
615                    matrix::mix16(source, &mut self.mix_buffer_u, &mut self.mix_buffer_v, mix_bits, mix_res);
616                },
617                20 => {
618                    matrix::mix20(source, &mut self.mix_buffer_u, &mut self.mix_buffer_v, mix_bits, mix_res);
619                },
620                24 => {
621                    // includes extraction of shifted-off bytes
622                    matrix::mix24(source, &mut self.mix_buffer_u, &mut self.mix_buffer_v, mix_bits, mix_res, &mut self.shift_buffer_uv, bytes_shifted);
623                },
624                32 => {
625                    // includes extraction of shifted-off bytes
626                    matrix::mix32(source, &mut self.mix_buffer_u, &mut self.mix_buffer_v, mix_bits, mix_res, &mut self.shift_buffer_uv, bytes_shifted);
627                },
628                _ => panic!("Invalid mBitDepth"),
629            }
630
631            // run the dynamic predictors
632            dp::pc_block(&self.mix_buffer_u, &mut self.predictor_u, num_samples / dilate, &mut coefs_u[(num_u as usize) - 1], num_u as usize, chan_bits, dp::DENSHIFT_DEFAULT);
633            dp::pc_block(&self.mix_buffer_v, &mut self.predictor_v, num_samples / dilate, &mut coefs_v[(num_v as usize) - 1], num_v as usize, chan_bits, dp::DENSHIFT_DEFAULT);
634
635            // run the lossless compressor on each channel
636            let mut work_bits = BitBuffer::new(&mut self.work_buffer);
637            let ag_params = AgParams::new(ag::MB0, (pb_factor * (ag::PB0)) / 4, ag::KB0, (num_samples / dilate) as u32, (num_samples / dilate) as u32);
638            ag::dyn_comp(&ag_params, &self.predictor_u, &mut work_bits, num_samples / dilate, chan_bits);
639            ag::dyn_comp(&ag_params, &self.predictor_v, &mut work_bits, num_samples / dilate, chan_bits);
640
641            // look for best match
642            if work_bits.position() < min_bits {
643                min_bits = work_bits.position();
644                best_res = mix_res;
645            }
646        }
647
648        self.last_mix_res[channel_index] = best_res as i16;
649
650        // mix the stereo inputs with the current best mixRes
651        let mix_res: i32 = self.last_mix_res[channel_index] as i32;
652        let source = Source { data: input, stride, num_samples };
653        match self.bit_depth {
654            16 => {
655                matrix::mix16(source, &mut self.mix_buffer_u, &mut self.mix_buffer_v, mix_bits, mix_res);
656            },
657            20 => {
658                matrix::mix20(source, &mut self.mix_buffer_u, &mut self.mix_buffer_v, mix_bits, mix_res);
659            },
660            24 => {
661                // also extracts the shifted off bytes into the shift buffers
662                matrix::mix24(source, &mut self.mix_buffer_u, &mut self.mix_buffer_v, mix_bits, mix_res, &mut self.shift_buffer_uv, bytes_shifted);
663            },
664            32 => {
665                // also extracts the shifted off bytes into the shift buffers
666                matrix::mix32(source, &mut self.mix_buffer_u, &mut self.mix_buffer_v, mix_bits, mix_res, &mut self.shift_buffer_uv, bytes_shifted);
667            },
668            _ => panic!("Invalid mBitDepth"),
669        }
670
671        // now it's time for the predictor coefficient search loop
672        let mut num_u: usize = MIN_UV;
673        let mut num_v: usize = MIN_UV;
674        let mut min_bits1: usize = 1 << 31;
675        let mut min_bits2: usize = 1 << 31;
676
677        for num_uv in (MIN_UV..=MAX_UV).step_by(4) {
678            let dilate = 32usize;
679
680            for _ in 0..8 {
681                dp::pc_block(&self.mix_buffer_u, &mut self.predictor_u, num_samples / dilate, &mut coefs_u[(num_uv as usize) - 1], num_uv as usize, chan_bits, dp::DENSHIFT_DEFAULT);
682                dp::pc_block(&self.mix_buffer_v, &mut self.predictor_v, num_samples / dilate, &mut coefs_v[(num_uv as usize) - 1], num_uv as usize, chan_bits, dp::DENSHIFT_DEFAULT);
683            }
684
685            let dilate = 8usize;
686            let ag_params = AgParams::new(ag::MB0, (pb_factor * ag::PB0) / 4, ag::KB0, (num_samples / dilate) as u32, (num_samples / dilate) as u32);
687
688            let mut work_bits = BitBuffer::new(&mut self.work_buffer);
689            ag::dyn_comp(&ag_params, &self.predictor_u, &mut work_bits, num_samples / dilate, chan_bits);
690            let bits1 = work_bits.position();
691
692            if (bits1 * dilate + 16 * num_uv) < min_bits1 {
693                min_bits1 = bits1 * dilate + 16 * num_uv;
694                num_u = num_uv;
695            }
696
697            let mut work_bits = BitBuffer::new(&mut self.work_buffer);
698            ag::dyn_comp(&ag_params, &self.predictor_v, &mut work_bits, num_samples / dilate, chan_bits);
699            let bits2 = work_bits.position();
700
701            if (bits2 * dilate + 16 * num_uv) < min_bits2 {
702                min_bits2 = bits2 * dilate + 16 * num_uv;
703                num_v = num_uv;
704            }
705        }
706
707        // test for escape hatch if best calculated compressed size turns out to be more than the input size
708        let mut min_bits = min_bits1 + min_bits2 + (8 /* mixRes/maxRes/etc. */ * 8) + (if partial_frame == (true as u8) { 32 } else { 0 });
709        if bytes_shifted != 0 {
710            min_bits += num_samples * bytes_shifted * 8 * 2;
711        }
712
713        let escape_bits = (num_samples * self.bit_depth * 2) + (if partial_frame == (true as u8) { 32 } else { 0 }) + (2 * 8); /* 2 common header bytes */
714
715        let mut do_escape = min_bits >= escape_bits;
716
717        if !do_escape {
718            // write bitstream header and coefs
719            bitstream.write_lte25(0, 12);
720            bitstream.write_lte25(((partial_frame as u32) << 3) | ((bytes_shifted as u32) << 1), 4);
721            if partial_frame > 0 {
722                bitstream.write(num_samples as u32, 32);
723            }
724            bitstream.write_lte25(mix_bits as u32, 8);
725            bitstream.write_lte25(mix_res as u32, 8);
726
727            debug_assert!((mode < 16) && (dp::DENSHIFT_DEFAULT < 16));
728            debug_assert!((pb_factor < 8) && (num_u < 32));
729            debug_assert!((pb_factor < 8) && (num_v < 32));
730
731            bitstream.write_lte25((mode << 4) | dp::DENSHIFT_DEFAULT, 8);
732            bitstream.write_lte25((pb_factor << 5) | (num_u as u32), 8);
733            for index in 0..num_u {
734                bitstream.write_lte25(coefs_u[(num_u as usize) - 1][index as usize] as u32, 16);
735            }
736
737            bitstream.write_lte25((mode << 4) | dp::DENSHIFT_DEFAULT, 8);
738            bitstream.write_lte25((pb_factor << 5) | (num_v as u32), 8);
739            for index in 0..num_v {
740                bitstream.write_lte25(coefs_v[(num_v as usize) - 1][index as usize] as u32, 16);
741            }
742
743            // if shift active, write the interleaved shift buffers
744            if bytes_shifted != 0 {
745                let bit_shift = bytes_shifted * 8;
746                debug_assert!(bit_shift <= 16);
747
748                for index in (0..(num_samples * 2)).step_by(2) {
749                    let shifted_val: u32 = ((self.shift_buffer_uv[index] as u32) << bit_shift) | (self.shift_buffer_uv[index + 1] as u32);
750                    bitstream.write(shifted_val, bit_shift * 2);
751                }
752            }
753
754            // run the dynamic predictor and lossless compression for the "left" channel
755            // - note: to avoid allocating more buffers, we're mixing and matching between the available buffers instead
756            //   of only using "U" buffers for the U-channel and "V" buffers for the V-channel
757            if mode == 0 {
758                dp::pc_block(&self.mix_buffer_u, &mut self.predictor_u, num_samples, &mut coefs_u[(num_u as usize) - 1], num_u as usize, chan_bits, dp::DENSHIFT_DEFAULT);
759            } else {
760                dp::pc_block(&self.mix_buffer_u, &mut self.predictor_v, num_samples, &mut coefs_u[(num_u as usize) - 1], num_u as usize, chan_bits, dp::DENSHIFT_DEFAULT);
761                dp::pc_block(&self.predictor_v, &mut self.predictor_u, num_samples, &mut [], 31, chan_bits, 0);
762            }
763
764            let ag_params = AgParams::new(ag::MB0, (pb_factor * ag::PB0) / 4, ag::KB0, num_samples as u32, num_samples as u32);
765            ag::dyn_comp(&ag_params, &self.predictor_u, bitstream, num_samples, chan_bits);
766
767            // run the dynamic predictor and lossless compression for the "right" channel
768            if mode == 0 {
769                dp::pc_block(&self.mix_buffer_v, &mut self.predictor_v, num_samples, &mut coefs_v[(num_v as usize) - 1], num_v as usize, chan_bits, dp::DENSHIFT_DEFAULT);
770            } else {
771                dp::pc_block(&self.mix_buffer_v, &mut self.predictor_u, num_samples, &mut coefs_v[(num_v as usize) - 1], num_v as usize, chan_bits, dp::DENSHIFT_DEFAULT);
772                dp::pc_block(&self.predictor_u, &mut self.predictor_v, num_samples, &mut [], 31, chan_bits, 0);
773            }
774
775            let ag_params = AgParams::new(ag::MB0, (pb_factor * ag::PB0) / 4, ag::KB0, num_samples as u32, num_samples as u32);
776            ag::dyn_comp(&ag_params, &self.predictor_v, bitstream, num_samples, chan_bits);
777
778            // if we happened to create a compressed packet that was actually bigger than an escape packet would be,
779            // chuck it and do an escape packet
780            let min_bits = bitstream.position() - start_position;
781            if min_bits >= escape_bits {
782                bitstream.set_position(start_position);
783                do_escape = true;
784                debug!("compressed frame too big: {} vs. {}", min_bits, escape_bits);
785            }
786        }
787
788        if do_escape {
789            self.encode_stereo_escape(bitstream, input, stride, num_samples);
790        }
791    }
792}
793
794#[cfg(test)]
795mod tests {
796    extern crate std;
797
798    use super::{AlacEncoder, FormatDescription};
799
800    use std::{fs, vec::Vec};
801
802    use bincode::{deserialize};
803    use serde::{Serialize, Deserialize};
804
805    #[derive(Serialize, Deserialize, Eq, PartialEq, Debug)]
806    struct EncodingResult {
807        magic_cookie: Vec<u8>,
808        alac_chunks: Vec<Vec<u8>>,
809    }
810
811    fn test_case (input: &str, expected: &str, frame_size: u32, channels: u32) {
812        let input_format = FormatDescription::pcm::<i16>(44100.0, channels);
813        let output_format = FormatDescription::alac(44100.0, frame_size, channels);
814
815        let mut encoder = AlacEncoder::new(&output_format);
816
817        let pcm = fs::read(format!("fixtures/{}", input)).unwrap();
818
819        let mut output = vec![0u8; output_format.max_packet_size()];
820
821        let mut result = EncodingResult {
822            magic_cookie: encoder.magic_cookie(),
823            alac_chunks: Vec::new(),
824        };
825
826        for chunk in pcm.chunks(frame_size as usize * channels as usize * 2) {
827            let size = encoder.encode(&input_format, chunk, &mut output);
828            result.alac_chunks.push(Vec::from(&output[0..size]));
829        }
830
831        // NOTE: Uncomment to write out actual result
832        // fs::write(format!("actual_{}", expected), bincode::serialize(&result).unwrap()).unwrap();
833
834        let expected: EncodingResult = deserialize(&fs::read(format!("fixtures/{}", expected)).unwrap()).unwrap();
835
836        assert_eq!(result, expected);
837    }
838
839    #[test]
840    fn it_encodes_sample_352_2() {
841        test_case("sample.pcm", "sample_352_2.bin", 352, 2);
842    }
843
844    #[test]
845    fn it_encodes_sample_4096_2() {
846        test_case("sample.pcm", "sample_4096_2.bin", 4096, 2);
847    }
848
849    #[test]
850    fn it_encodes_random_352_2() {
851        test_case("random.pcm", "random_352_2.bin", 352, 2);
852    }
853
854    #[test]
855    fn it_encodes_random_352_3() {
856        test_case("random.pcm", "random_352_3.bin", 352, 3);
857    }
858
859    #[test]
860    fn it_encodes_random_352_4() {
861        test_case("random.pcm", "random_352_4.bin", 352, 4);
862    }
863
864    #[test]
865    fn it_encodes_random_352_5() {
866        test_case("random.pcm", "random_352_5.bin", 352, 5);
867    }
868
869    #[test]
870    fn it_encodes_random_352_6() {
871        test_case("random.pcm", "random_352_6.bin", 352, 6);
872    }
873
874    #[test]
875    fn it_encodes_random_352_7() {
876        test_case("random.pcm", "random_352_7.bin", 352, 7);
877    }
878
879    #[test]
880    fn it_encodes_random_352_8() {
881        test_case("random.pcm", "random_352_8.bin", 352, 8);
882    }
883
884    #[test]
885    fn it_encodes_random_4096_2() {
886        test_case("random.pcm", "random_4096_2.bin", 4096, 2);
887    }
888
889    #[test]
890    fn it_encodes_random_4096_3() {
891        test_case("random.pcm", "random_4096_3.bin", 4096, 3);
892    }
893
894    #[test]
895    fn it_encodes_random_4096_4() {
896        test_case("random.pcm", "random_4096_4.bin", 4096, 4);
897    }
898
899    #[test]
900    fn it_encodes_random_4096_5() {
901        test_case("random.pcm", "random_4096_5.bin", 4096, 5);
902    }
903
904    #[test]
905    fn it_encodes_random_4096_6() {
906        test_case("random.pcm", "random_4096_6.bin", 4096, 6);
907    }
908
909    #[test]
910    fn it_encodes_random_4096_7() {
911        test_case("random.pcm", "random_4096_7.bin", 4096, 7);
912    }
913
914    #[test]
915    fn it_encodes_random_4096_8() {
916        test_case("random.pcm", "random_4096_8.bin", 4096, 8);
917    }
918}