flac_codec/
encode.rs

1// Copyright 2025 Brian Langenberger
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! For encoding PCM samples to FLAC files
10//!
11//! ## Multithreading
12//!
13//! Encoders will operate using multithreading if the optional `rayon` feature is enabled,
14//! typically boosting performance by processing channels in parallel.
15//! But because subframes must eventually be written serially, and their size cannot generally
16//! be known in advance, processing two channels across two threads will not
17//! encode twice as fast.
18
19use crate::audio::Frame;
20use crate::metadata::{
21    Application, BlockList, BlockSize, Cuesheet, Picture, SeekPoint, Streaminfo, VorbisComment,
22    write_blocks,
23};
24use crate::stream::{ChannelAssignment, FrameNumber, Independent, SampleRate};
25use crate::{Counter, Error};
26use arrayvec::ArrayVec;
27use bitstream_io::{BigEndian, BitRecorder, BitWrite, BitWriter, SignedBitCount};
28use std::collections::VecDeque;
29use std::fs::File;
30use std::io::BufWriter;
31use std::num::NonZero;
32use std::path::Path;
33
34const MAX_CHANNELS: usize = 8;
35// maximum number of LPC coefficients
36const MAX_LPC_COEFFS: usize = 32;
37
38// Invent a vec!-like macro that the official crate lacks
39macro_rules! arrayvec {
40    ( $( $x:expr ),* ) => {
41        {
42            let mut v = ArrayVec::default();
43            $( v.push($x); )*
44            v
45        }
46    }
47}
48
49/// A FLAC writer which accepts samples as bytes
50///
51/// # Example
52///
53/// ```
54/// use flac_codec::{
55///     byteorder::LittleEndian,
56///     encode::{FlacByteWriter, Options},
57///     decode::{FlacByteReader, Metadata},
58/// };
59/// use std::io::{Cursor, Read, Seek, Write};
60///
61/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
62///
63/// let mut writer = FlacByteWriter::endian(
64///     &mut flac,           // our wrapped writer
65///     LittleEndian,        // .wav-style byte order
66///     Options::default(),  // default encoding options
67///     44100,               // sample rate
68///     16,                  // bits-per-sample
69///     1,                   // channel count
70///     Some(2000),          // total bytes
71/// ).unwrap();
72///
73/// // write 1000 samples as 16-bit, signed, little-endian bytes (2000 bytes total)
74/// let written_bytes = (0..1000).map(i16::to_le_bytes).flatten().collect::<Vec<u8>>();
75/// assert!(writer.write_all(&written_bytes).is_ok());
76///
77/// // finalize writing file
78/// assert!(writer.finalize().is_ok());
79///
80/// flac.rewind().unwrap();
81///
82/// // open reader around written FLAC file
83/// let mut reader = FlacByteReader::endian(flac, LittleEndian).unwrap();
84///
85/// // read 2000 bytes
86/// let mut read_bytes = vec![];
87/// assert!(reader.read_to_end(&mut read_bytes).is_ok());
88///
89/// // ensure MD5 sum of signed, little-endian samples matches hash in file
90/// let mut md5 = md5::Context::new();
91/// md5.consume(&read_bytes);
92/// assert_eq!(&md5.compute().0, reader.md5().unwrap());
93///
94/// // ensure input and output matches
95/// assert_eq!(read_bytes, written_bytes);
96/// ```
97pub struct FlacByteWriter<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> {
98    // the wrapped encoder
99    encoder: Encoder<W>,
100    // bytes that make up a partial FLAC frame
101    buf: VecDeque<u8>,
102    // a whole set of samples for a FLAC frame
103    frame: Frame,
104    // size of a single sample in bytes
105    bytes_per_sample: usize,
106    // size of single set of channel-independent samples in bytes
107    pcm_frame_size: usize,
108    // size of whole FLAC frame's samples in bytes
109    frame_byte_size: usize,
110    // whether the encoder has finalized the file
111    finalized: bool,
112    // the input bytes' endianness
113    endianness: std::marker::PhantomData<E>,
114}
115
116impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> FlacByteWriter<W, E> {
117    /// Creates new FLAC writer with the given parameters
118    ///
119    /// The writer should be positioned at the start of the file.
120    ///
121    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
122    ///
123    /// `bits_per_sample` must be between 1 and 32.
124    ///
125    /// `channels` must be between 1 and 8.
126    ///
127    /// Note that if `total_bytes` is indicated,
128    /// the number of channel-independent samples written *must*
129    /// be equal to that amount or an error will occur when writing
130    /// or finalizing the stream.
131    ///
132    /// # Errors
133    ///
134    /// Returns I/O error if unable to write initial
135    /// metadata blocks.
136    /// Returns error if any of the encoding parameters are invalid.
137    pub fn new(
138        writer: W,
139        options: Options,
140        sample_rate: u32,
141        bits_per_sample: u32,
142        channels: u8,
143        total_bytes: Option<u64>,
144    ) -> Result<Self, Error> {
145        let bits_per_sample: SignedBitCount<32> = bits_per_sample
146            .try_into()
147            .map_err(|_| Error::InvalidBitsPerSample)?;
148
149        let bytes_per_sample = u32::from(bits_per_sample).div_ceil(8) as usize;
150
151        let pcm_frame_size = bytes_per_sample * channels as usize;
152
153        Ok(Self {
154            buf: VecDeque::default(),
155            frame: Frame::empty(channels.into(), bits_per_sample.into()),
156            bytes_per_sample,
157            pcm_frame_size,
158            frame_byte_size: pcm_frame_size * options.block_size as usize,
159            encoder: Encoder::new(
160                writer,
161                options,
162                sample_rate,
163                bits_per_sample,
164                channels,
165                total_bytes
166                    .map(|bytes| {
167                        exact_div(bytes, channels.into())
168                            .and_then(|s| exact_div(s, bytes_per_sample as u64))
169                            .ok_or(Error::SamplesNotDivisibleByChannels)
170                            .and_then(|b| NonZero::new(b).ok_or(Error::InvalidTotalBytes))
171                    })
172                    .transpose()?,
173            )?,
174            finalized: false,
175            endianness: std::marker::PhantomData,
176        })
177    }
178
179    /// Creates new FLAC writer with CDDA parameters
180    ///
181    /// The writer should be positioned at the start of the file.
182    ///
183    /// Sample rate is 44100 Hz, bits-per-sample is 16,
184    /// channels is 2.
185    ///
186    /// Note that if `total_bytes` is indicated,
187    /// the number of channel-independent samples written *must*
188    /// be equal to that amount or an error will occur when writing
189    /// or finalizing the stream.
190    ///
191    /// # Errors
192    ///
193    /// Returns I/O error if unable to write initial
194    /// metadata blocks.
195    /// Returns error if any of the encoding parameters are invalid.
196    pub fn new_cdda(writer: W, options: Options, total_bytes: Option<u64>) -> Result<Self, Error> {
197        Self::new(writer, options, 44100, 16, 2, total_bytes)
198    }
199
200    /// Creates new FLAC writer in the given endianness with the given parameters
201    ///
202    /// The writer should be positioned at the start of the file.
203    ///
204    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
205    ///
206    /// `bits_per_sample` must be between 1 and 32.
207    ///
208    /// `channels` must be between 1 and 8.
209    ///
210    /// Note that if `total_bytes` is indicated,
211    /// the number of bytes written *must*
212    /// be equal to that amount or an error will occur when writing
213    /// or finalizing the stream.
214    ///
215    /// # Errors
216    ///
217    /// Returns I/O error if unable to write initial
218    /// metadata blocks.
219    #[inline]
220    pub fn endian(
221        writer: W,
222        _endianness: E,
223        options: Options,
224        sample_rate: u32,
225        bits_per_sample: u32,
226        channels: u8,
227        total_bytes: Option<u64>,
228    ) -> Result<Self, Error> {
229        Self::new(
230            writer,
231            options,
232            sample_rate,
233            bits_per_sample,
234            channels,
235            total_bytes,
236        )
237    }
238
239    fn finalize_inner(&mut self) -> Result<(), Error> {
240        if !self.finalized {
241            self.finalized = true;
242
243            // encode as many bytes as possible into final frame, if necessary
244            if !self.buf.is_empty() {
245                use crate::byteorder::LittleEndian;
246
247                // truncate buffer to whole PCM frames
248                let buf = self.buf.make_contiguous();
249                let buf_len = buf.len();
250                let buf = &mut buf[..(buf_len - buf_len % self.pcm_frame_size)];
251
252                // convert buffer to little-endian bytes
253                E::bytes_to_le(buf, self.bytes_per_sample);
254
255                // update MD5 sum with little-endian bytes
256                self.encoder.update_md5(buf);
257
258                self.encoder
259                    .encode(self.frame.fill_from_buf::<LittleEndian>(buf))?;
260            }
261
262            self.encoder.finalize_inner()
263        } else {
264            Ok(())
265        }
266    }
267
268    /// Attempt to finalize stream
269    ///
270    /// It is necessary to finalize the FLAC encoder
271    /// so that it will write any partially unwritten samples
272    /// to the stream and update the [`crate::metadata::Streaminfo`] and [`crate::metadata::SeekTable`] blocks
273    /// with their final values.
274    ///
275    /// Dropping the encoder will attempt to finalize the stream
276    /// automatically, but will ignore any errors that may occur.
277    pub fn finalize(mut self) -> Result<(), Error> {
278        self.finalize_inner()?;
279        Ok(())
280    }
281}
282
283impl<E: crate::byteorder::Endianness> FlacByteWriter<BufWriter<File>, E> {
284    /// Creates new FLAC file at the given path
285    ///
286    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
287    ///
288    /// `bits_per_sample` must be between 1 and 32.
289    ///
290    /// `channels` must be between 1 and 8.
291    ///
292    /// Note that if `total_bytes` is indicated,
293    /// the number of bytes written *must*
294    /// be equal to that amount or an error will occur when writing
295    /// or finalizing the stream.
296    ///
297    /// # Errors
298    ///
299    /// Returns I/O error if unable to write initial
300    /// metadata blocks.
301    #[inline]
302    pub fn create<P: AsRef<Path>>(
303        path: P,
304        options: Options,
305        sample_rate: u32,
306        bits_per_sample: u32,
307        channels: u8,
308        total_bytes: Option<u64>,
309    ) -> Result<Self, Error> {
310        FlacByteWriter::new(
311            BufWriter::new(options.create(path)?),
312            options,
313            sample_rate,
314            bits_per_sample,
315            channels,
316            total_bytes,
317        )
318    }
319
320    /// Creates new FLAC file with CDDA parameters at the given path
321    ///
322    /// Sample rate is 44100 Hz, bits-per-sample is 16,
323    /// channels is 2.
324    ///
325    /// Note that if `total_bytes` is indicated,
326    /// the number of bytes written *must*
327    /// be equal to that amount or an error will occur when writing
328    /// or finalizing the stream.
329    ///
330    /// # Errors
331    ///
332    /// Returns I/O error if unable to write initial
333    /// metadata blocks.
334    pub fn create_cdda<P: AsRef<Path>>(
335        path: P,
336        options: Options,
337        total_bytes: Option<u64>,
338    ) -> Result<Self, Error> {
339        Self::create(path, options, 44100, 16, 2, total_bytes)
340    }
341}
342
343impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> std::io::Write
344    for FlacByteWriter<W, E>
345{
346    /// Writes a set of sample bytes to the FLAC file
347    ///
348    /// Samples are signed and encoded in the stream's given byte order.
349    ///
350    /// Samples are then interleaved by channel, like:
351    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
352    ///
353    /// This is the same format used by common PCM container
354    /// formats like WAVE and AIFF.
355    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
356        use crate::byteorder::LittleEndian;
357
358        // dump whole set of bytes into our internal buffer
359        self.buf.extend(buf);
360
361        // encode as many FLAC frames as possible (which may be 0)
362        let mut encoded_frames = 0;
363        for buf in self
364            .buf
365            .make_contiguous()
366            .chunks_exact_mut(self.frame_byte_size)
367        {
368            // convert buffer to little-endian bytes
369            E::bytes_to_le(buf, self.bytes_per_sample);
370
371            // update MD5 sum with little-endian bytes
372            self.encoder.update_md5(buf);
373
374            // encode fresh FLAC frame
375            self.encoder
376                .encode(self.frame.fill_from_buf::<LittleEndian>(buf))?;
377
378            encoded_frames += 1;
379        }
380        // TODO - use truncate_front whenever that stabilizes
381        self.buf.drain(0..self.frame_byte_size * encoded_frames);
382
383        // indicate whole buffer's been consumed
384        Ok(buf.len())
385    }
386
387    #[inline]
388    fn flush(&mut self) -> std::io::Result<()> {
389        // we don't want to flush a partial frame to disk,
390        // but we can at least flush our internal writer
391        self.encoder.writer.flush()
392    }
393}
394
395impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> Drop
396    for FlacByteWriter<W, E>
397{
398    fn drop(&mut self) {
399        let _ = self.finalize_inner();
400    }
401}
402
403/// A FLAC writer which accepts samples as signed integers
404///
405/// # Example
406///
407/// ```
408/// use flac_codec::{
409///     encode::{FlacSampleWriter, Options},
410///     decode::FlacSampleReader,
411/// };
412/// use std::io::{Cursor, Seek};
413///
414/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
415///
416/// let mut writer = FlacSampleWriter::new(
417///     &mut flac,           // our wrapped writer
418///     Options::default(),  // default encoding options
419///     44100,               // sample rate
420///     16,                  // bits-per-sample
421///     1,                   // channel count
422///     Some(1000),          // total samples
423/// ).unwrap();
424///
425/// // write 1000 samples
426/// let written_samples = (0..1000).collect::<Vec<i32>>();
427/// assert!(writer.write(&written_samples).is_ok());
428///
429/// // finalize writing file
430/// assert!(writer.finalize().is_ok());
431///
432/// flac.rewind().unwrap();
433///
434/// // open reader around written FLAC file
435/// let mut reader = FlacSampleReader::new(flac).unwrap();
436///
437/// // read 1000 samples
438/// let mut read_samples = vec![0; 1000];
439/// assert!(matches!(reader.read(&mut read_samples), Ok(1000)));
440///
441/// // ensure they match
442/// assert_eq!(read_samples, written_samples);
443/// ```
444pub struct FlacSampleWriter<W: std::io::Write + std::io::Seek> {
445    // the wrapped encoder
446    encoder: Encoder<W>,
447    // samples that make up a partial FLAC frame
448    // in channel-interleaved order
449    // (must de-interleave later in case someone writes
450    // only partial set of channels in a single write call)
451    sample_buf: VecDeque<i32>,
452    // a whole set of samples for a FLAC frame
453    frame: Frame,
454    // size of a single frame in samples
455    frame_sample_size: usize,
456    // size of a single PCM frame in samples
457    pcm_frame_size: usize,
458    // size of a single sample in bytes
459    bytes_per_sample: usize,
460    // size of single set of channel-independent samples in bytes
461    // whether the encoder has finalized the file
462    finalized: bool,
463}
464
465impl<W: std::io::Write + std::io::Seek> FlacSampleWriter<W> {
466    /// Creates new FLAC writer with the given parameters
467    ///
468    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
469    ///
470    /// `bits_per_sample` must be between 1 and 32.
471    ///
472    /// `channels` must be between 1 and 8.
473    ///
474    /// Note that if `total_samples` is indicated,
475    /// the number of samples written *must*
476    /// be equal to that amount or an error will occur when writing
477    /// or finalizing the stream.
478    ///
479    /// # Errors
480    ///
481    /// Returns I/O error if unable to write initial
482    /// metadata blocks.
483    /// Returns error if any of the encoding parameters are invalid.
484    pub fn new(
485        writer: W,
486        options: Options,
487        sample_rate: u32,
488        bits_per_sample: u32,
489        channels: u8,
490        total_samples: Option<u64>,
491    ) -> Result<Self, Error> {
492        let bits_per_sample: SignedBitCount<32> = bits_per_sample
493            .try_into()
494            .map_err(|_| Error::InvalidBitsPerSample)?;
495
496        let bytes_per_sample = u32::from(bits_per_sample).div_ceil(8) as usize;
497
498        let pcm_frame_size = usize::from(channels);
499
500        Ok(Self {
501            sample_buf: VecDeque::default(),
502            frame: Frame::empty(channels.into(), bits_per_sample.into()),
503            bytes_per_sample,
504            pcm_frame_size,
505            frame_sample_size: pcm_frame_size * options.block_size as usize,
506            encoder: Encoder::new(
507                writer,
508                options,
509                sample_rate,
510                bits_per_sample,
511                channels,
512                total_samples
513                    .map(|samples| {
514                        exact_div(samples, channels.into())
515                            .ok_or(Error::SamplesNotDivisibleByChannels)
516                            .and_then(|s| NonZero::new(s).ok_or(Error::InvalidTotalSamples))
517                    })
518                    .transpose()?,
519            )?,
520            finalized: false,
521        })
522    }
523
524    /// Creates new FLAC writer with CDDA parameters
525    ///
526    /// Sample rate is 44100 Hz, bits-per-sample is 16,
527    /// channels is 2.
528    ///
529    /// Note that if `total_samples` is indicated,
530    /// the number of samples written *must*
531    /// be equal to that amount or an error will occur when writing
532    /// or finalizing the stream.
533    ///
534    /// # Errors
535    ///
536    /// Returns I/O error if unable to write initial
537    /// metadata blocks.
538    /// Returns error if any of the encoding parameters are invalid.
539    pub fn new_cdda(
540        writer: W,
541        options: Options,
542        total_samples: Option<u64>,
543    ) -> Result<Self, Error> {
544        Self::new(writer, options, 44100, 16, 2, total_samples)
545    }
546
547    /// Given a set of samples, writes them to the FLAC file
548    ///
549    /// Samples are interleaved by channel, like:
550    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
551    ///
552    /// This may output 0 or more actual FLAC frames,
553    /// depending on the quantity of samples and the amount
554    /// previously written.
555    pub fn write(&mut self, samples: &[i32]) -> Result<(), Error> {
556        // dump whole set of samples into our internal buffer
557        self.sample_buf.extend(samples);
558
559        // encode as many FLAC frames as possible (which may be 0)
560        let mut encoded_frames = 0;
561        for buf in self
562            .sample_buf
563            .make_contiguous()
564            .chunks_exact_mut(self.frame_sample_size)
565        {
566            // update running MD5 sum calculation
567            // since samples are already interleaved in channel order
568            update_md5(&mut self.encoder, buf, self.bytes_per_sample);
569
570            // encode fresh FLAC frame
571            self.encoder.encode(self.frame.fill_from_samples(buf))?;
572
573            encoded_frames += 1;
574        }
575        self.sample_buf
576            .drain(0..self.frame_sample_size * encoded_frames);
577
578        Ok(())
579    }
580
581    fn finalize_inner(&mut self) -> Result<(), Error> {
582        if !self.finalized {
583            self.finalized = true;
584
585            // encode as many samples possible into final frame, if necessary
586            if !self.sample_buf.is_empty() {
587                // truncate buffer to whole PCM frames
588                let buf = self.sample_buf.make_contiguous();
589                let buf_len = buf.len();
590                let buf = &mut buf[..(buf_len - buf_len % self.pcm_frame_size)];
591
592                // update running MD5 sum calculation
593                // since samples are already interleaved in channel order
594                update_md5(&mut self.encoder, buf, self.bytes_per_sample);
595
596                // encode final FLAC frame
597                self.encoder.encode(self.frame.fill_from_samples(buf))?;
598            }
599
600            self.encoder.finalize_inner()
601        } else {
602            Ok(())
603        }
604    }
605
606    /// Attempt to finalize stream
607    ///
608    /// It is necessary to finalize the FLAC encoder
609    /// so that it will write any partially unwritten samples
610    /// to the stream and update the [`crate::metadata::Streaminfo`] and [`crate::metadata::SeekTable`] blocks
611    /// with their final values.
612    ///
613    /// Dropping the encoder will attempt to finalize the stream
614    /// automatically, but will ignore any errors that may occur.
615    pub fn finalize(mut self) -> Result<(), Error> {
616        self.finalize_inner()?;
617        Ok(())
618    }
619}
620
621impl FlacSampleWriter<BufWriter<File>> {
622    /// Creates new FLAC file at the given path
623    ///
624    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
625    ///
626    /// `bits_per_sample` must be between 1 and 32.
627    ///
628    /// `channels` must be between 1 and 8.
629    ///
630    /// Note that if `total_bytes` is indicated,
631    /// the number of bytes written *must*
632    /// be equal to that amount or an error will occur when writing
633    /// or finalizing the stream.
634    ///
635    /// # Errors
636    ///
637    /// Returns I/O error if unable to write initial
638    /// metadata blocks.
639    #[inline]
640    pub fn create<P: AsRef<Path>>(
641        path: P,
642        options: Options,
643        sample_rate: u32,
644        bits_per_sample: u32,
645        channels: u8,
646        total_samples: Option<u64>,
647    ) -> Result<Self, Error> {
648        FlacSampleWriter::new(
649            BufWriter::new(options.create(path)?),
650            options,
651            sample_rate,
652            bits_per_sample,
653            channels,
654            total_samples,
655        )
656    }
657
658    /// Creates new FLAC file at the given path with CDDA parameters
659    ///
660    /// Sample rate is 44100 Hz, bits-per-sample is 16,
661    /// channels is 2.
662    ///
663    /// Note that if `total_bytes` is indicated,
664    /// the number of bytes written *must*
665    /// be equal to that amount or an error will occur when writing
666    /// or finalizing the stream.
667    ///
668    /// # Errors
669    ///
670    /// Returns I/O error if unable to write initial
671    /// metadata blocks.
672    #[inline]
673    pub fn create_cdda<P: AsRef<Path>>(
674        path: P,
675        options: Options,
676        total_samples: Option<u64>,
677    ) -> Result<Self, Error> {
678        Self::create(path, options, 44100, 16, 2, total_samples)
679    }
680}
681
682/// A FLAC writer which operates on streamed output
683///
684/// Because this encodes FLAC frames without any metadata
685/// blocks or finalizing, it does not need to be seekable.
686///
687/// # Example
688///
689/// ```
690/// use flac_codec::{
691///     decode::{FlacStreamReader, FrameBuf},
692///     encode::{FlacStreamWriter, Options},
693/// };
694/// use std::io::{Cursor, Seek};
695/// use bitstream_io::SignedBitCount;
696///
697/// let mut flac = Cursor::new(vec![]);
698///
699/// let samples = (0..100).collect::<Vec<i32>>();
700///
701/// let mut w = FlacStreamWriter::new(&mut flac, Options::default());
702///
703/// // write a single FLAC frame with some samples
704/// w.write(
705///     44100,  // sample rate
706///     1,      // channels
707///     16,     // bits-per-sample
708///     &samples,
709/// ).unwrap();
710///
711/// flac.rewind().unwrap();
712///
713/// let mut r = FlacStreamReader::new(&mut flac);
714///
715/// // read a single FLAC frame with some samples
716/// assert_eq!(
717///     r.read().unwrap(),
718///     FrameBuf {
719///         samples: &samples,
720///         sample_rate: 44100,
721///         channels: 1,
722///         bits_per_sample: 16,
723///     },
724/// );
725/// ```
726pub struct FlacStreamWriter<W> {
727    // the writer we're outputting to
728    writer: W,
729    // various encoding optins
730    options: EncoderOptions,
731    // various encoder caches
732    caches: EncodingCaches,
733    // a whole set of samples for a FLAC frame
734    frame: Frame,
735    // the current frame number
736    frame_number: FrameNumber,
737}
738
739impl<W: std::io::Write> FlacStreamWriter<W> {
740    /// Creates new stream writer
741    pub fn new(writer: W, options: Options) -> Self {
742        Self {
743            writer,
744            options: EncoderOptions {
745                max_partition_order: options.max_partition_order,
746                mid_side: options.mid_side,
747                seektable_interval: options.seektable_interval,
748                max_lpc_order: options.max_lpc_order,
749                window: options.window,
750                exhaustive_channel_correlation: options.exhaustive_channel_correlation,
751                use_rice2: false,
752            },
753            caches: EncodingCaches::default(),
754            frame: Frame::default(),
755            frame_number: FrameNumber::default(),
756        }
757    }
758
759    /// Writes a whole FLAC frame to our output stream
760    ///
761    /// Samples are interleaved by channel, like:
762    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
763    ///
764    /// This writes a whole FLAC frame to the output stream on each call.
765    ///
766    /// # Errors
767    ///
768    /// Returns an error of any of the parameters are invalid
769    /// or if an I/O error occurs when writing to the stream.
770    pub fn write(
771        &mut self,
772        sample_rate: u32,
773        channels: u8,
774        bits_per_sample: u32,
775        samples: &[i32],
776    ) -> Result<(), Error> {
777        use crate::crc::{Crc16, CrcWriter};
778        use crate::stream::{BitsPerSample, FrameHeader, SampleRate};
779
780        let bits_per_sample: SignedBitCount<32> = bits_per_sample
781            .try_into()
782            .map_err(|_| Error::NonSubsetBitsPerSample)?;
783
784        // samples must divide evenly into channels
785        if samples.len() % usize::from(channels) != 0 {
786            return Err(Error::SamplesNotDivisibleByChannels);
787        } else if !(1..=8).contains(&channels) {
788            return Err(Error::ExcessiveChannels);
789        }
790
791        self.options.use_rice2 = u32::from(bits_per_sample) > 16;
792
793        self.frame
794            .resize(bits_per_sample.into(), channels.into(), 0);
795        self.frame.fill_from_samples(samples);
796
797        // block size must be valid
798        let block_size: crate::stream::BlockSize<u16> = crate::stream::BlockSize::try_from(
799            u16::try_from(self.frame.pcm_frames()).map_err(|_| Error::InvalidBlockSize)?,
800        )
801        .map_err(|_| Error::InvalidBlockSize)?;
802
803        // sample rate must be valid for subset streams
804        let sample_rate: SampleRate<u32> = sample_rate.try_into().and_then(|rate| match rate {
805            SampleRate::Streaminfo(_) => Err(Error::NonSubsetSampleRate),
806            rate => Ok(rate),
807        })?;
808
809        // bits-per-sample must be valid for subset streams
810        let header_bits_per_sample = match BitsPerSample::from(bits_per_sample) {
811            BitsPerSample::Streaminfo(_) => return Err(Error::NonSubsetBitsPerSample),
812            bps => bps,
813        };
814
815        let mut w: CrcWriter<_, Crc16> = CrcWriter::new(&mut self.writer);
816        let mut bw: BitWriter<CrcWriter<&mut W, Crc16>, BigEndian>;
817
818        match self
819            .frame
820            .channels()
821            .collect::<ArrayVec<&[i32], MAX_CHANNELS>>()
822            .as_slice()
823        {
824            [channel] => {
825                FrameHeader {
826                    blocking_strategy: false,
827                    frame_number: self.frame_number,
828                    block_size: (channel.len() as u16)
829                        .try_into()
830                        .expect("frame cannot be empty"),
831                    sample_rate,
832                    bits_per_sample: header_bits_per_sample,
833                    channel_assignment: ChannelAssignment::Independent(Independent::Mono),
834                }
835                .write_subset(&mut w)?;
836
837                bw = BitWriter::new(w);
838
839                self.caches.channels.resize_with(1, ChannelCache::default);
840
841                encode_subframe(
842                    &self.options,
843                    &mut self.caches.channels[0],
844                    CorrelatedChannel::independent(bits_per_sample, channel),
845                )?
846                .playback(&mut bw)?;
847            }
848            [left, right] if self.options.exhaustive_channel_correlation => {
849                let Correlated {
850                    channel_assignment,
851                    channels: [channel_0, channel_1],
852                } = correlate_channels_exhaustive(
853                    &self.options,
854                    &mut self.caches.correlated,
855                    [left, right],
856                    bits_per_sample,
857                )?;
858
859                FrameHeader {
860                    blocking_strategy: false,
861                    frame_number: self.frame_number,
862                    block_size,
863                    sample_rate,
864                    bits_per_sample: header_bits_per_sample,
865                    channel_assignment,
866                }
867                .write_subset(&mut w)?;
868
869                bw = BitWriter::new(w);
870
871                channel_0.playback(&mut bw)?;
872                channel_1.playback(&mut bw)?;
873            }
874            [left, right] => {
875                let Correlated {
876                    channel_assignment,
877                    channels: [channel_0, channel_1],
878                } = correlate_channels(
879                    &self.options,
880                    &mut self.caches.correlated,
881                    [left, right],
882                    bits_per_sample,
883                );
884
885                FrameHeader {
886                    blocking_strategy: false,
887                    frame_number: self.frame_number,
888                    block_size,
889                    sample_rate,
890                    bits_per_sample: header_bits_per_sample,
891                    channel_assignment,
892                }
893                .write_subset(&mut w)?;
894
895                self.caches.channels.resize_with(2, ChannelCache::default);
896                let [cache_0, cache_1] = self.caches.channels.get_disjoint_mut([0, 1]).unwrap();
897                let (channel_0, channel_1) = join(
898                    || encode_subframe(&self.options, cache_0, channel_0),
899                    || encode_subframe(&self.options, cache_1, channel_1),
900                );
901
902                bw = BitWriter::new(w);
903
904                channel_0?.playback(&mut bw)?;
905                channel_1?.playback(&mut bw)?;
906            }
907            channels => {
908                // non-stereo frames are always encoded independently
909                FrameHeader {
910                    blocking_strategy: false,
911                    frame_number: self.frame_number,
912                    block_size,
913                    sample_rate,
914                    bits_per_sample: header_bits_per_sample,
915                    channel_assignment: ChannelAssignment::Independent(
916                        channels.len().try_into().expect("invalid channel count"),
917                    ),
918                }
919                .write_subset(&mut w)?;
920
921                bw = BitWriter::new(w);
922
923                self.caches
924                    .channels
925                    .resize_with(channels.len(), ChannelCache::default);
926
927                vec_map(
928                    self.caches.channels.iter_mut().zip(channels).collect(),
929                    |(cache, channel)| {
930                        encode_subframe(
931                            &self.options,
932                            cache,
933                            CorrelatedChannel::independent(bits_per_sample, channel),
934                        )
935                    },
936                )
937                .into_iter()
938                .try_for_each(|r| r.and_then(|r| r.playback(bw.by_ref()).map_err(Error::Io)))?;
939            }
940        }
941
942        let crc16: u16 = bw.aligned_writer()?.checksum().into();
943        bw.write_from(crc16)?;
944
945        if self.frame_number.try_increment().is_err() {
946            self.frame_number = FrameNumber::default();
947        }
948
949        Ok(())
950    }
951
952    /// Writes a whole FLAC frame to our output stream with CDDA parameters
953    ///
954    /// Samples are interleaved by channel, like:
955    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
956    ///
957    /// This writes a whole FLAC frame to the output stream on each call.
958    ///
959    /// # Errors
960    ///
961    /// Returns an error of any of the parameters are invalid
962    /// or if an I/O error occurs when writing to the stream.
963    pub fn write_cdda(&mut self, samples: &[i32]) -> Result<(), Error> {
964        self.write(44100, 2, 16, samples)
965    }
966}
967
968fn update_md5<W: std::io::Write + std::io::Seek>(
969    encoder: &mut Encoder<W>,
970    samples: &[i32],
971    bytes_per_sample: usize,
972) {
973    use crate::byteorder::{Endianness, LittleEndian};
974
975    match bytes_per_sample {
976        1 => {
977            for s in samples {
978                encoder.update_md5(&LittleEndian::i8_to_bytes(*s as i8));
979            }
980        }
981        2 => {
982            for s in samples {
983                encoder.update_md5(&LittleEndian::i16_to_bytes(*s as i16));
984            }
985        }
986        3 => {
987            for s in samples {
988                encoder.update_md5(&LittleEndian::i24_to_bytes(*s));
989            }
990        }
991        4 => {
992            for s in samples {
993                encoder.update_md5(&LittleEndian::i32_to_bytes(*s));
994            }
995        }
996        _ => panic!("unsupported number of bytes per sample"),
997    }
998}
999
1000/// The interval of seek points to generate
1001#[derive(Copy, Clone, Debug)]
1002pub enum SeekTableInterval {
1003    ///Generate seekpoint every nth seconds
1004    Seconds(NonZero<u8>),
1005    /// Generate seekpoint every nth frames
1006    Frames(NonZero<usize>),
1007}
1008
1009impl Default for SeekTableInterval {
1010    fn default() -> Self {
1011        Self::Seconds(NonZero::new(10).unwrap())
1012    }
1013}
1014
1015impl SeekTableInterval {
1016    // decimates full set of seekpoints based on the requested
1017    // seektable style, or returns None if no seektable is requested
1018    fn filter<'s>(
1019        self,
1020        sample_rate: u32,
1021        seekpoints: impl IntoIterator<Item = EncoderSeekPoint> + 's,
1022    ) -> Box<dyn Iterator<Item = EncoderSeekPoint> + 's> {
1023        match self {
1024            Self::Seconds(seconds) => {
1025                let nth_sample = u64::from(u32::from(seconds.get()) * sample_rate);
1026                let mut offset = 0;
1027                Box::new(seekpoints.into_iter().filter(move |point| {
1028                    if point.range().contains(&offset) {
1029                        offset += nth_sample;
1030                        true
1031                    } else {
1032                        false
1033                    }
1034                }))
1035            }
1036            Self::Frames(frames) => Box::new(seekpoints.into_iter().step_by(frames.get())),
1037        }
1038    }
1039}
1040
1041/// FLAC encoding options
1042#[derive(Clone, Debug)]
1043pub struct Options {
1044    // whether to clobber existing file
1045    clobber: bool,
1046    block_size: u16,
1047    max_partition_order: u32,
1048    mid_side: bool,
1049    metadata: BlockList,
1050    seektable_interval: Option<SeekTableInterval>,
1051    max_lpc_order: Option<NonZero<u8>>,
1052    window: Window,
1053    exhaustive_channel_correlation: bool,
1054}
1055
1056impl Default for Options {
1057    fn default() -> Self {
1058        // a dummy placeholder value
1059        // since we can't know the stream parameters yet
1060        let mut metadata = BlockList::new(Streaminfo {
1061            minimum_block_size: 0,
1062            maximum_block_size: 0,
1063            minimum_frame_size: None,
1064            maximum_frame_size: None,
1065            sample_rate: 0,
1066            channels: NonZero::new(1).unwrap(),
1067            bits_per_sample: SignedBitCount::new::<4>(),
1068            total_samples: None,
1069            md5: None,
1070        });
1071
1072        metadata.insert(crate::metadata::Padding {
1073            size: 4096u16.into(),
1074        });
1075
1076        Self {
1077            clobber: false,
1078            block_size: 4096,
1079            mid_side: true,
1080            max_partition_order: 5,
1081            metadata,
1082            seektable_interval: Some(SeekTableInterval::default()),
1083            max_lpc_order: NonZero::new(8),
1084            window: Window::default(),
1085            exhaustive_channel_correlation: true,
1086        }
1087    }
1088}
1089
1090impl Options {
1091    /// Sets new block size
1092    ///
1093    /// Block size must be ≥ 16
1094    ///
1095    /// For subset streams, this must be ≤ 4608
1096    /// if the sample rate is ≤ 48 kHz -
1097    /// or ≤ 16384 for higher sample rates.
1098    pub fn block_size(self, block_size: u16) -> Result<Self, OptionsError> {
1099        match block_size {
1100            0..16 => Err(OptionsError::InvalidBlockSize),
1101            16.. => Ok(Self { block_size, ..self }),
1102        }
1103    }
1104
1105    /// Sets new maximum LPC order
1106    ///
1107    /// The valid range is: 0 < `max_lpc_order` ≤ 32
1108    ///
1109    /// A value of `None` means that no LPC subframes will be encoded.
1110    pub fn max_lpc_order(self, max_lpc_order: Option<u8>) -> Result<Self, OptionsError> {
1111        Ok(Self {
1112            max_lpc_order: max_lpc_order
1113                .map(|o| {
1114                    o.try_into()
1115                        .ok()
1116                        .filter(|o| *o <= NonZero::new(32).unwrap())
1117                        .ok_or(OptionsError::InvalidLpcOrder)
1118                })
1119                .transpose()?,
1120            ..self
1121        })
1122    }
1123
1124    /// Sets maximum residual partion order.
1125    ///
1126    /// The valid range is: 0 ≤ `max_partition_order` ≤ 15
1127    pub fn max_partition_order(self, max_partition_order: u32) -> Result<Self, OptionsError> {
1128        match max_partition_order {
1129            0..=15 => Ok(Self {
1130                max_partition_order,
1131                ..self
1132            }),
1133            16.. => Err(OptionsError::InvalidMaxPartitions),
1134        }
1135    }
1136
1137    /// Whether to use mid-side encoding
1138    ///
1139    /// The default is `true`.
1140    pub fn mid_side(self, mid_side: bool) -> Self {
1141        Self { mid_side, ..self }
1142    }
1143
1144    /// The windowing function to use for input samples
1145    pub fn window(self, window: Window) -> Self {
1146        Self { window, ..self }
1147    }
1148
1149    /// Whether to calculate the best channel correlation quickly
1150    ///
1151    /// The default is `false`
1152    pub fn fast_channel_correlation(self, fast: bool) -> Self {
1153        Self {
1154            exhaustive_channel_correlation: !fast,
1155            ..self
1156        }
1157    }
1158
1159    /// Updates size of padding block
1160    ///
1161    /// `size` must be < 2²⁴
1162    ///
1163    /// If `size` is set to 0, removes the block entirely.
1164    ///
1165    /// The default is to add a 4096 byte padding block.
1166    pub fn padding(mut self, size: u32) -> Result<Self, OptionsError> {
1167        use crate::metadata::Padding;
1168
1169        match size
1170            .try_into()
1171            .map_err(|_| OptionsError::ExcessivePadding)?
1172        {
1173            BlockSize::ZERO => self.metadata.remove::<Padding>(),
1174            size => self.metadata.update::<Padding>(|p| {
1175                p.size = size;
1176            }),
1177        }
1178        Ok(self)
1179    }
1180
1181    /// Remove any padding blocks from metadata
1182    ///
1183    /// This makes the file smaller, but will likely require
1184    /// rewriting it if any metadata needs to be modified later.
1185    pub fn no_padding(mut self) -> Self {
1186        self.metadata.remove::<crate::metadata::Padding>();
1187        self
1188    }
1189
1190    /// Adds new tag to comment metadata block
1191    ///
1192    /// Creates new [`crate::metadata::VorbisComment`] block if not already present.
1193    pub fn tag<S>(mut self, field: &str, value: S) -> Self
1194    where
1195        S: std::fmt::Display,
1196    {
1197        self.metadata
1198            .update::<VorbisComment>(|vc| vc.insert(field, value));
1199        self
1200    }
1201
1202    /// Replaces entire [`crate::metadata::VorbisComment`] metadata block
1203    ///
1204    /// This may be more convenient when adding many fields at once.
1205    pub fn comment(mut self, comment: VorbisComment) -> Self {
1206        self.metadata.insert(comment);
1207        self
1208    }
1209
1210    /// Add new [`crate::metadata::Picture`] block to metadata
1211    ///
1212    /// Files may contain multiple [`crate::metadata::Picture`] blocks,
1213    /// and this adds a new block each time it is used.
1214    pub fn picture(mut self, picture: Picture) -> Self {
1215        self.metadata.insert(picture);
1216        self
1217    }
1218
1219    /// Add new [`crate::metadata::Cuesheet`] block to metadata
1220    ///
1221    /// Files may (theoretically) contain multiple [`crate::metadata::Cuesheet`] blocks,
1222    /// and this adds a new block each time it is used.
1223    ///
1224    /// In practice, CD images almost always use only a single
1225    /// cue sheet.
1226    pub fn cuesheet(mut self, cuesheet: Cuesheet) -> Self {
1227        self.metadata.insert(cuesheet);
1228        self
1229    }
1230
1231    /// Add new [`crate::metadata::Application`] block to metadata
1232    ///
1233    /// Files may contain multiple [`crate::metadata::Application`] blocks,
1234    /// and this adds a new block each time it is used.
1235    pub fn application(mut self, application: Application) -> Self {
1236        self.metadata.insert(application);
1237        self
1238    }
1239
1240    /// Generate [`crate::metadata::SeekTable`] with the given number of seconds between seek points
1241    ///
1242    /// The default is to generate a SEEKTABLE with 10 seconds between seek points.
1243    ///
1244    /// If `seconds` is 0, removes the SEEKTABLE block.
1245    ///
1246    /// The interval between seek points may be larger than requested
1247    /// if the encoder's block size is larger than the seekpoint interval.
1248    pub fn seektable_seconds(mut self, seconds: u8) -> Self {
1249        // note that we can't drop a placeholder seektable
1250        // into the metadata blocks until we know
1251        // the sample rate and total samples of our stream
1252        self.seektable_interval = NonZero::new(seconds).map(SeekTableInterval::Seconds);
1253        self
1254    }
1255
1256    /// Generate [`crate::metadata::SeekTable`] with the given number of FLAC frames between seek points
1257    ///
1258    /// If `frames` is 0, removes the SEEKTABLE block
1259    pub fn seektable_frames(mut self, frames: usize) -> Self {
1260        self.seektable_interval = NonZero::new(frames).map(SeekTableInterval::Frames);
1261        self
1262    }
1263
1264    /// Do not generate a seektable in our encoded file
1265    pub fn no_seektable(self) -> Self {
1266        Self {
1267            seektable_interval: None,
1268            ..self
1269        }
1270    }
1271
1272    /// Overwrites existing file if it already exists
1273    ///
1274    /// This only applies to encoding functions which
1275    /// create files from paths.
1276    ///
1277    /// The default is to not overwrite files
1278    /// if they already exist.
1279    pub fn overwrite(mut self) -> Self {
1280        self.clobber = true;
1281        self
1282    }
1283
1284    /// Returns the fastest encoding options
1285    ///
1286    /// These are tuned to encode as quickly as possible.
1287    pub fn fast() -> Self {
1288        Self {
1289            block_size: 1152,
1290            mid_side: false,
1291            max_partition_order: 3,
1292            max_lpc_order: None,
1293            exhaustive_channel_correlation: false,
1294            ..Self::default()
1295        }
1296    }
1297
1298    /// Returns the fastest encoding options
1299    ///
1300    /// These are tuned to encode files as small as possible.
1301    pub fn best() -> Self {
1302        Self {
1303            block_size: 4096,
1304            mid_side: true,
1305            max_partition_order: 6,
1306            max_lpc_order: NonZero::new(12),
1307            ..Self::default()
1308        }
1309    }
1310
1311    /// Creates files according to whether clobber is set or not
1312    fn create<P: AsRef<Path>>(&self, path: P) -> std::io::Result<File> {
1313        if self.clobber {
1314            File::create(path)
1315        } else {
1316            use std::fs::OpenOptions;
1317
1318            OpenOptions::new()
1319                .write(true)
1320                .create_new(true)
1321                .open(path.as_ref())
1322        }
1323    }
1324}
1325
1326/// An error when specifying encoding options
1327#[derive(Debug)]
1328pub enum OptionsError {
1329    /// Selected block size is too small
1330    InvalidBlockSize,
1331    /// Maximum LPC order is too large
1332    InvalidLpcOrder,
1333    /// Maximum residual partitions is too large
1334    InvalidMaxPartitions,
1335    /// Selected padding size is too large
1336    ExcessivePadding,
1337}
1338
1339impl std::error::Error for OptionsError {}
1340
1341impl std::fmt::Display for OptionsError {
1342    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1343        match self {
1344            Self::InvalidBlockSize => "block size must be >= 16".fmt(f),
1345            Self::InvalidLpcOrder => "maximum LPC order must be <= 32".fmt(f),
1346            Self::InvalidMaxPartitions => "max partition order must be <= 15".fmt(f),
1347            Self::ExcessivePadding => "padding size is too large for block".fmt(f),
1348        }
1349    }
1350}
1351
1352/// A cut-down version of Options without the metadata blocks
1353struct EncoderOptions {
1354    max_partition_order: u32,
1355    mid_side: bool,
1356    seektable_interval: Option<SeekTableInterval>,
1357    max_lpc_order: Option<NonZero<u8>>,
1358    window: Window,
1359    exhaustive_channel_correlation: bool,
1360    use_rice2: bool,
1361}
1362
1363/// The method to use for windowing the input signal
1364#[derive(Copy, Clone, Debug)]
1365pub enum Window {
1366    /// Basic rectangular window
1367    Rectangle,
1368    /// Hann window
1369    Hann,
1370    /// Tukey window
1371    Tukey(f32),
1372}
1373
1374// TODO - add more windowing options
1375
1376impl Window {
1377    fn generate(&self, window: &mut [f64]) {
1378        use std::f64::consts::PI;
1379
1380        match self {
1381            Self::Rectangle => window.fill(1.0),
1382            Self::Hann => {
1383                // verified output against reference implementation
1384                // See: FLAC__window_hann()
1385
1386                let np =
1387                    f64::from(u16::try_from(window.len()).expect("window size too large")) - 1.0;
1388
1389                window.iter_mut().zip(0u16..).for_each(|(w, n)| {
1390                    *w = 0.5 - 0.5 * (2.0 * PI * f64::from(n) / np).cos();
1391                });
1392            }
1393            Self::Tukey(p) => match p {
1394                // verified output against reference implementation
1395                // See: FLAC__window_tukey()
1396                ..=0.0 => {
1397                    window.fill(1.0);
1398                }
1399                1.0.. => {
1400                    Self::Hann.generate(window);
1401                }
1402                0.0..1.0 => {
1403                    match ((f64::from(*p) / 2.0 * window.len() as f64) as usize).checked_sub(1) {
1404                        Some(np) => match window.get_disjoint_mut([
1405                            0..np,
1406                            np..window.len() - np,
1407                            window.len() - np..window.len(),
1408                        ]) {
1409                            Ok([first, mid, last]) => {
1410                                // u16 is maximum block size
1411                                let np = u16::try_from(np).expect("window size too large");
1412
1413                                for ((x, y), n) in
1414                                    first.iter_mut().zip(last.iter_mut().rev()).zip(0u16..)
1415                                {
1416                                    *x = 0.5 - 0.5 * (PI * f64::from(n) / f64::from(np)).cos();
1417                                    *y = *x;
1418                                }
1419                                mid.fill(1.0);
1420                            }
1421                            Err(_) => {
1422                                window.fill(1.0);
1423                            }
1424                        },
1425                        None => {
1426                            window.fill(1.0);
1427                        }
1428                    }
1429                }
1430                _ => {
1431                    Self::Tukey(0.5).generate(window);
1432                }
1433            },
1434        }
1435    }
1436
1437    fn apply<'w>(
1438        &self,
1439        window: &mut Vec<f64>,
1440        cache: &'w mut Vec<f64>,
1441        samples: &[i32],
1442    ) -> &'w [f64] {
1443        if window.len() != samples.len() {
1444            // need to re-generate window to fit samples
1445            window.resize(samples.len(), 0.0);
1446            self.generate(window);
1447        }
1448
1449        // window signal into cache and return cached slice
1450        cache.clear();
1451        cache.extend(samples.iter().zip(window).map(|(s, w)| f64::from(*s) * *w));
1452        cache.as_slice()
1453    }
1454}
1455
1456impl Default for Window {
1457    fn default() -> Self {
1458        Self::Tukey(0.5)
1459    }
1460}
1461
1462#[derive(Default)]
1463struct EncodingCaches {
1464    channels: Vec<ChannelCache>,
1465    correlated: CorrelationCache,
1466}
1467
1468#[derive(Default)]
1469struct CorrelationCache {
1470    // the average channel samples
1471    average_samples: Vec<i32>,
1472    // the difference channel samples
1473    difference_samples: Vec<i32>,
1474
1475    left_cache: ChannelCache,
1476    right_cache: ChannelCache,
1477    average_cache: ChannelCache,
1478    difference_cache: ChannelCache,
1479}
1480
1481#[derive(Default)]
1482struct ChannelCache {
1483    fixed: FixedCache,
1484    fixed_output: BitRecorder<u32, BigEndian>,
1485    lpc: LpcCache,
1486    lpc_output: BitRecorder<u32, BigEndian>,
1487    constant_output: BitRecorder<u32, BigEndian>,
1488    verbatim_output: BitRecorder<u32, BigEndian>,
1489    wasted: Vec<i32>,
1490}
1491
1492#[derive(Default)]
1493struct FixedCache {
1494    // FIXED subframe buffers, one per order 1-4
1495    fixed_buffers: [Vec<i32>; 4],
1496}
1497
1498#[derive(Default)]
1499struct LpcCache {
1500    window: Vec<f64>,
1501    windowed: Vec<f64>,
1502    residuals: Vec<i32>,
1503}
1504
1505/// A FLAC encoder
1506struct Encoder<W: std::io::Write + std::io::Seek> {
1507    // the writer we're outputting to
1508    writer: Counter<W>,
1509    // the stream's starting offset in the writer, in bytes
1510    start: u64,
1511    // various encoding options
1512    options: EncoderOptions,
1513    // various encoder caches
1514    caches: EncodingCaches,
1515    // our metadata blocks
1516    blocks: BlockList,
1517    // our stream's sample rate
1518    sample_rate: SampleRate<u32>,
1519    // the current frame number
1520    frame_number: FrameNumber,
1521    // the number of channel-independent samples written
1522    samples_written: u64,
1523    // all seekpoints
1524    seekpoints: Vec<EncoderSeekPoint>,
1525    // our running MD5 calculation
1526    md5: md5::Context,
1527    // whether the encoder has finalized the file
1528    finalized: bool,
1529}
1530
1531impl<W: std::io::Write + std::io::Seek> Encoder<W> {
1532    const MAX_SAMPLES: u64 = 68_719_476_736;
1533
1534    fn new(
1535        mut writer: W,
1536        options: Options,
1537        sample_rate: u32,
1538        bits_per_sample: SignedBitCount<32>,
1539        channels: u8,
1540        total_samples: Option<NonZero<u64>>,
1541    ) -> Result<Self, Error> {
1542        use crate::metadata::OptionalBlockType;
1543
1544        let mut blocks = options.metadata;
1545
1546        *blocks.streaminfo_mut() = Streaminfo {
1547            minimum_block_size: options.block_size,
1548            maximum_block_size: options.block_size,
1549            minimum_frame_size: None,
1550            maximum_frame_size: None,
1551            sample_rate: (0..1048576)
1552                .contains(&sample_rate)
1553                .then_some(sample_rate)
1554                .ok_or(Error::InvalidSampleRate)?,
1555            bits_per_sample,
1556            channels: (1..=8)
1557                .contains(&channels)
1558                .then_some(channels)
1559                .and_then(NonZero::new)
1560                .ok_or(Error::ExcessiveChannels)?,
1561            total_samples: match total_samples {
1562                None => None,
1563                total_samples @ Some(samples) => match samples.get() {
1564                    0..Self::MAX_SAMPLES => total_samples,
1565                    _ => return Err(Error::ExcessiveTotalSamples),
1566                },
1567            },
1568            md5: None,
1569        };
1570
1571        // insert a dummy SeekTable to be populated later
1572        if let Some(total_samples) = total_samples {
1573            if let Some(placeholders) = options.seektable_interval.map(|s| {
1574                s.filter(
1575                    sample_rate,
1576                    EncoderSeekPoint::placeholders(total_samples.get(), options.block_size),
1577                )
1578            }) {
1579                use crate::metadata::SeekTable;
1580
1581                blocks.insert(SeekTable {
1582                    // placeholder points should always be contiguous
1583                    points: placeholders
1584                        .take(SeekTable::MAX_POINTS)
1585                        .map(|p| p.into())
1586                        .collect::<Vec<_>>()
1587                        .try_into()
1588                        .unwrap(),
1589                });
1590            }
1591        }
1592
1593        let start = writer.stream_position()?;
1594
1595        // sort blocks to put more relevant items at the front
1596        blocks.sort_by(|block| match block {
1597            OptionalBlockType::VorbisComment => 0,
1598            OptionalBlockType::SeekTable => 1,
1599            OptionalBlockType::Picture => 2,
1600            OptionalBlockType::Application => 3,
1601            OptionalBlockType::Cuesheet => 4,
1602            OptionalBlockType::Padding => 5,
1603        });
1604
1605        write_blocks(writer.by_ref(), blocks.blocks())?;
1606
1607        Ok(Self {
1608            start,
1609            writer: Counter::new(writer),
1610            options: EncoderOptions {
1611                max_partition_order: options.max_partition_order,
1612                mid_side: options.mid_side,
1613                seektable_interval: options.seektable_interval,
1614                max_lpc_order: options.max_lpc_order,
1615                window: options.window,
1616                exhaustive_channel_correlation: options.exhaustive_channel_correlation,
1617                use_rice2: u32::from(bits_per_sample) > 16,
1618            },
1619            caches: EncodingCaches::default(),
1620            sample_rate: blocks
1621                .streaminfo()
1622                .sample_rate
1623                .try_into()
1624                .expect("invalid sample rate"),
1625            blocks,
1626            frame_number: FrameNumber::default(),
1627            samples_written: 0,
1628            seekpoints: Vec::new(),
1629            md5: md5::Context::new(),
1630            finalized: false,
1631        })
1632    }
1633
1634    /// Updates running MD5 calculation with signed, little-endian bytes
1635    fn update_md5(&mut self, frame_bytes: &[u8]) {
1636        self.md5.consume(frame_bytes)
1637    }
1638
1639    /// Encodes an audio frame of PCM samples
1640    ///
1641    /// Depending on the encoder's chosen block size,
1642    /// this may encode zero or more FLAC frames to disk.
1643    ///
1644    /// # Errors
1645    ///
1646    /// Returns an I/O error from the underlying stream,
1647    /// or if the frame's parameters are not a match
1648    /// for the encoder's.
1649    fn encode(&mut self, frame: &Frame) -> Result<(), Error> {
1650        // drop in a new seekpoint
1651        self.seekpoints.push(EncoderSeekPoint {
1652            sample_offset: self.samples_written,
1653            byte_offset: Some(self.writer.count),
1654            frame_samples: frame.pcm_frames() as u16,
1655        });
1656
1657        // update running total of samples written
1658        self.samples_written += frame.pcm_frames() as u64;
1659        if let Some(total_samples) = self.blocks.streaminfo().total_samples {
1660            if self.samples_written > total_samples.get() {
1661                return Err(Error::ExcessiveTotalSamples);
1662            }
1663        }
1664
1665        encode_frame(
1666            &self.options,
1667            &mut self.caches,
1668            &mut self.writer,
1669            self.blocks.streaminfo_mut(),
1670            &mut self.frame_number,
1671            self.sample_rate,
1672            frame.channels().collect(),
1673        )
1674    }
1675
1676    fn finalize_inner(&mut self) -> Result<(), Error> {
1677        if !self.finalized {
1678            use crate::metadata::SeekTable;
1679
1680            self.finalized = true;
1681
1682            // update SEEKTABLE metadata block with final values
1683            if let Some(encoded_points) = self
1684                .options
1685                .seektable_interval
1686                .map(|s| s.filter(self.sample_rate.into(), self.seekpoints.iter().cloned()))
1687            {
1688                match self.blocks.get_pair_mut() {
1689                    (Some(SeekTable { points }), _) => {
1690                        // placeholder SEEKTABLE already in place,
1691                        // so no need to adjust PADDING to fit
1692
1693                        // ensure points count is the same
1694
1695                        let points_len = points.len();
1696                        points.clear();
1697                        points
1698                            .try_extend(
1699                                encoded_points
1700                                    .into_iter()
1701                                    .map(|p| p.into())
1702                                    .chain(std::iter::repeat(SeekPoint::Placeholder))
1703                                    .take(points_len),
1704                            )
1705                            .unwrap();
1706                    }
1707                    (None, Some(crate::metadata::Padding { size: padding_size })) => {
1708                        // no SEEKTABLE, but there is a PADDING block,
1709                        // so try to shrink PADDING to fit SEEKTABLE
1710
1711                        use crate::metadata::MetadataBlock;
1712
1713                        let seektable = SeekTable {
1714                            points: encoded_points
1715                                .map(|p| p.into())
1716                                .collect::<Vec<_>>()
1717                                .try_into()
1718                                .unwrap(),
1719                        };
1720                        if let Some(new_padding_size) = seektable
1721                            .total_size()
1722                            .and_then(|seektable_size| padding_size.checked_sub(seektable_size))
1723                        {
1724                            *padding_size = new_padding_size;
1725                            self.blocks.insert(seektable);
1726                        }
1727                    }
1728                    (None, None) => { /* no seektable or padding, so nothing to do */ }
1729                }
1730            }
1731
1732            // verify or update final sample count
1733            match &mut self.blocks.streaminfo_mut().total_samples {
1734                Some(expected) => {
1735                    // ensure final sample count matches
1736                    if expected.get() != self.samples_written {
1737                        return Err(Error::SampleCountMismatch);
1738                    }
1739                }
1740                expected @ None => {
1741                    // update final sample count if possible
1742                    if self.samples_written < Self::MAX_SAMPLES {
1743                        *expected =
1744                            Some(NonZero::new(self.samples_written).ok_or(Error::NoSamples)?);
1745                    } else {
1746                        // TODO - should I just leave this blank
1747                        // if too many samples are written?
1748                        return Err(Error::ExcessiveTotalSamples);
1749                    }
1750                }
1751            }
1752
1753            // update STREAMINFO MD5 sum
1754            self.blocks.streaminfo_mut().md5 = Some(self.md5.clone().compute().0);
1755
1756            // rewrite metadata blocks, relative to the beginning
1757            // of the stream
1758            let writer = self.writer.stream();
1759            writer.seek(std::io::SeekFrom::Start(self.start))?;
1760            write_blocks(writer.by_ref(), self.blocks.blocks())
1761        } else {
1762            Ok(())
1763        }
1764    }
1765}
1766
1767impl<W: std::io::Write + std::io::Seek> Drop for Encoder<W> {
1768    fn drop(&mut self) {
1769        let _ = self.finalize_inner();
1770    }
1771}
1772
1773// Unlike regular SeekPoints, which can have placeholder variants,
1774// these are always defined to be something.  A byte offset
1775// of None indicates a dummy encoder point
1776#[derive(Debug, Clone)]
1777struct EncoderSeekPoint {
1778    sample_offset: u64,
1779    byte_offset: Option<u64>,
1780    frame_samples: u16,
1781}
1782
1783impl EncoderSeekPoint {
1784    // generates set of placeholder points
1785    fn placeholders(total_samples: u64, block_size: u16) -> impl Iterator<Item = EncoderSeekPoint> {
1786        (0..total_samples)
1787            .step_by(usize::from(block_size))
1788            .map(move |sample_offset| EncoderSeekPoint {
1789                sample_offset,
1790                byte_offset: None,
1791                frame_samples: u16::try_from(total_samples - sample_offset)
1792                    .map(|s| s.min(block_size))
1793                    .unwrap_or(block_size),
1794            })
1795    }
1796
1797    // returns sample range of point
1798    fn range(&self) -> std::ops::Range<u64> {
1799        self.sample_offset..(self.sample_offset + u64::from(self.frame_samples))
1800    }
1801}
1802
1803impl From<EncoderSeekPoint> for SeekPoint {
1804    fn from(p: EncoderSeekPoint) -> Self {
1805        match p.byte_offset {
1806            Some(byte_offset) => Self::Defined {
1807                sample_offset: p.sample_offset,
1808                byte_offset,
1809                frame_samples: p.frame_samples,
1810            },
1811            None => Self::Placeholder,
1812        }
1813    }
1814}
1815
1816/// Given a FLAC stream, generates new seek table
1817///
1818/// Though encoders should add seek tables by default,
1819/// sometimes one isn't present.  This function takes
1820/// an existing FLAC file stream and generates a new
1821/// seek table suitable for adding to the file's metadata
1822/// via the [`crate::metadata::update`] function.
1823///
1824/// The stream should be rewound to the beginning of the file.
1825///
1826/// # Errors
1827///
1828/// Returns any error from the underlying stream.
1829///
1830/// # Example
1831/// ```
1832/// use flac_codec::{
1833///     encode::{FlacSampleWriter, Options, SeekTableInterval, generate_seektable},
1834///     metadata::{SeekTable, read_block},
1835/// };
1836/// use std::io::{Cursor, Seek};
1837///
1838/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
1839///
1840/// // add a seekpoint every second
1841/// let options = Options::default().seektable_seconds(1);
1842///
1843/// let mut writer = FlacSampleWriter::new(
1844///     &mut flac,         // our wrapped writer
1845///     options,           // our seektable options
1846///     44100,             // sample rate
1847///     16,                // bits-per-sample
1848///     1,                 // channel count
1849///     Some(60 * 44100),  // one minute's worth of samples
1850/// ).unwrap();
1851///
1852/// // write one minute's worth of samples
1853/// writer.write(vec![0; 60 * 44100].as_slice()).unwrap();
1854///
1855/// // finalize writing file
1856/// assert!(writer.finalize().is_ok());
1857///
1858/// flac.rewind().unwrap();
1859///
1860/// // get existing seektable
1861/// let original_seektable = match read_block::<_, SeekTable>(&mut flac) {
1862///     Ok(Some(seektable)) => seektable,
1863///     _ => panic!("seektable not found"),
1864/// };
1865///
1866/// flac.rewind().unwrap();
1867///
1868/// // generate new seektable, also with seekpoints every second
1869/// let new_seektable = generate_seektable(
1870///     flac,
1871///     SeekTableInterval::Seconds(1.try_into().unwrap())
1872/// ).unwrap();
1873///
1874/// // ensure both seektables are identical
1875/// assert_eq!(original_seektable, new_seektable);
1876/// ```
1877pub fn generate_seektable<R: std::io::Read>(
1878    r: R,
1879    interval: SeekTableInterval,
1880) -> Result<crate::metadata::SeekTable, Error> {
1881    use crate::{
1882        metadata::{Metadata, SeekTable},
1883        stream::FrameIterator,
1884    };
1885
1886    let iter = FrameIterator::new(r)?;
1887    let metadata_len = iter.metadata_len();
1888    let sample_rate = iter.sample_rate();
1889    let mut sample_offset = 0;
1890
1891    iter.map(|r| {
1892        r.map(|(frame, offset)| EncoderSeekPoint {
1893            sample_offset,
1894            byte_offset: Some(offset - metadata_len),
1895            frame_samples: frame.header.block_size.into(),
1896        })
1897        .inspect(|p| {
1898            sample_offset += u64::from(p.frame_samples);
1899        })
1900    })
1901    .collect::<Result<Vec<_>, _>>()
1902    .map(|seekpoints| SeekTable {
1903        points: interval
1904            .filter(sample_rate, seekpoints)
1905            .take(SeekTable::MAX_POINTS)
1906            .map(|p| p.into())
1907            .collect::<Vec<_>>()
1908            .try_into()
1909            .unwrap(),
1910    })
1911}
1912
1913fn encode_frame<W>(
1914    options: &EncoderOptions,
1915    cache: &mut EncodingCaches,
1916    mut writer: W,
1917    streaminfo: &mut Streaminfo,
1918    frame_number: &mut FrameNumber,
1919    sample_rate: SampleRate<u32>,
1920    frame: ArrayVec<&[i32], MAX_CHANNELS>,
1921) -> Result<(), Error>
1922where
1923    W: std::io::Write,
1924{
1925    use crate::Counter;
1926    use crate::crc::{Crc16, CrcWriter};
1927    use crate::stream::FrameHeader;
1928    use bitstream_io::BigEndian;
1929
1930    debug_assert!(!frame.is_empty());
1931
1932    let size = Counter::new(writer.by_ref());
1933    let mut w: CrcWriter<_, Crc16> = CrcWriter::new(size);
1934    let mut bw: BitWriter<CrcWriter<Counter<&mut W>, Crc16>, BigEndian>;
1935
1936    match frame.as_slice() {
1937        [channel] => {
1938            FrameHeader {
1939                blocking_strategy: false,
1940                frame_number: *frame_number,
1941                block_size: (channel.len() as u16)
1942                    .try_into()
1943                    .expect("frame cannot be empty"),
1944                sample_rate,
1945                bits_per_sample: streaminfo.bits_per_sample.into(),
1946                channel_assignment: ChannelAssignment::Independent(Independent::Mono),
1947            }
1948            .write(&mut w, streaminfo)?;
1949
1950            bw = BitWriter::new(w);
1951
1952            cache.channels.resize_with(1, ChannelCache::default);
1953
1954            encode_subframe(
1955                options,
1956                &mut cache.channels[0],
1957                CorrelatedChannel::independent(streaminfo.bits_per_sample, channel),
1958            )?
1959            .playback(&mut bw)?;
1960        }
1961        [left, right] if options.exhaustive_channel_correlation => {
1962            let Correlated {
1963                channel_assignment,
1964                channels: [channel_0, channel_1],
1965            } = correlate_channels_exhaustive(
1966                options,
1967                &mut cache.correlated,
1968                [left, right],
1969                streaminfo.bits_per_sample,
1970            )?;
1971
1972            FrameHeader {
1973                blocking_strategy: false,
1974                frame_number: *frame_number,
1975                block_size: (frame[0].len() as u16)
1976                    .try_into()
1977                    .expect("frame cannot be empty"),
1978                sample_rate,
1979                bits_per_sample: streaminfo.bits_per_sample.into(),
1980                channel_assignment,
1981            }
1982            .write(&mut w, streaminfo)?;
1983
1984            bw = BitWriter::new(w);
1985
1986            channel_0.playback(&mut bw)?;
1987            channel_1.playback(&mut bw)?;
1988        }
1989        [left, right] => {
1990            let Correlated {
1991                channel_assignment,
1992                channels: [channel_0, channel_1],
1993            } = correlate_channels(
1994                options,
1995                &mut cache.correlated,
1996                [left, right],
1997                streaminfo.bits_per_sample,
1998            );
1999
2000            FrameHeader {
2001                blocking_strategy: false,
2002                frame_number: *frame_number,
2003                block_size: (frame[0].len() as u16)
2004                    .try_into()
2005                    .expect("frame cannot be empty"),
2006                sample_rate,
2007                bits_per_sample: streaminfo.bits_per_sample.into(),
2008                channel_assignment,
2009            }
2010            .write(&mut w, streaminfo)?;
2011
2012            cache.channels.resize_with(2, ChannelCache::default);
2013            let [cache_0, cache_1] = cache.channels.get_disjoint_mut([0, 1]).unwrap();
2014            let (channel_0, channel_1) = join(
2015                || encode_subframe(options, cache_0, channel_0),
2016                || encode_subframe(options, cache_1, channel_1),
2017            );
2018
2019            bw = BitWriter::new(w);
2020
2021            channel_0?.playback(&mut bw)?;
2022            channel_1?.playback(&mut bw)?;
2023        }
2024        channels => {
2025            // non-stereo frames are always encoded independently
2026
2027            FrameHeader {
2028                blocking_strategy: false,
2029                frame_number: *frame_number,
2030                block_size: (channels[0].len() as u16)
2031                    .try_into()
2032                    .expect("frame cannot be empty"),
2033                sample_rate,
2034                bits_per_sample: streaminfo.bits_per_sample.into(),
2035                channel_assignment: ChannelAssignment::Independent(
2036                    frame.len().try_into().expect("invalid channel count"),
2037                ),
2038            }
2039            .write(&mut w, streaminfo)?;
2040
2041            bw = BitWriter::new(w);
2042
2043            cache
2044                .channels
2045                .resize_with(channels.len(), ChannelCache::default);
2046
2047            vec_map(
2048                cache.channels.iter_mut().zip(channels).collect(),
2049                |(cache, channel)| {
2050                    encode_subframe(
2051                        options,
2052                        cache,
2053                        CorrelatedChannel::independent(streaminfo.bits_per_sample, channel),
2054                    )
2055                },
2056            )
2057            .into_iter()
2058            .try_for_each(|r| r.and_then(|r| r.playback(bw.by_ref()).map_err(Error::Io)))?;
2059        }
2060    }
2061
2062    let crc16: u16 = bw.aligned_writer()?.checksum().into();
2063    bw.write_from(crc16)?;
2064
2065    frame_number.try_increment()?;
2066
2067    // update minimum and maximum frame size values
2068    if let s @ Some(size) = u32::try_from(bw.into_writer().into_writer().count)
2069        .ok()
2070        .filter(|size| *size < Streaminfo::MAX_FRAME_SIZE)
2071        .and_then(NonZero::new)
2072    {
2073        match &mut streaminfo.minimum_frame_size {
2074            Some(min_size) => {
2075                *min_size = size.min(*min_size);
2076            }
2077            min_size @ None => {
2078                *min_size = s;
2079            }
2080        }
2081
2082        match &mut streaminfo.maximum_frame_size {
2083            Some(max_size) => {
2084                *max_size = size.max(*max_size);
2085            }
2086            max_size @ None => {
2087                *max_size = s;
2088            }
2089        }
2090    }
2091
2092    Ok(())
2093}
2094
2095struct Correlated<C> {
2096    channel_assignment: ChannelAssignment,
2097    channels: [C; 2],
2098}
2099
2100struct CorrelatedChannel<'c> {
2101    samples: &'c [i32],
2102    bits_per_sample: SignedBitCount<32>,
2103    // whether all samples are known to be 0
2104    all_0: bool,
2105}
2106
2107impl<'c> CorrelatedChannel<'c> {
2108    fn independent(bits_per_sample: SignedBitCount<32>, samples: &'c [i32]) -> Self {
2109        Self {
2110            all_0: samples.iter().all(|s| *s == 0),
2111            bits_per_sample,
2112            samples,
2113        }
2114    }
2115}
2116
2117fn correlate_channels<'c>(
2118    options: &EncoderOptions,
2119    CorrelationCache {
2120        average_samples,
2121        difference_samples,
2122        ..
2123    }: &'c mut CorrelationCache,
2124    [left, right]: [&'c [i32]; 2],
2125    bits_per_sample: SignedBitCount<32>,
2126) -> Correlated<CorrelatedChannel<'c>> {
2127    match bits_per_sample.checked_add::<32>(1) {
2128        Some(difference_bits_per_sample) if options.mid_side => {
2129            let mut left_abs_sum = 0;
2130            let mut right_abs_sum = 0;
2131            let mut mid_abs_sum = 0;
2132            let mut side_abs_sum = 0;
2133
2134            join(
2135                || {
2136                    average_samples.clear();
2137                    average_samples.extend(
2138                        left.iter()
2139                            .inspect(|s| left_abs_sum += u64::from(s.unsigned_abs()))
2140                            .zip(
2141                                right
2142                                    .iter()
2143                                    .inspect(|s| right_abs_sum += u64::from(s.unsigned_abs())),
2144                            )
2145                            .map(|(l, r)| (l + r) >> 1)
2146                            .inspect(|s| mid_abs_sum += u64::from(s.unsigned_abs())),
2147                    );
2148                },
2149                || {
2150                    difference_samples.clear();
2151                    difference_samples.extend(
2152                        left.iter()
2153                            .zip(right)
2154                            .map(|(l, r)| l - r)
2155                            .inspect(|s| side_abs_sum += u64::from(s.unsigned_abs())),
2156                    );
2157                },
2158            );
2159
2160            match [
2161                (
2162                    ChannelAssignment::Independent(Independent::Stereo),
2163                    left_abs_sum + right_abs_sum,
2164                ),
2165                (ChannelAssignment::LeftSide, left_abs_sum + side_abs_sum),
2166                (ChannelAssignment::SideRight, side_abs_sum + right_abs_sum),
2167                (ChannelAssignment::MidSide, mid_abs_sum + side_abs_sum),
2168            ]
2169            .into_iter()
2170            .min_by_key(|(_, total)| *total)
2171            .unwrap()
2172            .0
2173            {
2174                channel_assignment @ ChannelAssignment::LeftSide => Correlated {
2175                    channel_assignment,
2176                    channels: [
2177                        CorrelatedChannel {
2178                            samples: left,
2179                            bits_per_sample,
2180                            all_0: left_abs_sum == 0,
2181                        },
2182                        CorrelatedChannel {
2183                            samples: difference_samples,
2184                            bits_per_sample: difference_bits_per_sample,
2185                            all_0: side_abs_sum == 0,
2186                        },
2187                    ],
2188                },
2189                channel_assignment @ ChannelAssignment::SideRight => Correlated {
2190                    channel_assignment,
2191                    channels: [
2192                        CorrelatedChannel {
2193                            samples: difference_samples,
2194                            bits_per_sample: difference_bits_per_sample,
2195                            all_0: side_abs_sum == 0,
2196                        },
2197                        CorrelatedChannel {
2198                            samples: right,
2199                            bits_per_sample,
2200                            all_0: right_abs_sum == 0,
2201                        },
2202                    ],
2203                },
2204                channel_assignment @ ChannelAssignment::MidSide => Correlated {
2205                    channel_assignment,
2206                    channels: [
2207                        CorrelatedChannel {
2208                            samples: average_samples,
2209                            bits_per_sample,
2210                            all_0: mid_abs_sum == 0,
2211                        },
2212                        CorrelatedChannel {
2213                            samples: difference_samples,
2214                            bits_per_sample: difference_bits_per_sample,
2215                            all_0: side_abs_sum == 0,
2216                        },
2217                    ],
2218                },
2219                channel_assignment @ ChannelAssignment::Independent(_) => Correlated {
2220                    channel_assignment,
2221                    channels: [
2222                        CorrelatedChannel {
2223                            samples: left,
2224                            bits_per_sample,
2225                            all_0: left_abs_sum == 0,
2226                        },
2227                        CorrelatedChannel {
2228                            samples: right,
2229                            bits_per_sample,
2230                            all_0: right_abs_sum == 0,
2231                        },
2232                    ],
2233                },
2234            }
2235        }
2236        Some(difference_bits_per_sample) => {
2237            let mut left_abs_sum = 0;
2238            let mut right_abs_sum = 0;
2239            let mut side_abs_sum = 0;
2240
2241            difference_samples.clear();
2242            difference_samples.extend(
2243                left.iter()
2244                    .inspect(|s| left_abs_sum += u64::from(s.unsigned_abs()))
2245                    .zip(
2246                        right
2247                            .iter()
2248                            .inspect(|s| right_abs_sum += u64::from(s.unsigned_abs())),
2249                    )
2250                    .map(|(l, r)| l - r)
2251                    .inspect(|s| side_abs_sum += u64::from(s.unsigned_abs())),
2252            );
2253
2254            match [
2255                (ChannelAssignment::LeftSide, left_abs_sum + side_abs_sum),
2256                (ChannelAssignment::SideRight, side_abs_sum + right_abs_sum),
2257                (
2258                    ChannelAssignment::Independent(Independent::Stereo),
2259                    left_abs_sum + right_abs_sum,
2260                ),
2261            ]
2262            .into_iter()
2263            .min_by_key(|(_, total)| *total)
2264            .unwrap()
2265            .0
2266            {
2267                channel_assignment @ ChannelAssignment::LeftSide => Correlated {
2268                    channel_assignment,
2269                    channels: [
2270                        CorrelatedChannel {
2271                            samples: left,
2272                            bits_per_sample,
2273                            all_0: left_abs_sum == 0,
2274                        },
2275                        CorrelatedChannel {
2276                            samples: difference_samples,
2277                            bits_per_sample: difference_bits_per_sample,
2278                            all_0: side_abs_sum == 0,
2279                        },
2280                    ],
2281                },
2282                channel_assignment @ ChannelAssignment::SideRight => Correlated {
2283                    channel_assignment,
2284                    channels: [
2285                        CorrelatedChannel {
2286                            samples: difference_samples,
2287                            bits_per_sample: difference_bits_per_sample,
2288                            all_0: side_abs_sum == 0,
2289                        },
2290                        CorrelatedChannel {
2291                            samples: right,
2292                            bits_per_sample,
2293                            all_0: right_abs_sum == 0,
2294                        },
2295                    ],
2296                },
2297                ChannelAssignment::MidSide => unreachable!(),
2298                channel_assignment @ ChannelAssignment::Independent(_) => Correlated {
2299                    channel_assignment,
2300                    channels: [
2301                        CorrelatedChannel {
2302                            samples: left,
2303                            bits_per_sample,
2304                            all_0: left_abs_sum == 0,
2305                        },
2306                        CorrelatedChannel {
2307                            samples: right,
2308                            bits_per_sample,
2309                            all_0: right_abs_sum == 0,
2310                        },
2311                    ],
2312                },
2313            }
2314        }
2315        None => {
2316            // 32 bps stream, so forego difference channel
2317            // and encode them both indepedently
2318
2319            Correlated {
2320                channel_assignment: ChannelAssignment::Independent(Independent::Stereo),
2321                channels: [
2322                    CorrelatedChannel::independent(bits_per_sample, left),
2323                    CorrelatedChannel::independent(bits_per_sample, right),
2324                ],
2325            }
2326        }
2327    }
2328}
2329
2330fn correlate_channels_exhaustive<'c>(
2331    options: &EncoderOptions,
2332    CorrelationCache {
2333        average_samples,
2334        difference_samples,
2335        left_cache,
2336        right_cache,
2337        average_cache,
2338        difference_cache,
2339        ..
2340    }: &'c mut CorrelationCache,
2341    [left, right]: [&'c [i32]; 2],
2342    bits_per_sample: SignedBitCount<32>,
2343) -> Result<Correlated<&'c BitRecorder<u32, BigEndian>>, Error> {
2344    let (left_recorder, right_recorder) = try_join(
2345        || {
2346            encode_subframe(
2347                options,
2348                left_cache,
2349                CorrelatedChannel {
2350                    samples: left,
2351                    bits_per_sample,
2352                    all_0: false,
2353                },
2354            )
2355        },
2356        || {
2357            encode_subframe(
2358                options,
2359                right_cache,
2360                CorrelatedChannel {
2361                    samples: right,
2362                    bits_per_sample,
2363                    all_0: false,
2364                },
2365            )
2366        },
2367    )?;
2368
2369    match bits_per_sample.checked_add::<32>(1) {
2370        Some(difference_bits_per_sample) if options.mid_side => {
2371            let (average_recorder, difference_recorder) = try_join(
2372                || {
2373                    average_samples.clear();
2374                    average_samples
2375                        .extend(left.iter().zip(right.iter()).map(|(l, r)| (l + r) >> 1));
2376                    encode_subframe(
2377                        options,
2378                        average_cache,
2379                        CorrelatedChannel {
2380                            samples: average_samples,
2381                            bits_per_sample,
2382                            all_0: false,
2383                        },
2384                    )
2385                },
2386                || {
2387                    difference_samples.clear();
2388                    difference_samples.extend(left.iter().zip(right).map(|(l, r)| l - r));
2389                    encode_subframe(
2390                        options,
2391                        difference_cache,
2392                        CorrelatedChannel {
2393                            samples: difference_samples,
2394                            bits_per_sample: difference_bits_per_sample,
2395                            all_0: false,
2396                        },
2397                    )
2398                },
2399            )?;
2400
2401            match [
2402                (
2403                    ChannelAssignment::Independent(Independent::Stereo),
2404                    left_recorder.written() + right_recorder.written(),
2405                ),
2406                (
2407                    ChannelAssignment::LeftSide,
2408                    left_recorder.written() + difference_recorder.written(),
2409                ),
2410                (
2411                    ChannelAssignment::SideRight,
2412                    difference_recorder.written() + right_recorder.written(),
2413                ),
2414                (
2415                    ChannelAssignment::MidSide,
2416                    average_recorder.written() + difference_recorder.written(),
2417                ),
2418            ]
2419            .into_iter()
2420            .min_by_key(|(_, total)| *total)
2421            .unwrap()
2422            .0
2423            {
2424                channel_assignment @ ChannelAssignment::LeftSide => Ok(Correlated {
2425                    channel_assignment,
2426                    channels: [left_recorder, difference_recorder],
2427                }),
2428                channel_assignment @ ChannelAssignment::SideRight => Ok(Correlated {
2429                    channel_assignment,
2430                    channels: [difference_recorder, right_recorder],
2431                }),
2432                channel_assignment @ ChannelAssignment::MidSide => Ok(Correlated {
2433                    channel_assignment,
2434                    channels: [average_recorder, difference_recorder],
2435                }),
2436                channel_assignment @ ChannelAssignment::Independent(_) => Ok(Correlated {
2437                    channel_assignment,
2438                    channels: [left_recorder, right_recorder],
2439                }),
2440            }
2441        }
2442        Some(difference_bits_per_sample) => {
2443            let difference_recorder = {
2444                difference_samples.clear();
2445                difference_samples.extend(left.iter().zip(right).map(|(l, r)| l - r));
2446                encode_subframe(
2447                    options,
2448                    difference_cache,
2449                    CorrelatedChannel {
2450                        samples: difference_samples,
2451                        bits_per_sample: difference_bits_per_sample,
2452                        all_0: false,
2453                    },
2454                )?
2455            };
2456
2457            match [
2458                (
2459                    ChannelAssignment::Independent(Independent::Stereo),
2460                    left_recorder.written() + right_recorder.written(),
2461                ),
2462                (
2463                    ChannelAssignment::LeftSide,
2464                    left_recorder.written() + difference_recorder.written(),
2465                ),
2466                (
2467                    ChannelAssignment::SideRight,
2468                    difference_recorder.written() + right_recorder.written(),
2469                ),
2470            ]
2471            .into_iter()
2472            .min_by_key(|(_, total)| *total)
2473            .unwrap()
2474            .0
2475            {
2476                channel_assignment @ ChannelAssignment::LeftSide => Ok(Correlated {
2477                    channel_assignment,
2478                    channels: [left_recorder, difference_recorder],
2479                }),
2480                channel_assignment @ ChannelAssignment::SideRight => Ok(Correlated {
2481                    channel_assignment,
2482                    channels: [difference_recorder, right_recorder],
2483                }),
2484                ChannelAssignment::MidSide => unreachable!(),
2485                channel_assignment @ ChannelAssignment::Independent(_) => Ok(Correlated {
2486                    channel_assignment,
2487                    channels: [left_recorder, right_recorder],
2488                }),
2489            }
2490        }
2491        None => {
2492            // 32 bps stream, so forego difference channel
2493            // and encode them both indepedently
2494
2495            Ok(Correlated {
2496                channel_assignment: ChannelAssignment::Independent(Independent::Stereo),
2497                channels: [left_recorder, right_recorder],
2498            })
2499        }
2500    }
2501}
2502
2503fn encode_subframe<'c>(
2504    options: &EncoderOptions,
2505    ChannelCache {
2506        fixed: fixed_cache,
2507        fixed_output,
2508        lpc: lpc_cache,
2509        lpc_output,
2510        constant_output,
2511        verbatim_output,
2512        wasted,
2513    }: &'c mut ChannelCache,
2514    CorrelatedChannel {
2515        samples: channel,
2516        bits_per_sample,
2517        all_0,
2518    }: CorrelatedChannel,
2519) -> Result<&'c BitRecorder<u32, BigEndian>, Error> {
2520    const WASTED_MAX: NonZero<u32> = NonZero::new(32).unwrap();
2521
2522    debug_assert!(!channel.is_empty());
2523
2524    if all_0 {
2525        // all samples are 0
2526        constant_output.clear();
2527        encode_constant_subframe(constant_output, channel[0], bits_per_sample, 0)?;
2528        return Ok(constant_output);
2529    }
2530
2531    // determine any wasted bits
2532    let (channel, bits_per_sample, wasted_bps) =
2533        match channel.iter().try_fold(WASTED_MAX, |acc, sample| {
2534            NonZero::new(sample.trailing_zeros()).map(|sample| sample.min(acc))
2535        }) {
2536            None => (channel, bits_per_sample, 0),
2537            Some(WASTED_MAX) => {
2538                constant_output.clear();
2539                encode_constant_subframe(constant_output, channel[0], bits_per_sample, 0)?;
2540                return Ok(constant_output);
2541            }
2542            Some(wasted_bps) => {
2543                let wasted_bps = wasted_bps.get();
2544                wasted.clear();
2545                wasted.extend(channel.iter().map(|sample| sample >> wasted_bps));
2546                (
2547                    wasted.as_slice(),
2548                    bits_per_sample.checked_sub(wasted_bps).unwrap(),
2549                    wasted_bps,
2550                )
2551            }
2552        };
2553
2554    fixed_output.clear();
2555
2556    let best = match options.max_lpc_order {
2557        Some(max_lpc_order) => {
2558            lpc_output.clear();
2559
2560            match join(
2561                || {
2562                    encode_fixed_subframe(
2563                        options,
2564                        fixed_cache,
2565                        fixed_output,
2566                        channel,
2567                        bits_per_sample,
2568                        wasted_bps,
2569                    )
2570                },
2571                || {
2572                    encode_lpc_subframe(
2573                        options,
2574                        max_lpc_order,
2575                        lpc_cache,
2576                        lpc_output,
2577                        channel,
2578                        bits_per_sample,
2579                        wasted_bps,
2580                    )
2581                },
2582            ) {
2583                (Ok(()), Ok(())) => [fixed_output, lpc_output]
2584                    .into_iter()
2585                    .min_by_key(|c| c.written())
2586                    .unwrap(),
2587                (Err(_), Ok(())) => lpc_output,
2588                (Ok(()), Err(_)) => fixed_output,
2589                (Err(_), Err(_)) => {
2590                    verbatim_output.clear();
2591                    encode_verbatim_subframe(
2592                        verbatim_output,
2593                        channel,
2594                        bits_per_sample,
2595                        wasted_bps,
2596                    )?;
2597                    return Ok(verbatim_output);
2598                }
2599            }
2600        }
2601        _ => {
2602            match encode_fixed_subframe(
2603                options,
2604                fixed_cache,
2605                fixed_output,
2606                channel,
2607                bits_per_sample,
2608                wasted_bps,
2609            ) {
2610                Ok(()) => fixed_output,
2611                Err(_) => {
2612                    verbatim_output.clear();
2613                    encode_verbatim_subframe(
2614                        verbatim_output,
2615                        channel,
2616                        bits_per_sample,
2617                        wasted_bps,
2618                    )?;
2619                    return Ok(verbatim_output);
2620                }
2621            }
2622        }
2623    };
2624
2625    let verbatim_len = channel.len() as u32 * u32::from(bits_per_sample);
2626
2627    if best.written() < verbatim_len {
2628        Ok(best)
2629    } else {
2630        verbatim_output.clear();
2631        encode_verbatim_subframe(verbatim_output, channel, bits_per_sample, wasted_bps)?;
2632        Ok(verbatim_output)
2633    }
2634}
2635
2636fn encode_constant_subframe<W: BitWrite>(
2637    writer: &mut W,
2638    sample: i32,
2639    bits_per_sample: SignedBitCount<32>,
2640    wasted_bps: u32,
2641) -> Result<(), Error> {
2642    use crate::stream::{SubframeHeader, SubframeHeaderType};
2643
2644    writer.build(&SubframeHeader {
2645        type_: SubframeHeaderType::Constant,
2646        wasted_bps,
2647    })?;
2648
2649    writer
2650        .write_signed_counted(bits_per_sample, sample)
2651        .map_err(Error::Io)
2652}
2653
2654fn encode_verbatim_subframe<W: BitWrite>(
2655    writer: &mut W,
2656    channel: &[i32],
2657    bits_per_sample: SignedBitCount<32>,
2658    wasted_bps: u32,
2659) -> Result<(), Error> {
2660    use crate::stream::{SubframeHeader, SubframeHeaderType};
2661
2662    writer.build(&SubframeHeader {
2663        type_: SubframeHeaderType::Verbatim,
2664        wasted_bps,
2665    })?;
2666
2667    channel
2668        .iter()
2669        .try_for_each(|i| writer.write_signed_counted(bits_per_sample, *i))?;
2670
2671    Ok(())
2672}
2673
2674fn encode_fixed_subframe<W: BitWrite>(
2675    options: &EncoderOptions,
2676    FixedCache {
2677        fixed_buffers: buffers,
2678    }: &mut FixedCache,
2679    writer: &mut W,
2680    channel: &[i32],
2681    bits_per_sample: SignedBitCount<32>,
2682    wasted_bps: u32,
2683) -> Result<(), Error> {
2684    use crate::stream::{SubframeHeader, SubframeHeaderType};
2685
2686    // calculate residuals for FIXED subframe orders 0-4
2687    // (or fewer, if we don't have enough samples)
2688    let (order, warm_up, residuals) = {
2689        let mut fixed_orders = ArrayVec::<&[i32], 5>::new();
2690        fixed_orders.push(channel);
2691
2692        // accumulate a set of FIXED diffs
2693        'outer: for buf in buffers.iter_mut() {
2694            let prev_order = fixed_orders.last().unwrap();
2695            match prev_order.split_at_checked(1) {
2696                Some((_, r)) => {
2697                    buf.clear();
2698                    for (n, p) in r.iter().zip(*prev_order) {
2699                        match n.checked_sub(*p) {
2700                            Some(v) => {
2701                                buf.push(v);
2702                            }
2703                            None => break 'outer,
2704                        }
2705                    }
2706                    if buf.is_empty() {
2707                        break;
2708                    } else {
2709                        fixed_orders.push(buf.as_slice());
2710                    }
2711                }
2712                None => break,
2713            }
2714        }
2715
2716        let min_fixed = fixed_orders.last().unwrap().len();
2717
2718        // choose diff with the smallest abs sum
2719        fixed_orders
2720            .into_iter()
2721            .enumerate()
2722            .min_by_key(|(_, residuals)| {
2723                residuals[(residuals.len() - min_fixed)..]
2724                    .iter()
2725                    .map(|r| u64::from(r.unsigned_abs()))
2726                    .sum::<u64>()
2727            })
2728            .map(|(order, residuals)| (order as u8, &channel[0..order], residuals))
2729            .unwrap()
2730    };
2731
2732    writer.build(&SubframeHeader {
2733        type_: SubframeHeaderType::Fixed { order },
2734        wasted_bps,
2735    })?;
2736
2737    warm_up
2738        .iter()
2739        .try_for_each(|sample: &i32| writer.write_signed_counted(bits_per_sample, *sample))?;
2740
2741    write_residuals(options, writer, order.into(), residuals)
2742}
2743
2744fn encode_lpc_subframe<W: BitWrite>(
2745    options: &EncoderOptions,
2746    max_lpc_order: NonZero<u8>,
2747    cache: &mut LpcCache,
2748    writer: &mut W,
2749    channel: &[i32],
2750    bits_per_sample: SignedBitCount<32>,
2751    wasted_bps: u32,
2752) -> Result<(), Error> {
2753    use crate::stream::{SubframeHeader, SubframeHeaderType};
2754
2755    let LpcSubframeParameters {
2756        warm_up,
2757        residuals,
2758        parameters:
2759            LpcParameters {
2760                order,
2761                precision,
2762                shift,
2763                coefficients,
2764            },
2765    } = LpcSubframeParameters::best(options, bits_per_sample, max_lpc_order, cache, channel)?;
2766
2767    writer.build(&SubframeHeader {
2768        type_: SubframeHeaderType::Lpc { order },
2769        wasted_bps,
2770    })?;
2771
2772    for sample in warm_up {
2773        writer.write_signed_counted(bits_per_sample, *sample)?;
2774    }
2775
2776    writer.write_count::<0b1111>(
2777        precision
2778            .count()
2779            .checked_sub(1)
2780            .ok_or(Error::InvalidQlpPrecision)?,
2781    )?;
2782
2783    writer.write::<5, i32>(shift as i32)?;
2784
2785    for coeff in coefficients {
2786        writer.write_signed_counted(precision, coeff)?;
2787    }
2788
2789    write_residuals(options, writer, order.get().into(), residuals)
2790}
2791
2792struct LpcSubframeParameters<'w, 'r> {
2793    parameters: LpcParameters,
2794    warm_up: &'w [i32],
2795    residuals: &'r [i32],
2796}
2797
2798impl<'w, 'r> LpcSubframeParameters<'w, 'r> {
2799    fn best(
2800        options: &EncoderOptions,
2801        bits_per_sample: SignedBitCount<32>,
2802        max_lpc_order: NonZero<u8>,
2803        LpcCache {
2804            residuals,
2805            window,
2806            windowed,
2807        }: &'r mut LpcCache,
2808        channel: &'w [i32],
2809    ) -> Result<Self, Error> {
2810        let parameters = LpcParameters::best(
2811            options,
2812            bits_per_sample,
2813            max_lpc_order,
2814            window,
2815            windowed,
2816            channel,
2817        )?;
2818
2819        Self::encode_residuals(&parameters, channel, residuals)
2820            .map(|(warm_up, residuals)| Self {
2821                warm_up,
2822                residuals,
2823                parameters,
2824            })
2825            .map_err(|ResidualOverflow| Error::ResidualOverflow)
2826    }
2827
2828    fn encode_residuals(
2829        parameters: &LpcParameters,
2830        channel: &'w [i32],
2831        residuals: &'r mut Vec<i32>,
2832    ) -> Result<(&'w [i32], &'r [i32]), ResidualOverflow> {
2833        residuals.clear();
2834
2835        for split in usize::from(parameters.order.get())..channel.len() {
2836            let (previous, current) = channel.split_at(split);
2837
2838            residuals.push(
2839                current[0]
2840                    .checked_sub(
2841                        (previous
2842                            .iter()
2843                            .rev()
2844                            .zip(&parameters.coefficients)
2845                            .map(|(x, y)| *x as i64 * *y as i64)
2846                            .sum::<i64>()
2847                            >> parameters.shift) as i32,
2848                    )
2849                    .ok_or(ResidualOverflow)?,
2850            );
2851        }
2852
2853        Ok((
2854            &channel[0..parameters.order.get().into()],
2855            residuals.as_slice(),
2856        ))
2857    }
2858}
2859
2860#[derive(Debug)]
2861struct ResidualOverflow;
2862
2863impl From<ResidualOverflow> for Error {
2864    #[inline]
2865    fn from(_: ResidualOverflow) -> Self {
2866        Error::ResidualOverflow
2867    }
2868}
2869
2870#[test]
2871fn test_residual_encoding_1() {
2872    let samples = [
2873        0, 16, 31, 44, 54, 61, 64, 63, 58, 49, 38, 24, 8, -8, -24, -38, -49, -58, -63, -64, -61,
2874        -54, -44, -31, -16,
2875    ];
2876
2877    let expected_residuals = [
2878        2, 2, 2, 3, 3, 3, 2, 2, 3, 0, 0, 0, -1, -1, -1, -3, -2, -2, -2, -1, -1, 0, 0,
2879    ];
2880
2881    let mut actual_residuals = Vec::with_capacity(expected_residuals.len());
2882
2883    let (warm_up, residuals) = LpcSubframeParameters::encode_residuals(
2884        &LpcParameters {
2885            order: NonZero::new(2).unwrap(),
2886            precision: SignedBitCount::new::<7>(),
2887            shift: 5,
2888            coefficients: arrayvec![59, -30],
2889        },
2890        &samples,
2891        &mut actual_residuals,
2892    )
2893    .unwrap();
2894
2895    assert_eq!(warm_up, &samples[0..2]);
2896    assert_eq!(residuals, &expected_residuals);
2897}
2898
2899#[test]
2900fn test_residual_encoding_2() {
2901    let samples = [
2902        64, 62, 56, 47, 34, 20, 4, -12, -27, -41, -52, -60, -63, -63, -60, -52, -41, -27, -12, 4,
2903        20, 34, 47, 56, 62,
2904    ];
2905
2906    let expected_residuals = [
2907        2, 2, 0, 1, -1, -1, -1, -2, -2, -2, -1, -3, -2, 0, -1, 1, 0, 2, 2, 2, 4, 2, 4,
2908    ];
2909
2910    let mut actual_residuals = Vec::with_capacity(expected_residuals.len());
2911
2912    let (warm_up, residuals) = LpcSubframeParameters::encode_residuals(
2913        &LpcParameters {
2914            order: NonZero::new(2).unwrap(),
2915            precision: SignedBitCount::new::<7>(),
2916            shift: 5,
2917            coefficients: arrayvec![58, -29],
2918        },
2919        &samples,
2920        &mut actual_residuals,
2921    )
2922    .unwrap();
2923
2924    assert_eq!(warm_up, &samples[0..2]);
2925    assert_eq!(residuals, &expected_residuals);
2926}
2927
2928#[derive(Debug)]
2929struct LpcParameters {
2930    order: NonZero<u8>,
2931    precision: SignedBitCount<15>,
2932    shift: u32,
2933    coefficients: ArrayVec<i32, MAX_LPC_COEFFS>,
2934}
2935
2936// There isn't any particular *best* way to determine
2937// the ideal LPC subframe parameters (though there are
2938// some worst ways, like choosing them at random).
2939// Even the reference implementation has changed its
2940// defaults over time.  So long as the subframe's residuals
2941// are calculated correctly, decoders don't care one way or another.
2942//
2943// I'll try to use an approach similar to the reference implementation's.
2944
2945impl LpcParameters {
2946    fn best(
2947        options: &EncoderOptions,
2948        bits_per_sample: SignedBitCount<32>,
2949        max_lpc_order: NonZero<u8>,
2950        window: &mut Vec<f64>,
2951        windowed: &mut Vec<f64>,
2952        channel: &[i32],
2953    ) -> Result<Self, Error> {
2954        if channel.len() <= usize::from(max_lpc_order.get()) {
2955            // not enough samples in channel to calculate LPC parameters
2956            return Err(Error::InsufficientLpcSamples);
2957        }
2958
2959        let precision = match channel.len() {
2960            // this shouldn't be possible
2961            0 => panic!("at least one sample required in channel"),
2962            1..=192 => SignedBitCount::new::<7>(),
2963            193..=384 => SignedBitCount::new::<8>(),
2964            385..=576 => SignedBitCount::new::<9>(),
2965            577..=1152 => SignedBitCount::new::<10>(),
2966            1153..=2304 => SignedBitCount::new::<11>(),
2967            2305..=4608 => SignedBitCount::new::<12>(),
2968            4609.. => SignedBitCount::new::<13>(),
2969        };
2970
2971        let (order, lp_coeffs) = compute_best_order(
2972            bits_per_sample,
2973            precision,
2974            channel
2975                .len()
2976                .try_into()
2977                // this shouldn't be possible
2978                .expect("excessive samples for subframe"),
2979            lp_coefficients(autocorrelate(
2980                options.window.apply(window, windowed, channel),
2981                max_lpc_order,
2982            )),
2983        )?;
2984
2985        Self::quantize(order, lp_coeffs, precision)
2986    }
2987
2988    fn quantize(
2989        order: NonZero<u8>,
2990        coeffs: ArrayVec<f64, MAX_LPC_COEFFS>,
2991        precision: SignedBitCount<15>,
2992    ) -> Result<Self, Error> {
2993        const MAX_SHIFT: i32 = (1 << 4) - 1;
2994        const MIN_SHIFT: i32 = -(1 << 4);
2995
2996        // verified output against reference implementation
2997        // See: FLAC__lpc_quantize_coefficients
2998
2999        debug_assert!(coeffs.len() == usize::from(order.get()));
3000
3001        let max_coeff = (1 << (u32::from(precision) - 1)) - 1;
3002        let min_coeff = -(1 << (u32::from(precision) - 1));
3003
3004        let l = coeffs
3005            .iter()
3006            .map(|c| c.abs())
3007            .max_by(|x, y| x.total_cmp(y))
3008            // f64.log2() gives unfortunate results when <= 0.0
3009            .filter(|l| *l > 0.0)
3010            .ok_or(Error::ZeroLpCoefficients)?;
3011
3012        let mut error = 0.0;
3013
3014        match ((u32::from(precision) - 1) as i32 - ((l.log2().floor()) as i32) - 1).min(MAX_SHIFT) {
3015            shift @ 0.. => {
3016                // normal, positive shift case
3017                let shift = shift as u32;
3018
3019                Ok(Self {
3020                    order,
3021                    precision,
3022                    shift,
3023                    coefficients: coeffs
3024                        .into_iter()
3025                        .map(|lp_coeff| {
3026                            let sum: f64 = lp_coeff.mul_add((1 << shift) as f64, error);
3027                            let qlp_coeff = (sum.round() as i32).clamp(min_coeff, max_coeff);
3028                            error = sum - (qlp_coeff as f64);
3029                            qlp_coeff
3030                        })
3031                        .collect(),
3032                })
3033            }
3034            shift @ MIN_SHIFT..0 => {
3035                // unusual negative shift case
3036                let shift = -shift as u32;
3037
3038                Ok(Self {
3039                    order,
3040                    precision,
3041                    shift: 0,
3042                    coefficients: coeffs
3043                        .into_iter()
3044                        .map(|lp_coeff| {
3045                            let sum: f64 = (lp_coeff / (1 << shift) as f64) + error;
3046                            let qlp_coeff = (sum.round() as i32).clamp(min_coeff, max_coeff);
3047                            error = sum - (qlp_coeff as f64);
3048                            qlp_coeff
3049                        })
3050                        .collect(),
3051                })
3052            }
3053            ..MIN_SHIFT => Err(Error::LpNegativeShiftError),
3054        }
3055    }
3056}
3057
3058#[test]
3059fn test_quantization() {
3060    // test against numbers generated from reference implementation
3061
3062    let order = NonZero::new(4).unwrap();
3063
3064    let quantized = LpcParameters::quantize(
3065        order,
3066        arrayvec![0.797774, -0.045362, -0.050136, -0.054254],
3067        SignedBitCount::new::<10>(),
3068    )
3069    .unwrap();
3070
3071    assert_eq!(quantized.order, order);
3072    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3073    assert_eq!(quantized.shift, 9);
3074    assert_eq!(quantized.coefficients, arrayvec![408, -23, -25, -28]);
3075
3076    // note the relationship between the un-quantized,
3077    // floating point parameters and the shift value (9)
3078    //
3079    // 409 / 2 ** 9 ≈ 0.796875
3080    // -23 / 2 ** 9 ≈ -0.044921
3081    // -25 / 2 ** 9 ≈ -0.048828
3082    // -28 / 2 ** 9 ≈ -0.054687
3083    //
3084    // we're converting floats to fractions
3085
3086    let quantized = LpcParameters::quantize(
3087        order,
3088        arrayvec![-0.054687, -0.953216, -0.027115, 0.033537],
3089        SignedBitCount::new::<10>(),
3090    )
3091    .unwrap();
3092
3093    assert_eq!(quantized.order, order);
3094    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3095    assert_eq!(quantized.shift, 9);
3096    assert_eq!(quantized.coefficients, arrayvec![-28, -488, -14, 17]);
3097
3098    // coefficients should never be all zero, which is bad
3099    assert!(matches!(
3100        LpcParameters::quantize(
3101            order,
3102            arrayvec![0.0, 0.0, 0.0, 0.0],
3103            SignedBitCount::new::<10>(),
3104        ),
3105        Err(Error::ZeroLpCoefficients)
3106    ));
3107
3108    // negative shifts should also be handled properly
3109    let quantized = LpcParameters::quantize(
3110        order,
3111        arrayvec![-0.1, 0.1, 10000000.0, -0.2],
3112        SignedBitCount::new::<10>(),
3113    )
3114    .unwrap();
3115
3116    assert_eq!(quantized.order, order);
3117    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3118    assert_eq!(quantized.shift, 0);
3119    assert_eq!(quantized.coefficients, arrayvec![0, 0, 305, 0]);
3120
3121    // and massive negative shifts must be an error
3122    assert!(matches!(
3123        LpcParameters::quantize(
3124            order,
3125            arrayvec![-0.1, 0.1, 100000000.0, -0.2],
3126            SignedBitCount::new::<10>(),
3127        ),
3128        Err(Error::LpNegativeShiftError)
3129    ));
3130}
3131
3132fn autocorrelate(
3133    windowed: &[f64],
3134    max_lpc_order: NonZero<u8>,
3135) -> ArrayVec<f64, { MAX_LPC_COEFFS + 1 }> {
3136    // verified output against reference implementation
3137    // See: FLAC__lpc_compute_autocorrelation
3138
3139    debug_assert!(usize::from(max_lpc_order.get()) < MAX_LPC_COEFFS);
3140
3141    let mut tail = windowed;
3142    // let mut autocorrelated = Vec::with_capacity(max_lpc_order.get().into());
3143    let mut autocorrelated = ArrayVec::default();
3144
3145    for _ in 0..=max_lpc_order.get() {
3146        if tail.is_empty() {
3147            return autocorrelated;
3148        } else {
3149            autocorrelated.push(windowed.iter().zip(tail).map(|(x, y)| x * y).sum());
3150            tail.split_off_first();
3151        }
3152    }
3153
3154    autocorrelated
3155}
3156
3157#[test]
3158fn test_autocorrelation() {
3159    // test against numbers generated from reference implementation
3160
3161    assert_eq!(
3162        autocorrelate(&[1.0], NonZero::new(1).unwrap()),
3163        arrayvec![1.0]
3164    );
3165
3166    assert_eq!(
3167        autocorrelate(&[1.0, 2.0, 3.0, 4.0, 5.0], NonZero::new(4).unwrap()),
3168        arrayvec![55.0, 40.0, 26.0, 14.0, 5.0],
3169    );
3170
3171    assert_eq!(
3172        autocorrelate(
3173            &[
3174                0.0, 16.0, 31.0, 44.0, 54.0, 61.0, 64.0, 63.0, 58.0, 49.0, 38.0, 24.0, 8.0, -8.0,
3175                -24.0, -38.0, -49.0, -58.0, -63.0, -64.0, -61.0, -54.0, -44.0, -31.0, -16.0,
3176            ],
3177            NonZero::new(4).unwrap()
3178        ),
3179        arrayvec![51408.0, 49792.0, 45304.0, 38466.0, 29914.0],
3180    )
3181}
3182
3183#[derive(Debug)]
3184struct LpCoeff {
3185    coeffs: ArrayVec<f64, MAX_LPC_COEFFS>,
3186    error: f64,
3187}
3188
3189// returns a Vec of (coefficients, error) pairs
3190fn lp_coefficients(
3191    autocorrelated: ArrayVec<f64, { MAX_LPC_COEFFS + 1 }>,
3192) -> ArrayVec<LpCoeff, MAX_LPC_COEFFS> {
3193    // verified output against reference implementation
3194    // See: FLAC__lpc_compute_lp_coefficients
3195
3196    match autocorrelated.len() {
3197        0 | 1 => panic!("must have at least 2 autocorrelation values"),
3198        _ => {
3199            let k = autocorrelated[1] / autocorrelated[0];
3200            let mut lp_coefficients = arrayvec![LpCoeff {
3201                coeffs: arrayvec![k],
3202                error: autocorrelated[0] * (1.0 - k.powi(2)),
3203            }];
3204
3205            for i in 1..(autocorrelated.len() - 1) {
3206                if let [prev @ .., next] = &autocorrelated[0..=i + 1] {
3207                    let LpCoeff { coeffs, error } = lp_coefficients.last().unwrap();
3208
3209                    let q = next
3210                        - prev
3211                            .iter()
3212                            .rev()
3213                            .zip(coeffs)
3214                            .map(|(x, y)| x * y)
3215                            .sum::<f64>();
3216
3217                    let k = q / error;
3218
3219                    lp_coefficients.push(LpCoeff {
3220                        coeffs: coeffs
3221                            .iter()
3222                            .zip(coeffs.iter().rev().map(|c| k * c))
3223                            .map(|(c1, c2)| (c1 - c2))
3224                            .chain(std::iter::once(k))
3225                            .collect(),
3226                        error: error * (1.0 - k.powi(2)),
3227                    });
3228                }
3229            }
3230
3231            lp_coefficients
3232        }
3233    }
3234}
3235
3236#[allow(unused)]
3237macro_rules! assert_float_approx {
3238    ($a:expr, $b:expr) => {{
3239        let a = $a;
3240        let b = $b;
3241        assert!((a - b).abs() < 1.0e-6, "{a} != {b}");
3242    }};
3243}
3244
3245#[test]
3246fn test_lp_coefficients_1() {
3247    // test against numbers generated from reference implementation
3248
3249    let lp_coeffs = lp_coefficients(arrayvec![55.0, 40.0, 26.0, 14.0, 5.0]);
3250
3251    assert_eq!(lp_coeffs.len(), 4);
3252
3253    assert_float_approx!(lp_coeffs[0].error, 25.909091);
3254    assert_float_approx!(lp_coeffs[1].error, 25.540351);
3255    assert_float_approx!(lp_coeffs[2].error, 25.316142);
3256    assert_float_approx!(lp_coeffs[3].error, 25.241623);
3257
3258    assert_eq!(lp_coeffs[0].coeffs.len(), 1);
3259    assert_float_approx!(lp_coeffs[0].coeffs[0], 0.727273);
3260
3261    assert_eq!(lp_coeffs[1].coeffs.len(), 2);
3262    assert_float_approx!(lp_coeffs[1].coeffs[0], 0.814035);
3263    assert_float_approx!(lp_coeffs[1].coeffs[1], -0.119298);
3264
3265    assert_eq!(lp_coeffs[2].coeffs.len(), 3);
3266    assert_float_approx!(lp_coeffs[2].coeffs[0], 0.802858);
3267    assert_float_approx!(lp_coeffs[2].coeffs[1], -0.043028);
3268    assert_float_approx!(lp_coeffs[2].coeffs[2], -0.093694);
3269
3270    assert_eq!(lp_coeffs[3].coeffs.len(), 4);
3271    assert_float_approx!(lp_coeffs[3].coeffs[0], 0.797774);
3272    assert_float_approx!(lp_coeffs[3].coeffs[1], -0.045362);
3273    assert_float_approx!(lp_coeffs[3].coeffs[2], -0.050136);
3274    assert_float_approx!(lp_coeffs[3].coeffs[3], -0.054254);
3275}
3276
3277#[test]
3278fn test_lp_coefficients_2() {
3279    // test against numbers generated from reference implementation
3280
3281    let lp_coeffs = lp_coefficients(arrayvec![51408.0, 49792.0, 45304.0, 38466.0, 29914.0]);
3282
3283    assert_eq!(lp_coeffs.len(), 4);
3284
3285    assert_float_approx!(lp_coeffs[0].error, 3181.201369);
3286    assert_float_approx!(lp_coeffs[1].error, 495.815931);
3287    assert_float_approx!(lp_coeffs[2].error, 495.161449);
3288    assert_float_approx!(lp_coeffs[3].error, 494.604514);
3289
3290    assert_eq!(lp_coeffs[0].coeffs.len(), 1);
3291    assert_float_approx!(lp_coeffs[0].coeffs[0], 0.968565);
3292
3293    assert_eq!(lp_coeffs[1].coeffs.len(), 2);
3294    assert_float_approx!(lp_coeffs[1].coeffs[0], 1.858456);
3295    assert_float_approx!(lp_coeffs[1].coeffs[1], -0.918772);
3296
3297    assert_eq!(lp_coeffs[2].coeffs.len(), 3);
3298    assert_float_approx!(lp_coeffs[2].coeffs[0], 1.891837);
3299    assert_float_approx!(lp_coeffs[2].coeffs[1], -0.986293);
3300    assert_float_approx!(lp_coeffs[2].coeffs[2], 0.036332);
3301
3302    assert_eq!(lp_coeffs[3].coeffs.len(), 4);
3303    assert_float_approx!(lp_coeffs[3].coeffs[0], 1.890618);
3304    assert_float_approx!(lp_coeffs[3].coeffs[1], -0.953216);
3305    assert_float_approx!(lp_coeffs[3].coeffs[2], -0.027115);
3306    assert_float_approx!(lp_coeffs[3].coeffs[3], 0.033537);
3307}
3308
3309// Returns (bits, order, coeffients) tuples
3310fn subframe_bits_by_order(
3311    bits_per_sample: SignedBitCount<32>,
3312    precision: SignedBitCount<15>,
3313    sample_count: u16,
3314    coeffs: ArrayVec<LpCoeff, MAX_LPC_COEFFS>,
3315) -> impl Iterator<Item = (f64, u8, ArrayVec<f64, MAX_LPC_COEFFS>)> {
3316    debug_assert!(sample_count > 0);
3317
3318    let error_scale = 0.5 / f64::from(sample_count);
3319
3320    coeffs
3321        .into_iter()
3322        .take_while(|coeffs| coeffs.error > 0.0)
3323        .zip(1..)
3324        .map(move |(LpCoeff { coeffs, error }, order)| {
3325            let header_bits =
3326                u32::from(order) * (u32::from(bits_per_sample) + u32::from(precision));
3327
3328            let bits_per_residual =
3329                (error * error_scale).ln() / (2.0 * std::f64::consts::LN_2).max(0.0);
3330
3331            let subframe_bits = bits_per_residual.mul_add(
3332                f64::from(sample_count - u16::from(order)),
3333                f64::from(header_bits),
3334            );
3335
3336            (subframe_bits, order, coeffs)
3337        })
3338}
3339
3340// Uses the error in the LP coefficients to determine the best order
3341// and returns that order along with the stripped-out coefficients
3342fn compute_best_order(
3343    bits_per_sample: SignedBitCount<32>,
3344    precision: SignedBitCount<15>,
3345    sample_count: u16,
3346    coeffs: ArrayVec<LpCoeff, MAX_LPC_COEFFS>,
3347) -> Result<(NonZero<u8>, ArrayVec<f64, MAX_LPC_COEFFS>), Error> {
3348    // verified output against reference implementation
3349    // See: FLAC__lpc_compute_best_order  and
3350    // See: FLAC__lpc_compute_expected_bits_per_residual_sample_with_error_scale
3351
3352    subframe_bits_by_order(bits_per_sample, precision, sample_count, coeffs)
3353        .min_by(|(x, _, _), (y, _, _)| x.total_cmp(y))
3354        .and_then(|(_, order, coeffs)| Some((NonZero::new(order)?, coeffs)))
3355        .ok_or(Error::NoBestLpcOrder)
3356}
3357
3358#[test]
3359fn test_compute_best_order() {
3360    // test against numbers generated from reference implementation
3361
3362    let mut bits = subframe_bits_by_order(
3363        SignedBitCount::new::<16>(),
3364        SignedBitCount::new::<5>(),
3365        20,
3366        [3181.201369, 495.815931, 495.161449, 494.604514]
3367            .into_iter()
3368            .map(|error| LpCoeff {
3369                coeffs: ArrayVec::default(),
3370                error,
3371            })
3372            .collect(),
3373    )
3374    .map(|t| t.0);
3375
3376    assert_float_approx!(bits.next().unwrap(), 80.977565);
3377    assert_float_approx!(bits.next().unwrap(), 74.685594);
3378    assert_float_approx!(bits.next().unwrap(), 93.853530);
3379    assert_float_approx!(bits.next().unwrap(), 113.025628);
3380
3381    let mut bits = subframe_bits_by_order(
3382        SignedBitCount::new::<16>(),
3383        SignedBitCount::new::<10>(),
3384        4096,
3385        [15000.0, 25000.0, 20000.0, 30000.0]
3386            .into_iter()
3387            .map(|error| LpCoeff {
3388                coeffs: ArrayVec::default(),
3389                error,
3390            })
3391            .collect(),
3392    )
3393    .map(|t| t.0);
3394
3395    assert_float_approx!(bits.next().unwrap(), 1812.801817);
3396    assert_float_approx!(bits.next().unwrap(), 3346.934051);
3397    assert_float_approx!(bits.next().unwrap(), 2713.303385);
3398    assert_float_approx!(bits.next().unwrap(), 3935.492805);
3399}
3400
3401fn write_residuals<W: BitWrite>(
3402    options: &EncoderOptions,
3403    writer: &mut W,
3404    predictor_order: usize,
3405    residuals: &[i32],
3406) -> Result<(), Error> {
3407    use crate::stream::ResidualPartitionHeader;
3408    use bitstream_io::{BitCount, ToBitStream};
3409
3410    const MAX_PARTITIONS: usize = 64;
3411
3412    #[derive(Debug)]
3413    struct Partition<'r, const RICE_MAX: u32> {
3414        header: ResidualPartitionHeader<RICE_MAX>,
3415        residuals: &'r [i32],
3416    }
3417
3418    impl<'r, const RICE_MAX: u32> Partition<'r, RICE_MAX> {
3419        fn new(partition: &'r [i32], estimated_bits: &mut u32) -> Option<Self> {
3420            let partition_samples = partition.len() as u16;
3421            if partition_samples == 0 {
3422                return None;
3423            }
3424
3425            let partition_sum = partition
3426                .iter()
3427                .map(|i| u64::from(i.unsigned_abs()))
3428                .sum::<u64>();
3429
3430            if partition_sum > 0 {
3431                let rice = if partition_sum > partition_samples.into() {
3432                    let bits_needed = ((partition_sum as f64) / f64::from(partition_samples))
3433                        .log2()
3434                        .ceil() as u32;
3435
3436                    match BitCount::try_from(bits_needed).ok().filter(|rice| {
3437                        u32::from(*rice) < u32::from(BitCount::<RICE_MAX>::new::<RICE_MAX>())
3438                    }) {
3439                        Some(rice) => rice,
3440                        None => {
3441                            let escape_size = (partition
3442                                .iter()
3443                                .map(|i| u64::from(i.unsigned_abs()))
3444                                .sum::<u64>()
3445                                .ilog2()
3446                                + 2)
3447                            .try_into()
3448                            .ok()?;
3449
3450                            *estimated_bits +=
3451                                u32::from(escape_size) * u32::from(partition_samples);
3452
3453                            return Some(Self {
3454                                header: ResidualPartitionHeader::Escaped { escape_size },
3455                                residuals: partition,
3456                            });
3457                        }
3458                    }
3459                } else {
3460                    BitCount::new::<0>()
3461                };
3462
3463                let partition_size: u32 = 4u32
3464                    + ((1 + u32::from(rice)) * u32::from(partition_samples))
3465                    + if u32::from(rice) > 0 {
3466                        u32::try_from(partition_sum >> (u32::from(rice) - 1)).ok()?
3467                    } else {
3468                        u32::try_from(partition_sum << 1).ok()?
3469                    }
3470                    - (u32::from(partition_samples) / 2);
3471
3472                *estimated_bits += partition_size;
3473
3474                Some(Partition {
3475                    header: ResidualPartitionHeader::Standard { rice },
3476                    residuals: partition,
3477                })
3478            } else {
3479                // all partition residuals are 0, so use a constant
3480                Some(Partition {
3481                    header: ResidualPartitionHeader::Constant,
3482                    residuals: partition,
3483                })
3484            }
3485        }
3486    }
3487
3488    impl<const RICE_MAX: u32> ToBitStream for Partition<'_, RICE_MAX> {
3489        type Error = std::io::Error;
3490
3491        #[inline]
3492        fn to_writer<W: BitWrite + ?Sized>(&self, w: &mut W) -> Result<(), Self::Error> {
3493            w.build(&self.header)?;
3494            match self.header {
3495                ResidualPartitionHeader::Standard { rice } => {
3496                    let mask = rice.mask_lsb();
3497
3498                    self.residuals.iter().try_for_each(|s| {
3499                        let (msb, lsb) = mask(if s.is_negative() {
3500                            ((-*s as u32 - 1) << 1) + 1
3501                        } else {
3502                            (*s as u32) << 1
3503                        });
3504                        w.write_unary::<1>(msb)?;
3505                        w.write_checked(lsb)
3506                    })?;
3507                }
3508                ResidualPartitionHeader::Escaped { escape_size } => {
3509                    self.residuals
3510                        .iter()
3511                        .try_for_each(|s| w.write_signed_counted(escape_size, *s))?;
3512                }
3513                ResidualPartitionHeader::Constant => { /* nothing left to do */ }
3514            }
3515            Ok(())
3516        }
3517    }
3518
3519    fn best_partitions<'r, const RICE_MAX: u32>(
3520        options: &EncoderOptions,
3521        block_size: usize,
3522        residuals: &'r [i32],
3523    ) -> ArrayVec<Partition<'r, RICE_MAX>, MAX_PARTITIONS> {
3524        (0..=block_size.trailing_zeros().min(options.max_partition_order))
3525            .map(|partition_order| 1 << partition_order)
3526            .take_while(|partition_count: &usize| partition_count.is_power_of_two())
3527            .filter_map(|partition_count| {
3528                let mut estimated_bits = 0;
3529
3530                let partitions = residuals
3531                    .rchunks(block_size / partition_count)
3532                    .rev()
3533                    .map(|partition| Partition::new(partition, &mut estimated_bits))
3534                    .collect::<Option<ArrayVec<_, MAX_PARTITIONS>>>()
3535                    .filter(|p| !p.is_empty() && p.len().is_power_of_two())?;
3536
3537                Some((partitions, estimated_bits))
3538            })
3539            .min_by_key(|(_, estimated_bits)| *estimated_bits)
3540            .map(|(partitions, _)| partitions)
3541            .unwrap_or_else(|| {
3542                std::iter::once(Partition {
3543                    header: ResidualPartitionHeader::Escaped {
3544                        escape_size: SignedBitCount::new::<0b11111>(),
3545                    },
3546                    residuals,
3547                })
3548                .collect()
3549            })
3550    }
3551
3552    fn write_partitions<const RICE_MAX: u32, W: BitWrite>(
3553        writer: &mut W,
3554        partitions: ArrayVec<Partition<'_, RICE_MAX>, MAX_PARTITIONS>,
3555    ) -> Result<(), Error> {
3556        writer.write::<4, u32>(partitions.len().ilog2())?; // partition order
3557        for partition in partitions {
3558            writer.build(&partition)?;
3559        }
3560        Ok(())
3561    }
3562
3563    #[inline]
3564    fn try_shrink_header<const RICE_MAX: u32, const RICE_NEW_MAX: u32>(
3565        header: ResidualPartitionHeader<RICE_MAX>,
3566    ) -> Option<ResidualPartitionHeader<RICE_NEW_MAX>> {
3567        Some(match header {
3568            ResidualPartitionHeader::Standard { rice } => ResidualPartitionHeader::Standard {
3569                rice: rice.try_map(|r| (r < RICE_NEW_MAX).then_some(r))?,
3570            },
3571            ResidualPartitionHeader::Escaped { escape_size } => {
3572                ResidualPartitionHeader::Escaped { escape_size }
3573            }
3574            ResidualPartitionHeader::Constant => ResidualPartitionHeader::Constant,
3575        })
3576    }
3577
3578    enum CodingMethod<'p> {
3579        Rice(ArrayVec<Partition<'p, 0b1111>, MAX_PARTITIONS>),
3580        Rice2(ArrayVec<Partition<'p, 0b11111>, MAX_PARTITIONS>),
3581    }
3582
3583    fn try_reduce_rice(
3584        partitions: ArrayVec<Partition<'_, 0b11111>, MAX_PARTITIONS>,
3585    ) -> CodingMethod {
3586        match partitions
3587            .iter()
3588            .map(|Partition { header, residuals }| {
3589                try_shrink_header(*header).map(|header| Partition { header, residuals })
3590            })
3591            .collect()
3592        {
3593            Some(partitions) => CodingMethod::Rice(partitions),
3594            None => CodingMethod::Rice2(partitions),
3595        }
3596    }
3597
3598    let block_size = predictor_order + residuals.len();
3599
3600    if options.use_rice2 {
3601        match try_reduce_rice(best_partitions(options, block_size, residuals)) {
3602            CodingMethod::Rice(partitions) => {
3603                writer.write::<2, u8>(0)?; // coding method 0
3604                write_partitions(writer, partitions)
3605            }
3606            CodingMethod::Rice2(partitions) => {
3607                writer.write::<2, u8>(1)?; // coding method 1
3608                write_partitions(writer, partitions)
3609            }
3610        }
3611    } else {
3612        let partitions = best_partitions::<0b1111>(options, block_size, residuals);
3613        writer.write::<2, u8>(0)?; // coding method 0
3614        write_partitions(writer, partitions)
3615    }
3616}
3617
3618fn try_join<A, B, RA, RB, E>(oper_a: A, oper_b: B) -> Result<(RA, RB), E>
3619where
3620    A: FnOnce() -> Result<RA, E> + Send,
3621    B: FnOnce() -> Result<RB, E> + Send,
3622    RA: Send,
3623    RB: Send,
3624    E: Send,
3625{
3626    let (a, b) = join(oper_a, oper_b);
3627    Ok((a?, b?))
3628}
3629
3630#[cfg(feature = "rayon")]
3631use rayon::join;
3632
3633#[cfg(not(feature = "rayon"))]
3634fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
3635where
3636    A: FnOnce() -> RA + Send,
3637    B: FnOnce() -> RB + Send,
3638    RA: Send,
3639    RB: Send,
3640{
3641    (oper_a(), oper_b())
3642}
3643
3644#[cfg(feature = "rayon")]
3645fn vec_map<T, U, F>(src: Vec<T>, f: F) -> Vec<U>
3646where
3647    T: Send,
3648    U: Send,
3649    F: Fn(T) -> U + Send + Sync,
3650{
3651    use rayon::iter::{IntoParallelIterator, ParallelIterator};
3652
3653    src.into_par_iter().map(f).collect()
3654}
3655
3656#[cfg(not(feature = "rayon"))]
3657fn vec_map<T, U, F>(src: Vec<T>, f: F) -> Vec<U>
3658where
3659    T: Send,
3660    U: Send,
3661    F: Fn(T) -> U + Send + Sync,
3662{
3663    src.into_iter().map(f).collect()
3664}
3665
3666fn exact_div<N>(n: N, rhs: N) -> Option<N>
3667where
3668    N: std::ops::Div<Output = N> + std::ops::Rem<Output = N> + std::cmp::PartialEq + Copy + Default,
3669{
3670    (n % rhs == N::default()).then_some(n / rhs)
3671}