flac_codec/
encode.rs

1// Copyright 2025 Brian Langenberger
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! For encoding PCM samples to FLAC files
10//!
11//! ## Multithreading
12//!
13//! Encoders will operate using multithreading if the optional `rayon` feature is enabled,
14//! typically boosting performance by processing channels in parallel.
15//! But because subframes must eventually be written serially, and their size cannot generally
16//! be known in advance, processing two channels across two threads will not
17//! encode twice as fast.
18
19use crate::audio::Frame;
20use crate::metadata::{
21    Application, BlockList, BlockSize, Cuesheet, Picture, SeekPoint, Streaminfo, VorbisComment,
22    write_blocks,
23};
24use crate::stream::{ChannelAssignment, FrameNumber, Independent, SampleRate};
25use crate::{Counter, Error};
26use arrayvec::ArrayVec;
27use bitstream_io::{BigEndian, BitRecorder, BitWrite, BitWriter, SignedBitCount};
28use std::collections::VecDeque;
29use std::fs::File;
30use std::io::BufWriter;
31use std::num::NonZero;
32use std::path::Path;
33
34const MAX_CHANNELS: usize = 8;
35// maximum number of LPC coefficients
36const MAX_LPC_COEFFS: usize = 32;
37
38// Invent a vec!-like macro that the official crate lacks
39macro_rules! arrayvec {
40    ( $( $x:expr ),* ) => {
41        {
42            let mut v = ArrayVec::default();
43            $( v.push($x); )*
44            v
45        }
46    }
47}
48
49/// A FLAC writer which accepts samples as bytes
50///
51/// # Example
52///
53/// ```
54/// use flac_codec::{
55///     byteorder::LittleEndian,
56///     encode::{FlacByteWriter, Options},
57///     decode::{FlacByteReader, Metadata},
58/// };
59/// use std::io::{Cursor, Read, Seek, Write};
60///
61/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
62///
63/// let mut writer = FlacByteWriter::endian(
64///     &mut flac,           // our wrapped writer
65///     LittleEndian,        // .wav-style byte order
66///     Options::default(),  // default encoding options
67///     44100,               // sample rate
68///     16,                  // bits-per-sample
69///     1,                   // channel count
70///     Some(2000),          // total bytes
71/// ).unwrap();
72///
73/// // write 1000 samples as 16-bit, signed, little-endian bytes (2000 bytes total)
74/// let written_bytes = (0..1000).map(i16::to_le_bytes).flatten().collect::<Vec<u8>>();
75/// assert!(writer.write_all(&written_bytes).is_ok());
76///
77/// // finalize writing file
78/// assert!(writer.finalize().is_ok());
79///
80/// flac.rewind().unwrap();
81///
82/// // open reader around written FLAC file
83/// let mut reader = FlacByteReader::endian(flac, LittleEndian).unwrap();
84///
85/// // read 2000 bytes
86/// let mut read_bytes = vec![];
87/// assert!(reader.read_to_end(&mut read_bytes).is_ok());
88///
89/// // ensure MD5 sum of signed, little-endian samples matches hash in file
90/// let mut md5 = md5::Context::new();
91/// md5.consume(&read_bytes);
92/// assert_eq!(&md5.compute().0, reader.md5().unwrap());
93///
94/// // ensure input and output matches
95/// assert_eq!(read_bytes, written_bytes);
96/// ```
97pub struct FlacByteWriter<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> {
98    // the wrapped encoder
99    encoder: Encoder<W>,
100    // bytes that make up a partial FLAC frame
101    buf: VecDeque<u8>,
102    // a whole set of samples for a FLAC frame
103    frame: Frame,
104    // size of a single sample in bytes
105    bytes_per_sample: usize,
106    // size of single set of channel-independent samples in bytes
107    pcm_frame_size: usize,
108    // size of whole FLAC frame's samples in bytes
109    frame_byte_size: usize,
110    // whether the encoder has finalized the file
111    finalized: bool,
112    // the input bytes' endianness
113    endianness: std::marker::PhantomData<E>,
114}
115
116impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> FlacByteWriter<W, E> {
117    /// Creates new FLAC writer with the given parameters
118    ///
119    /// The writer should be positioned at the start of the file.
120    ///
121    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
122    ///
123    /// `bits_per_sample` must be between 1 and 32.
124    ///
125    /// `channels` must be between 1 and 8.
126    ///
127    /// Note that if `total_bytes` is indicated,
128    /// the number of channel-independent samples written *must*
129    /// be equal to that amount or an error will occur when writing
130    /// or finalizing the stream.
131    ///
132    /// # Errors
133    ///
134    /// Returns I/O error if unable to write initial
135    /// metadata blocks.
136    /// Returns error if any of the encoding parameters are invalid.
137    pub fn new(
138        writer: W,
139        options: Options,
140        sample_rate: u32,
141        bits_per_sample: u32,
142        channels: u8,
143        total_bytes: Option<u64>,
144    ) -> Result<Self, Error> {
145        let bits_per_sample: SignedBitCount<32> = bits_per_sample
146            .try_into()
147            .map_err(|_| Error::InvalidBitsPerSample)?;
148
149        let bytes_per_sample = u32::from(bits_per_sample).div_ceil(8) as usize;
150
151        let pcm_frame_size = bytes_per_sample * channels as usize;
152
153        Ok(Self {
154            buf: VecDeque::default(),
155            frame: Frame::empty(channels.into(), bits_per_sample.into()),
156            bytes_per_sample,
157            pcm_frame_size,
158            frame_byte_size: pcm_frame_size * options.block_size as usize,
159            encoder: Encoder::new(
160                writer,
161                options,
162                sample_rate,
163                bits_per_sample,
164                channels,
165                total_bytes
166                    .map(|bytes| {
167                        exact_div(bytes, channels.into())
168                            .and_then(|s| exact_div(s, bytes_per_sample as u64))
169                            .ok_or(Error::SamplesNotDivisibleByChannels)
170                            .and_then(|b| NonZero::new(b).ok_or(Error::InvalidTotalBytes))
171                    })
172                    .transpose()?,
173            )?,
174            finalized: false,
175            endianness: std::marker::PhantomData,
176        })
177    }
178
179    /// Creates new FLAC writer with CDDA parameters
180    ///
181    /// The writer should be positioned at the start of the file.
182    ///
183    /// Sample rate is 44100 Hz, bits-per-sample is 16,
184    /// channels is 2.
185    ///
186    /// Note that if `total_bytes` is indicated,
187    /// the number of channel-independent samples written *must*
188    /// be equal to that amount or an error will occur when writing
189    /// or finalizing the stream.
190    ///
191    /// # Errors
192    ///
193    /// Returns I/O error if unable to write initial
194    /// metadata blocks.
195    /// Returns error if any of the encoding parameters are invalid.
196    pub fn new_cdda(writer: W, options: Options, total_bytes: Option<u64>) -> Result<Self, Error> {
197        Self::new(writer, options, 44100, 16, 2, total_bytes)
198    }
199
200    /// Creates new FLAC writer in the given endianness with the given parameters
201    ///
202    /// The writer should be positioned at the start of the file.
203    ///
204    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
205    ///
206    /// `bits_per_sample` must be between 1 and 32.
207    ///
208    /// `channels` must be between 1 and 8.
209    ///
210    /// Note that if `total_bytes` is indicated,
211    /// the number of bytes written *must*
212    /// be equal to that amount or an error will occur when writing
213    /// or finalizing the stream.
214    ///
215    /// # Errors
216    ///
217    /// Returns I/O error if unable to write initial
218    /// metadata blocks.
219    #[inline]
220    pub fn endian(
221        writer: W,
222        _endianness: E,
223        options: Options,
224        sample_rate: u32,
225        bits_per_sample: u32,
226        channels: u8,
227        total_bytes: Option<u64>,
228    ) -> Result<Self, Error> {
229        Self::new(
230            writer,
231            options,
232            sample_rate,
233            bits_per_sample,
234            channels,
235            total_bytes,
236        )
237    }
238
239    fn finalize_inner(&mut self) -> Result<(), Error> {
240        if !self.finalized {
241            self.finalized = true;
242
243            // encode as many bytes as possible into final frame, if necessary
244            if !self.buf.is_empty() {
245                use crate::byteorder::LittleEndian;
246
247                // truncate buffer to whole PCM frames
248                let buf = self.buf.make_contiguous();
249                let buf_len = buf.len();
250                let buf = &mut buf[..(buf_len - buf_len % self.pcm_frame_size)];
251
252                // convert buffer to little-endian bytes
253                E::bytes_to_le(buf, self.bytes_per_sample);
254
255                // update MD5 sum with little-endian bytes
256                self.encoder.update_md5(buf);
257
258                self.encoder
259                    .encode(self.frame.fill_from_buf::<LittleEndian>(buf))?;
260            }
261
262            self.encoder.finalize_inner()
263        } else {
264            Ok(())
265        }
266    }
267
268    /// Attempt to finalize stream
269    ///
270    /// It is necessary to finalize the FLAC encoder
271    /// so that it will write any partially unwritten samples
272    /// to the stream and update the [`crate::metadata::Streaminfo`] and [`crate::metadata::SeekTable`] blocks
273    /// with their final values.
274    ///
275    /// Dropping the encoder will attempt to finalize the stream
276    /// automatically, but will ignore any errors that may occur.
277    pub fn finalize(mut self) -> Result<(), Error> {
278        self.finalize_inner()?;
279        Ok(())
280    }
281}
282
283impl<E: crate::byteorder::Endianness> FlacByteWriter<BufWriter<File>, E> {
284    /// Creates new FLAC file at the given path
285    ///
286    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
287    ///
288    /// `bits_per_sample` must be between 1 and 32.
289    ///
290    /// `channels` must be between 1 and 8.
291    ///
292    /// Note that if `total_bytes` is indicated,
293    /// the number of bytes written *must*
294    /// be equal to that amount or an error will occur when writing
295    /// or finalizing the stream.
296    ///
297    /// # Errors
298    ///
299    /// Returns I/O error if unable to write initial
300    /// metadata blocks.
301    #[inline]
302    pub fn create<P: AsRef<Path>>(
303        path: P,
304        options: Options,
305        sample_rate: u32,
306        bits_per_sample: u32,
307        channels: u8,
308        total_bytes: Option<u64>,
309    ) -> Result<Self, Error> {
310        FlacByteWriter::new(
311            BufWriter::new(options.create(path)?),
312            options,
313            sample_rate,
314            bits_per_sample,
315            channels,
316            total_bytes,
317        )
318    }
319
320    /// Creates new FLAC file with CDDA parameters at the given path
321    ///
322    /// Sample rate is 44100 Hz, bits-per-sample is 16,
323    /// channels is 2.
324    ///
325    /// Note that if `total_bytes` is indicated,
326    /// the number of bytes written *must*
327    /// be equal to that amount or an error will occur when writing
328    /// or finalizing the stream.
329    ///
330    /// # Errors
331    ///
332    /// Returns I/O error if unable to write initial
333    /// metadata blocks.
334    pub fn create_cdda<P: AsRef<Path>>(
335        path: P,
336        options: Options,
337        total_bytes: Option<u64>,
338    ) -> Result<Self, Error> {
339        Self::create(path, options, 44100, 16, 2, total_bytes)
340    }
341}
342
343impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> std::io::Write
344    for FlacByteWriter<W, E>
345{
346    /// Writes a set of sample bytes to the FLAC file
347    ///
348    /// Samples are signed and encoded in the stream's given byte order.
349    ///
350    /// Samples are then interleaved by channel, like:
351    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
352    ///
353    /// This is the same format used by common PCM container
354    /// formats like WAVE and AIFF.
355    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
356        use crate::byteorder::LittleEndian;
357
358        // dump whole set of bytes into our internal buffer
359        self.buf.extend(buf);
360
361        // encode as many FLAC frames as possible (which may be 0)
362        let mut encoded_frames = 0;
363        for buf in self
364            .buf
365            .make_contiguous()
366            .chunks_exact_mut(self.frame_byte_size)
367        {
368            // convert buffer to little-endian bytes
369            E::bytes_to_le(buf, self.bytes_per_sample);
370
371            // update MD5 sum with little-endian bytes
372            self.encoder.update_md5(buf);
373
374            // encode fresh FLAC frame
375            self.encoder
376                .encode(self.frame.fill_from_buf::<LittleEndian>(buf))?;
377
378            encoded_frames += 1;
379        }
380        // TODO - use truncate_front whenever that stabilizes
381        self.buf.drain(0..self.frame_byte_size * encoded_frames);
382
383        // indicate whole buffer's been consumed
384        Ok(buf.len())
385    }
386
387    #[inline]
388    fn flush(&mut self) -> std::io::Result<()> {
389        // we don't want to flush a partial frame to disk,
390        // but we can at least flush our internal writer
391        self.encoder.writer.flush()
392    }
393}
394
395impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> Drop
396    for FlacByteWriter<W, E>
397{
398    fn drop(&mut self) {
399        let _ = self.finalize_inner();
400    }
401}
402
403/// A FLAC writer which accepts samples as signed integers
404///
405/// # Example
406///
407/// ```
408/// use flac_codec::{
409///     encode::{FlacSampleWriter, Options},
410///     decode::FlacSampleReader,
411/// };
412/// use std::io::{Cursor, Seek};
413///
414/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
415///
416/// let mut writer = FlacSampleWriter::new(
417///     &mut flac,           // our wrapped writer
418///     Options::default(),  // default encoding options
419///     44100,               // sample rate
420///     16,                  // bits-per-sample
421///     1,                   // channel count
422///     Some(1000),          // total samples
423/// ).unwrap();
424///
425/// // write 1000 samples
426/// let written_samples = (0..1000).collect::<Vec<i32>>();
427/// assert!(writer.write(&written_samples).is_ok());
428///
429/// // finalize writing file
430/// assert!(writer.finalize().is_ok());
431///
432/// flac.rewind().unwrap();
433///
434/// // open reader around written FLAC file
435/// let mut reader = FlacSampleReader::new(flac).unwrap();
436///
437/// // read 1000 samples
438/// let mut read_samples = vec![0; 1000];
439/// assert!(matches!(reader.read(&mut read_samples), Ok(1000)));
440///
441/// // ensure they match
442/// assert_eq!(read_samples, written_samples);
443/// ```
444pub struct FlacSampleWriter<W: std::io::Write + std::io::Seek> {
445    // the wrapped encoder
446    encoder: Encoder<W>,
447    // samples that make up a partial FLAC frame
448    // in channel-interleaved order
449    // (must de-interleave later in case someone writes
450    // only partial set of channels in a single write call)
451    sample_buf: VecDeque<i32>,
452    // a whole set of samples for a FLAC frame
453    frame: Frame,
454    // size of a single frame in samples
455    frame_sample_size: usize,
456    // size of a single PCM frame in samples
457    pcm_frame_size: usize,
458    // size of a single sample in bytes
459    bytes_per_sample: usize,
460    // whether the encoder has finalized the file
461    finalized: bool,
462}
463
464impl<W: std::io::Write + std::io::Seek> FlacSampleWriter<W> {
465    /// Creates new FLAC writer with the given parameters
466    ///
467    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
468    ///
469    /// `bits_per_sample` must be between 1 and 32.
470    ///
471    /// `channels` must be between 1 and 8.
472    ///
473    /// Note that if `total_samples` is indicated,
474    /// the number of samples written *must*
475    /// be equal to that amount or an error will occur when writing
476    /// or finalizing the stream.
477    ///
478    /// # Errors
479    ///
480    /// Returns I/O error if unable to write initial
481    /// metadata blocks.
482    /// Returns error if any of the encoding parameters are invalid.
483    pub fn new(
484        writer: W,
485        options: Options,
486        sample_rate: u32,
487        bits_per_sample: u32,
488        channels: u8,
489        total_samples: Option<u64>,
490    ) -> Result<Self, Error> {
491        let bits_per_sample: SignedBitCount<32> = bits_per_sample
492            .try_into()
493            .map_err(|_| Error::InvalidBitsPerSample)?;
494
495        let bytes_per_sample = u32::from(bits_per_sample).div_ceil(8) as usize;
496
497        let pcm_frame_size = usize::from(channels);
498
499        Ok(Self {
500            sample_buf: VecDeque::default(),
501            frame: Frame::empty(channels.into(), bits_per_sample.into()),
502            bytes_per_sample,
503            pcm_frame_size,
504            frame_sample_size: pcm_frame_size * options.block_size as usize,
505            encoder: Encoder::new(
506                writer,
507                options,
508                sample_rate,
509                bits_per_sample,
510                channels,
511                total_samples
512                    .map(|samples| {
513                        exact_div(samples, channels.into())
514                            .ok_or(Error::SamplesNotDivisibleByChannels)
515                            .and_then(|s| NonZero::new(s).ok_or(Error::InvalidTotalSamples))
516                    })
517                    .transpose()?,
518            )?,
519            finalized: false,
520        })
521    }
522
523    /// Creates new FLAC writer with CDDA parameters
524    ///
525    /// Sample rate is 44100 Hz, bits-per-sample is 16,
526    /// channels is 2.
527    ///
528    /// Note that if `total_samples` is indicated,
529    /// the number of samples written *must*
530    /// be equal to that amount or an error will occur when writing
531    /// or finalizing the stream.
532    ///
533    /// # Errors
534    ///
535    /// Returns I/O error if unable to write initial
536    /// metadata blocks.
537    /// Returns error if any of the encoding parameters are invalid.
538    pub fn new_cdda(
539        writer: W,
540        options: Options,
541        total_samples: Option<u64>,
542    ) -> Result<Self, Error> {
543        Self::new(writer, options, 44100, 16, 2, total_samples)
544    }
545
546    /// Given a set of samples, writes them to the FLAC file
547    ///
548    /// Samples are interleaved by channel, like:
549    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
550    ///
551    /// This may output 0 or more actual FLAC frames,
552    /// depending on the quantity of samples and the amount
553    /// previously written.
554    pub fn write(&mut self, samples: &[i32]) -> Result<(), Error> {
555        // dump whole set of samples into our internal buffer
556        self.sample_buf.extend(samples);
557
558        // encode as many FLAC frames as possible (which may be 0)
559        let mut encoded_frames = 0;
560        for buf in self
561            .sample_buf
562            .make_contiguous()
563            .chunks_exact_mut(self.frame_sample_size)
564        {
565            // update running MD5 sum calculation
566            // since samples are already interleaved in channel order
567            update_md5(
568                &mut self.encoder,
569                buf.iter().copied(),
570                self.bytes_per_sample,
571            );
572
573            // encode fresh FLAC frame
574            self.encoder.encode(self.frame.fill_from_samples(buf))?;
575
576            encoded_frames += 1;
577        }
578        self.sample_buf
579            .drain(0..self.frame_sample_size * encoded_frames);
580
581        Ok(())
582    }
583
584    fn finalize_inner(&mut self) -> Result<(), Error> {
585        if !self.finalized {
586            self.finalized = true;
587
588            // encode as many samples possible into final frame, if necessary
589            if !self.sample_buf.is_empty() {
590                // truncate buffer to whole PCM frames
591                let buf = self.sample_buf.make_contiguous();
592                let buf_len = buf.len();
593                let buf = &mut buf[..(buf_len - buf_len % self.pcm_frame_size)];
594
595                // update running MD5 sum calculation
596                // since samples are already interleaved in channel order
597                update_md5(
598                    &mut self.encoder,
599                    buf.iter().copied(),
600                    self.bytes_per_sample,
601                );
602
603                // encode final FLAC frame
604                self.encoder.encode(self.frame.fill_from_samples(buf))?;
605            }
606
607            self.encoder.finalize_inner()
608        } else {
609            Ok(())
610        }
611    }
612
613    /// Attempt to finalize stream
614    ///
615    /// It is necessary to finalize the FLAC encoder
616    /// so that it will write any partially unwritten samples
617    /// to the stream and update the [`crate::metadata::Streaminfo`] and [`crate::metadata::SeekTable`] blocks
618    /// with their final values.
619    ///
620    /// Dropping the encoder will attempt to finalize the stream
621    /// automatically, but will ignore any errors that may occur.
622    pub fn finalize(mut self) -> Result<(), Error> {
623        self.finalize_inner()?;
624        Ok(())
625    }
626}
627
628impl FlacSampleWriter<BufWriter<File>> {
629    /// Creates new FLAC file at the given path
630    ///
631    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
632    ///
633    /// `bits_per_sample` must be between 1 and 32.
634    ///
635    /// `channels` must be between 1 and 8.
636    ///
637    /// Note that if `total_bytes` is indicated,
638    /// the number of bytes written *must*
639    /// be equal to that amount or an error will occur when writing
640    /// or finalizing the stream.
641    ///
642    /// # Errors
643    ///
644    /// Returns I/O error if unable to write initial
645    /// metadata blocks.
646    #[inline]
647    pub fn create<P: AsRef<Path>>(
648        path: P,
649        options: Options,
650        sample_rate: u32,
651        bits_per_sample: u32,
652        channels: u8,
653        total_samples: Option<u64>,
654    ) -> Result<Self, Error> {
655        FlacSampleWriter::new(
656            BufWriter::new(options.create(path)?),
657            options,
658            sample_rate,
659            bits_per_sample,
660            channels,
661            total_samples,
662        )
663    }
664
665    /// Creates new FLAC file at the given path with CDDA parameters
666    ///
667    /// Sample rate is 44100 Hz, bits-per-sample is 16,
668    /// channels is 2.
669    ///
670    /// Note that if `total_bytes` is indicated,
671    /// the number of bytes written *must*
672    /// be equal to that amount or an error will occur when writing
673    /// or finalizing the stream.
674    ///
675    /// # Errors
676    ///
677    /// Returns I/O error if unable to write initial
678    /// metadata blocks.
679    #[inline]
680    pub fn create_cdda<P: AsRef<Path>>(
681        path: P,
682        options: Options,
683        total_samples: Option<u64>,
684    ) -> Result<Self, Error> {
685        Self::create(path, options, 44100, 16, 2, total_samples)
686    }
687}
688
689/// A FLAC writer which accepts samples as channels of signed integers
690///
691/// # Example
692/// ```
693/// use flac_codec::{
694///     encode::{FlacChannelWriter, Options},
695///     decode::FlacChannelReader,
696/// };
697/// use std::io::{Cursor, Seek};
698///
699/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
700///
701/// let mut writer = FlacChannelWriter::new(
702///     &mut flac,           // our wrapped writer
703///     Options::default(),  // default encoding options
704///     44100,               // sample rate
705///     16,                  // bits-per-sample
706///     2,                   // channel count
707///     Some(5),             // total channel-independent samples
708/// ).unwrap();
709///
710/// // write our samples, divided by channel
711/// let written_samples = vec![
712///     vec![1, 2, 3, 4, 5],
713///     vec![-1, -2, -3, -4, -5],
714/// ];
715/// assert!(writer.write(&written_samples).is_ok());
716///
717/// // finalize writing file
718/// assert!(writer.finalize().is_ok());
719///
720/// flac.rewind().unwrap();
721///
722/// // open reader around written FLAC file
723/// let mut reader = FlacChannelReader::new(flac).unwrap();
724///
725/// // read a buffer's worth of samples
726/// let read_samples = reader.fill_buf().unwrap();
727///
728/// // ensure the channels match
729/// assert_eq!(read_samples.len(), written_samples.len());
730/// assert_eq!(read_samples[0], written_samples[0]);
731/// assert_eq!(read_samples[1], written_samples[1]);
732/// ```
733pub struct FlacChannelWriter<W: std::io::Write + std::io::Seek> {
734    // the wrapped encoder
735    encoder: Encoder<W>,
736    // channels that make up a partial FLAC frame
737    channel_bufs: Vec<VecDeque<i32>>,
738    // a whole set of samples for a FLAC frame
739    frame: Frame,
740    // size of a single frame in samples
741    frame_sample_size: usize,
742    // size of a single sample in bytes
743    bytes_per_sample: usize,
744    // whether the encoder has finalized the file
745    finalized: bool,
746}
747
748impl<W: std::io::Write + std::io::Seek> FlacChannelWriter<W> {
749    /// Creates new FLAC writer with the given parameters
750    ///
751    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
752    ///
753    /// `bits_per_sample` must be between 1 and 32.
754    ///
755    /// `channels` must be between 1 and 8.
756    ///
757    /// Note that if `total_samples` is indicated,
758    /// the number of samples written *must*
759    /// be equal to that amount or an error will occur when writing
760    /// or finalizing the stream.
761    ///
762    /// # Errors
763    ///
764    /// Returns I/O error if unable to write initial
765    /// metadata blocks.
766    /// Returns error if any of the encoding parameters are invalid.
767    pub fn new(
768        writer: W,
769        options: Options,
770        sample_rate: u32,
771        bits_per_sample: u32,
772        channels: u8,
773        total_samples: Option<u64>,
774    ) -> Result<Self, Error> {
775        let bits_per_sample: SignedBitCount<32> = bits_per_sample
776            .try_into()
777            .map_err(|_| Error::InvalidBitsPerSample)?;
778
779        let bytes_per_sample = u32::from(bits_per_sample).div_ceil(8) as usize;
780
781        Ok(Self {
782            channel_bufs: vec![VecDeque::default(); channels.into()],
783            frame: Frame::empty(channels.into(), bits_per_sample.into()),
784            bytes_per_sample,
785            frame_sample_size: options.block_size as usize,
786            encoder: Encoder::new(
787                writer,
788                options,
789                sample_rate,
790                bits_per_sample,
791                channels,
792                total_samples.and_then(NonZero::new),
793            )?,
794            finalized: false,
795        })
796    }
797
798    /// Creates new FLAC writer with CDDA parameters
799    ///
800    /// Sample rate is 44100 Hz, bits-per-sample is 16,
801    /// channels is 2.
802    ///
803    /// Note that if `total_samples` is indicated,
804    /// the number of samples written *must*
805    /// be equal to that amount or an error will occur when writing
806    /// or finalizing the stream.
807    ///
808    /// # Errors
809    ///
810    /// Returns I/O error if unable to write initial
811    /// metadata blocks.
812    /// Returns error if any of the encoding parameters are invalid.
813    pub fn new_cdda(
814        writer: W,
815        options: Options,
816        total_samples: Option<u64>,
817    ) -> Result<Self, Error> {
818        Self::new(writer, options, 44100, 16, 2, total_samples)
819    }
820
821    /// Given a set of channels containing samples, writes them to the FLAC file
822    ///
823    /// Channels should be a slice-able set of sample slices, like:
824    /// [[left₀ , left₁ , left₂ , …] , [right₀ , right₁ , right₂ , …]]
825    ///
826    /// The number of channels must be identical to the
827    /// channel count indicated when intializing the encoder.
828    ///
829    /// The number of samples in each channel must also be identical.
830    pub fn write<C, S>(&mut self, channels: C) -> Result<(), Error>
831    where
832        C: AsRef<[S]>,
833        S: AsRef<[i32]>,
834    {
835        use crate::audio::MultiZip;
836
837        // sanity-check our channel inputs
838        let channels = channels.as_ref();
839
840        match channels {
841            whole @ [first, rest @ ..]
842                if whole.len() == usize::from(self.encoder.channel_count().get()) =>
843            {
844                if rest
845                    .iter()
846                    .any(|c| c.as_ref().len() != first.as_ref().len())
847                {
848                    return Err(Error::ChannelLengthMismatch);
849                }
850            }
851            _ => {
852                return Err(Error::ChannelCountMismatch);
853            }
854        }
855
856        // dump whole set of samples into our internal channel buffers
857        for (buf, channel) in self.channel_bufs.iter_mut().zip(channels) {
858            buf.extend(channel.as_ref());
859        }
860
861        // encode as many FLAC frames as possible (which may be 0)
862        let mut encoded_frames = 0;
863        for bufs in self
864            .channel_bufs
865            .iter_mut()
866            .map(|v| v.make_contiguous().chunks_exact_mut(self.frame_sample_size))
867            .collect::<MultiZip<_>>()
868        {
869            // update running MD5 sum calculation
870            update_md5(
871                &mut self.encoder,
872                bufs.iter()
873                    .map(|c| c.iter().copied())
874                    .collect::<MultiZip<_>>()
875                    .flatten(),
876                self.bytes_per_sample,
877            );
878
879            // encode fresh FLAC frame
880            self.encoder.encode(self.frame.fill_from_channels(&bufs))?;
881
882            encoded_frames += 1;
883        }
884
885        // drain encoded samples from buffers
886        for channel in self.channel_bufs.iter_mut() {
887            channel.drain(0..self.frame_sample_size * encoded_frames);
888        }
889
890        Ok(())
891    }
892
893    fn finalize_inner(&mut self) -> Result<(), Error> {
894        use crate::audio::MultiZip;
895
896        if !self.finalized {
897            self.finalized = true;
898
899            // encode as many samples possible into final frame, if necessary
900            if !self.channel_bufs[0].is_empty() {
901                // update running MD5 sum calculation
902                update_md5(
903                    &mut self.encoder,
904                    self.channel_bufs
905                        .iter()
906                        .map(|c| c.iter().copied())
907                        .collect::<MultiZip<_>>()
908                        .flatten(),
909                    self.bytes_per_sample,
910                );
911
912                // encode final FLAC frame
913                self.encoder.encode(
914                    self.frame.fill_from_channels(
915                        self.channel_bufs
916                            .iter_mut()
917                            .map(|v| v.make_contiguous())
918                            .collect::<ArrayVec<_, MAX_CHANNELS>>()
919                            .as_slice(),
920                    ),
921                )?;
922            }
923
924            self.encoder.finalize_inner()
925        } else {
926            Ok(())
927        }
928    }
929
930    /// Attempt to finalize stream
931    ///
932    /// It is necessary to finalize the FLAC encoder
933    /// so that it will write any partially unwritten samples
934    /// to the stream and update the [`crate::metadata::Streaminfo`] and [`crate::metadata::SeekTable`] blocks
935    /// with their final values.
936    ///
937    /// Dropping the encoder will attempt to finalize the stream
938    /// automatically, but will ignore any errors that may occur.
939    pub fn finalize(mut self) -> Result<(), Error> {
940        self.finalize_inner()?;
941        Ok(())
942    }
943}
944
945impl FlacChannelWriter<BufWriter<File>> {
946    /// Creates new FLAC file at the given path
947    ///
948    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
949    ///
950    /// `bits_per_sample` must be between 1 and 32.
951    ///
952    /// `channels` must be between 1 and 8.
953    ///
954    /// Note that if `total_bytes` is indicated,
955    /// the number of bytes written *must*
956    /// be equal to that amount or an error will occur when writing
957    /// or finalizing the stream.
958    ///
959    /// # Errors
960    ///
961    /// Returns I/O error if unable to write initial
962    /// metadata blocks.
963    #[inline]
964    pub fn create<P: AsRef<Path>>(
965        path: P,
966        options: Options,
967        sample_rate: u32,
968        bits_per_sample: u32,
969        channels: u8,
970        total_samples: Option<u64>,
971    ) -> Result<Self, Error> {
972        FlacChannelWriter::new(
973            BufWriter::new(options.create(path)?),
974            options,
975            sample_rate,
976            bits_per_sample,
977            channels,
978            total_samples,
979        )
980    }
981
982    /// Creates new FLAC file at the given path with CDDA parameters
983    ///
984    /// Sample rate is 44100 Hz, bits-per-sample is 16,
985    /// channels is 2.
986    ///
987    /// Note that if `total_bytes` is indicated,
988    /// the number of bytes written *must*
989    /// be equal to that amount or an error will occur when writing
990    /// or finalizing the stream.
991    ///
992    /// # Errors
993    ///
994    /// Returns I/O error if unable to write initial
995    /// metadata blocks.
996    #[inline]
997    pub fn create_cdda<P: AsRef<Path>>(
998        path: P,
999        options: Options,
1000        total_samples: Option<u64>,
1001    ) -> Result<Self, Error> {
1002        Self::create(path, options, 44100, 16, 2, total_samples)
1003    }
1004}
1005
1006/// A FLAC writer which operates on streamed output
1007///
1008/// Because this encodes FLAC frames without any metadata
1009/// blocks or finalizing, it does not need to be seekable.
1010///
1011/// # Example
1012///
1013/// ```
1014/// use flac_codec::{
1015///     decode::{FlacStreamReader, FrameBuf},
1016///     encode::{FlacStreamWriter, Options},
1017/// };
1018/// use std::io::{Cursor, Seek};
1019/// use bitstream_io::SignedBitCount;
1020///
1021/// let mut flac = Cursor::new(vec![]);
1022///
1023/// let samples = (0..100).collect::<Vec<i32>>();
1024///
1025/// let mut w = FlacStreamWriter::new(&mut flac, Options::default());
1026///
1027/// // write a single FLAC frame with some samples
1028/// w.write(
1029///     44100,  // sample rate
1030///     1,      // channels
1031///     16,     // bits-per-sample
1032///     &samples,
1033/// ).unwrap();
1034///
1035/// flac.rewind().unwrap();
1036///
1037/// let mut r = FlacStreamReader::new(&mut flac);
1038///
1039/// // read a single FLAC frame with some samples
1040/// assert_eq!(
1041///     r.read().unwrap(),
1042///     FrameBuf {
1043///         samples: &samples,
1044///         sample_rate: 44100,
1045///         channels: 1,
1046///         bits_per_sample: 16,
1047///     },
1048/// );
1049/// ```
1050pub struct FlacStreamWriter<W> {
1051    // the writer we're outputting to
1052    writer: W,
1053    // various encoding optins
1054    options: EncoderOptions,
1055    // various encoder caches
1056    caches: EncodingCaches,
1057    // a whole set of samples for a FLAC frame
1058    frame: Frame,
1059    // the current frame number
1060    frame_number: FrameNumber,
1061}
1062
1063impl<W: std::io::Write> FlacStreamWriter<W> {
1064    /// Creates new stream writer
1065    pub fn new(writer: W, options: Options) -> Self {
1066        Self {
1067            writer,
1068            options: EncoderOptions {
1069                max_partition_order: options.max_partition_order,
1070                mid_side: options.mid_side,
1071                seektable_interval: options.seektable_interval,
1072                max_lpc_order: options.max_lpc_order,
1073                window: options.window,
1074                exhaustive_channel_correlation: options.exhaustive_channel_correlation,
1075                use_rice2: false,
1076            },
1077            caches: EncodingCaches::default(),
1078            frame: Frame::default(),
1079            frame_number: FrameNumber::default(),
1080        }
1081    }
1082
1083    /// Writes a whole FLAC frame to our output stream
1084    ///
1085    /// Samples are interleaved by channel, like:
1086    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
1087    ///
1088    /// This writes a whole FLAC frame to the output stream on each call.
1089    ///
1090    /// # Errors
1091    ///
1092    /// Returns an error of any of the parameters are invalid
1093    /// or if an I/O error occurs when writing to the stream.
1094    pub fn write(
1095        &mut self,
1096        sample_rate: u32,
1097        channels: u8,
1098        bits_per_sample: u32,
1099        samples: &[i32],
1100    ) -> Result<(), Error> {
1101        use crate::crc::{Crc16, CrcWriter};
1102        use crate::stream::{BitsPerSample, FrameHeader, SampleRate};
1103
1104        let bits_per_sample: SignedBitCount<32> = bits_per_sample
1105            .try_into()
1106            .map_err(|_| Error::NonSubsetBitsPerSample)?;
1107
1108        // samples must divide evenly into channels
1109        if samples.len() % usize::from(channels) != 0 {
1110            return Err(Error::SamplesNotDivisibleByChannels);
1111        } else if !(1..=8).contains(&channels) {
1112            return Err(Error::ExcessiveChannels);
1113        }
1114
1115        self.options.use_rice2 = u32::from(bits_per_sample) > 16;
1116
1117        self.frame
1118            .resize(bits_per_sample.into(), channels.into(), 0);
1119        self.frame.fill_from_samples(samples);
1120
1121        // block size must be valid
1122        let block_size: crate::stream::BlockSize<u16> = crate::stream::BlockSize::try_from(
1123            u16::try_from(self.frame.pcm_frames()).map_err(|_| Error::InvalidBlockSize)?,
1124        )
1125        .map_err(|_| Error::InvalidBlockSize)?;
1126
1127        // sample rate must be valid for subset streams
1128        let sample_rate: SampleRate<u32> = sample_rate.try_into().and_then(|rate| match rate {
1129            SampleRate::Streaminfo(_) => Err(Error::NonSubsetSampleRate),
1130            rate => Ok(rate),
1131        })?;
1132
1133        // bits-per-sample must be valid for subset streams
1134        let header_bits_per_sample = match BitsPerSample::from(bits_per_sample) {
1135            BitsPerSample::Streaminfo(_) => return Err(Error::NonSubsetBitsPerSample),
1136            bps => bps,
1137        };
1138
1139        let mut w: CrcWriter<_, Crc16> = CrcWriter::new(&mut self.writer);
1140        let mut bw: BitWriter<CrcWriter<&mut W, Crc16>, BigEndian>;
1141
1142        match self
1143            .frame
1144            .channels()
1145            .collect::<ArrayVec<&[i32], MAX_CHANNELS>>()
1146            .as_slice()
1147        {
1148            [channel] => {
1149                FrameHeader {
1150                    blocking_strategy: false,
1151                    frame_number: self.frame_number,
1152                    block_size: (channel.len() as u16)
1153                        .try_into()
1154                        .expect("frame cannot be empty"),
1155                    sample_rate,
1156                    bits_per_sample: header_bits_per_sample,
1157                    channel_assignment: ChannelAssignment::Independent(Independent::Mono),
1158                }
1159                .write_subset(&mut w)?;
1160
1161                bw = BitWriter::new(w);
1162
1163                self.caches.channels.resize_with(1, ChannelCache::default);
1164
1165                encode_subframe(
1166                    &self.options,
1167                    &mut self.caches.channels[0],
1168                    CorrelatedChannel::independent(bits_per_sample, channel),
1169                )?
1170                .playback(&mut bw)?;
1171            }
1172            [left, right] if self.options.exhaustive_channel_correlation => {
1173                let Correlated {
1174                    channel_assignment,
1175                    channels: [channel_0, channel_1],
1176                } = correlate_channels_exhaustive(
1177                    &self.options,
1178                    &mut self.caches.correlated,
1179                    [left, right],
1180                    bits_per_sample,
1181                )?;
1182
1183                FrameHeader {
1184                    blocking_strategy: false,
1185                    frame_number: self.frame_number,
1186                    block_size,
1187                    sample_rate,
1188                    bits_per_sample: header_bits_per_sample,
1189                    channel_assignment,
1190                }
1191                .write_subset(&mut w)?;
1192
1193                bw = BitWriter::new(w);
1194
1195                channel_0.playback(&mut bw)?;
1196                channel_1.playback(&mut bw)?;
1197            }
1198            [left, right] => {
1199                let Correlated {
1200                    channel_assignment,
1201                    channels: [channel_0, channel_1],
1202                } = correlate_channels(
1203                    &self.options,
1204                    &mut self.caches.correlated,
1205                    [left, right],
1206                    bits_per_sample,
1207                );
1208
1209                FrameHeader {
1210                    blocking_strategy: false,
1211                    frame_number: self.frame_number,
1212                    block_size,
1213                    sample_rate,
1214                    bits_per_sample: header_bits_per_sample,
1215                    channel_assignment,
1216                }
1217                .write_subset(&mut w)?;
1218
1219                self.caches.channels.resize_with(2, ChannelCache::default);
1220                let [cache_0, cache_1] = self.caches.channels.get_disjoint_mut([0, 1]).unwrap();
1221                let (channel_0, channel_1) = join(
1222                    || encode_subframe(&self.options, cache_0, channel_0),
1223                    || encode_subframe(&self.options, cache_1, channel_1),
1224                );
1225
1226                bw = BitWriter::new(w);
1227
1228                channel_0?.playback(&mut bw)?;
1229                channel_1?.playback(&mut bw)?;
1230            }
1231            channels => {
1232                // non-stereo frames are always encoded independently
1233                FrameHeader {
1234                    blocking_strategy: false,
1235                    frame_number: self.frame_number,
1236                    block_size,
1237                    sample_rate,
1238                    bits_per_sample: header_bits_per_sample,
1239                    channel_assignment: ChannelAssignment::Independent(
1240                        channels.len().try_into().expect("invalid channel count"),
1241                    ),
1242                }
1243                .write_subset(&mut w)?;
1244
1245                bw = BitWriter::new(w);
1246
1247                self.caches
1248                    .channels
1249                    .resize_with(channels.len(), ChannelCache::default);
1250
1251                vec_map(
1252                    self.caches.channels.iter_mut().zip(channels).collect(),
1253                    |(cache, channel)| {
1254                        encode_subframe(
1255                            &self.options,
1256                            cache,
1257                            CorrelatedChannel::independent(bits_per_sample, channel),
1258                        )
1259                    },
1260                )
1261                .into_iter()
1262                .try_for_each(|r| r.and_then(|r| r.playback(bw.by_ref()).map_err(Error::Io)))?;
1263            }
1264        }
1265
1266        let crc16: u16 = bw.aligned_writer()?.checksum().into();
1267        bw.write_from(crc16)?;
1268
1269        if self.frame_number.try_increment().is_err() {
1270            self.frame_number = FrameNumber::default();
1271        }
1272
1273        Ok(())
1274    }
1275
1276    /// Writes a whole FLAC frame to our output stream with CDDA parameters
1277    ///
1278    /// Samples are interleaved by channel, like:
1279    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
1280    ///
1281    /// This writes a whole FLAC frame to the output stream on each call.
1282    ///
1283    /// # Errors
1284    ///
1285    /// Returns an error of any of the parameters are invalid
1286    /// or if an I/O error occurs when writing to the stream.
1287    pub fn write_cdda(&mut self, samples: &[i32]) -> Result<(), Error> {
1288        self.write(44100, 2, 16, samples)
1289    }
1290}
1291
1292fn update_md5<W: std::io::Write + std::io::Seek>(
1293    encoder: &mut Encoder<W>,
1294    samples: impl Iterator<Item = i32>,
1295    bytes_per_sample: usize,
1296) {
1297    use crate::byteorder::{Endianness, LittleEndian};
1298
1299    match bytes_per_sample {
1300        1 => {
1301            for s in samples {
1302                encoder.update_md5(&LittleEndian::i8_to_bytes(s as i8));
1303            }
1304        }
1305        2 => {
1306            for s in samples {
1307                encoder.update_md5(&LittleEndian::i16_to_bytes(s as i16));
1308            }
1309        }
1310        3 => {
1311            for s in samples {
1312                encoder.update_md5(&LittleEndian::i24_to_bytes(s));
1313            }
1314        }
1315        4 => {
1316            for s in samples {
1317                encoder.update_md5(&LittleEndian::i32_to_bytes(s));
1318            }
1319        }
1320        _ => panic!("unsupported number of bytes per sample"),
1321    }
1322}
1323
1324/// The interval of seek points to generate
1325#[derive(Copy, Clone, Debug)]
1326pub enum SeekTableInterval {
1327    ///Generate seekpoint every nth seconds
1328    Seconds(NonZero<u8>),
1329    /// Generate seekpoint every nth frames
1330    Frames(NonZero<usize>),
1331}
1332
1333impl Default for SeekTableInterval {
1334    fn default() -> Self {
1335        Self::Seconds(NonZero::new(10).unwrap())
1336    }
1337}
1338
1339impl SeekTableInterval {
1340    // decimates full set of seekpoints based on the requested
1341    // seektable style, or returns None if no seektable is requested
1342    fn filter<'s>(
1343        self,
1344        sample_rate: u32,
1345        seekpoints: impl IntoIterator<Item = EncoderSeekPoint> + 's,
1346    ) -> Box<dyn Iterator<Item = EncoderSeekPoint> + 's> {
1347        match self {
1348            Self::Seconds(seconds) => {
1349                let nth_sample = u64::from(u32::from(seconds.get()) * sample_rate);
1350                let mut offset = 0;
1351                Box::new(seekpoints.into_iter().filter(move |point| {
1352                    if point.range().contains(&offset) {
1353                        offset += nth_sample;
1354                        true
1355                    } else {
1356                        false
1357                    }
1358                }))
1359            }
1360            Self::Frames(frames) => Box::new(seekpoints.into_iter().step_by(frames.get())),
1361        }
1362    }
1363}
1364
1365/// FLAC encoding options
1366#[derive(Clone, Debug)]
1367pub struct Options {
1368    // whether to clobber existing file
1369    clobber: bool,
1370    block_size: u16,
1371    max_partition_order: u32,
1372    mid_side: bool,
1373    metadata: BlockList,
1374    seektable_interval: Option<SeekTableInterval>,
1375    max_lpc_order: Option<NonZero<u8>>,
1376    window: Window,
1377    exhaustive_channel_correlation: bool,
1378}
1379
1380impl Default for Options {
1381    fn default() -> Self {
1382        // a dummy placeholder value
1383        // since we can't know the stream parameters yet
1384        let mut metadata = BlockList::new(Streaminfo {
1385            minimum_block_size: 0,
1386            maximum_block_size: 0,
1387            minimum_frame_size: None,
1388            maximum_frame_size: None,
1389            sample_rate: 0,
1390            channels: NonZero::new(1).unwrap(),
1391            bits_per_sample: SignedBitCount::new::<4>(),
1392            total_samples: None,
1393            md5: None,
1394        });
1395
1396        metadata.insert(crate::metadata::Padding {
1397            size: 4096u16.into(),
1398        });
1399
1400        Self {
1401            clobber: false,
1402            block_size: 4096,
1403            mid_side: true,
1404            max_partition_order: 5,
1405            metadata,
1406            seektable_interval: Some(SeekTableInterval::default()),
1407            max_lpc_order: NonZero::new(8),
1408            window: Window::default(),
1409            exhaustive_channel_correlation: true,
1410        }
1411    }
1412}
1413
1414impl Options {
1415    /// Sets new block size
1416    ///
1417    /// Block size must be ≥ 16
1418    ///
1419    /// For subset streams, this must be ≤ 4608
1420    /// if the sample rate is ≤ 48 kHz -
1421    /// or ≤ 16384 for higher sample rates.
1422    pub fn block_size(self, block_size: u16) -> Result<Self, OptionsError> {
1423        match block_size {
1424            0..16 => Err(OptionsError::InvalidBlockSize),
1425            16.. => Ok(Self { block_size, ..self }),
1426        }
1427    }
1428
1429    /// Sets new maximum LPC order
1430    ///
1431    /// The valid range is: 0 < `max_lpc_order` ≤ 32
1432    ///
1433    /// A value of `None` means that no LPC subframes will be encoded.
1434    pub fn max_lpc_order(self, max_lpc_order: Option<u8>) -> Result<Self, OptionsError> {
1435        Ok(Self {
1436            max_lpc_order: max_lpc_order
1437                .map(|o| {
1438                    o.try_into()
1439                        .ok()
1440                        .filter(|o| *o <= NonZero::new(32).unwrap())
1441                        .ok_or(OptionsError::InvalidLpcOrder)
1442                })
1443                .transpose()?,
1444            ..self
1445        })
1446    }
1447
1448    /// Sets maximum residual partion order.
1449    ///
1450    /// The valid range is: 0 ≤ `max_partition_order` ≤ 15
1451    pub fn max_partition_order(self, max_partition_order: u32) -> Result<Self, OptionsError> {
1452        match max_partition_order {
1453            0..=15 => Ok(Self {
1454                max_partition_order,
1455                ..self
1456            }),
1457            16.. => Err(OptionsError::InvalidMaxPartitions),
1458        }
1459    }
1460
1461    /// Whether to use mid-side encoding
1462    ///
1463    /// The default is `true`.
1464    pub fn mid_side(self, mid_side: bool) -> Self {
1465        Self { mid_side, ..self }
1466    }
1467
1468    /// The windowing function to use for input samples
1469    pub fn window(self, window: Window) -> Self {
1470        Self { window, ..self }
1471    }
1472
1473    /// Whether to calculate the best channel correlation quickly
1474    ///
1475    /// The default is `false`
1476    pub fn fast_channel_correlation(self, fast: bool) -> Self {
1477        Self {
1478            exhaustive_channel_correlation: !fast,
1479            ..self
1480        }
1481    }
1482
1483    /// Updates size of padding block
1484    ///
1485    /// `size` must be < 2²⁴
1486    ///
1487    /// If `size` is set to 0, removes the block entirely.
1488    ///
1489    /// The default is to add a 4096 byte padding block.
1490    pub fn padding(mut self, size: u32) -> Result<Self, OptionsError> {
1491        use crate::metadata::Padding;
1492
1493        match size
1494            .try_into()
1495            .map_err(|_| OptionsError::ExcessivePadding)?
1496        {
1497            BlockSize::ZERO => self.metadata.remove::<Padding>(),
1498            size => self.metadata.update::<Padding>(|p| {
1499                p.size = size;
1500            }),
1501        }
1502        Ok(self)
1503    }
1504
1505    /// Remove any padding blocks from metadata
1506    ///
1507    /// This makes the file smaller, but will likely require
1508    /// rewriting it if any metadata needs to be modified later.
1509    pub fn no_padding(mut self) -> Self {
1510        self.metadata.remove::<crate::metadata::Padding>();
1511        self
1512    }
1513
1514    /// Adds new tag to comment metadata block
1515    ///
1516    /// Creates new [`crate::metadata::VorbisComment`] block if not already present.
1517    pub fn tag<S>(mut self, field: &str, value: S) -> Self
1518    where
1519        S: std::fmt::Display,
1520    {
1521        self.metadata
1522            .update::<VorbisComment>(|vc| vc.insert(field, value));
1523        self
1524    }
1525
1526    /// Replaces entire [`crate::metadata::VorbisComment`] metadata block
1527    ///
1528    /// This may be more convenient when adding many fields at once.
1529    pub fn comment(mut self, comment: VorbisComment) -> Self {
1530        self.metadata.insert(comment);
1531        self
1532    }
1533
1534    /// Add new [`crate::metadata::Picture`] block to metadata
1535    ///
1536    /// Files may contain multiple [`crate::metadata::Picture`] blocks,
1537    /// and this adds a new block each time it is used.
1538    pub fn picture(mut self, picture: Picture) -> Self {
1539        self.metadata.insert(picture);
1540        self
1541    }
1542
1543    /// Add new [`crate::metadata::Cuesheet`] block to metadata
1544    ///
1545    /// Files may (theoretically) contain multiple [`crate::metadata::Cuesheet`] blocks,
1546    /// and this adds a new block each time it is used.
1547    ///
1548    /// In practice, CD images almost always use only a single
1549    /// cue sheet.
1550    pub fn cuesheet(mut self, cuesheet: Cuesheet) -> Self {
1551        self.metadata.insert(cuesheet);
1552        self
1553    }
1554
1555    /// Add new [`crate::metadata::Application`] block to metadata
1556    ///
1557    /// Files may contain multiple [`crate::metadata::Application`] blocks,
1558    /// and this adds a new block each time it is used.
1559    pub fn application(mut self, application: Application) -> Self {
1560        self.metadata.insert(application);
1561        self
1562    }
1563
1564    /// Generate [`crate::metadata::SeekTable`] with the given number of seconds between seek points
1565    ///
1566    /// The default is to generate a SEEKTABLE with 10 seconds between seek points.
1567    ///
1568    /// If `seconds` is 0, removes the SEEKTABLE block.
1569    ///
1570    /// The interval between seek points may be larger than requested
1571    /// if the encoder's block size is larger than the seekpoint interval.
1572    pub fn seektable_seconds(mut self, seconds: u8) -> Self {
1573        // note that we can't drop a placeholder seektable
1574        // into the metadata blocks until we know
1575        // the sample rate and total samples of our stream
1576        self.seektable_interval = NonZero::new(seconds).map(SeekTableInterval::Seconds);
1577        self
1578    }
1579
1580    /// Generate [`crate::metadata::SeekTable`] with the given number of FLAC frames between seek points
1581    ///
1582    /// If `frames` is 0, removes the SEEKTABLE block
1583    pub fn seektable_frames(mut self, frames: usize) -> Self {
1584        self.seektable_interval = NonZero::new(frames).map(SeekTableInterval::Frames);
1585        self
1586    }
1587
1588    /// Do not generate a seektable in our encoded file
1589    pub fn no_seektable(self) -> Self {
1590        Self {
1591            seektable_interval: None,
1592            ..self
1593        }
1594    }
1595
1596    /// Overwrites existing file if it already exists
1597    ///
1598    /// This only applies to encoding functions which
1599    /// create files from paths.
1600    ///
1601    /// The default is to not overwrite files
1602    /// if they already exist.
1603    pub fn overwrite(mut self) -> Self {
1604        self.clobber = true;
1605        self
1606    }
1607
1608    /// Returns the fastest encoding options
1609    ///
1610    /// These are tuned to encode as quickly as possible.
1611    pub fn fast() -> Self {
1612        Self {
1613            block_size: 1152,
1614            mid_side: false,
1615            max_partition_order: 3,
1616            max_lpc_order: None,
1617            exhaustive_channel_correlation: false,
1618            ..Self::default()
1619        }
1620    }
1621
1622    /// Returns the fastest encoding options
1623    ///
1624    /// These are tuned to encode files as small as possible.
1625    pub fn best() -> Self {
1626        Self {
1627            block_size: 4096,
1628            mid_side: true,
1629            max_partition_order: 6,
1630            max_lpc_order: NonZero::new(12),
1631            ..Self::default()
1632        }
1633    }
1634
1635    /// Creates files according to whether clobber is set or not
1636    fn create<P: AsRef<Path>>(&self, path: P) -> std::io::Result<File> {
1637        if self.clobber {
1638            File::create(path)
1639        } else {
1640            use std::fs::OpenOptions;
1641
1642            OpenOptions::new()
1643                .write(true)
1644                .create_new(true)
1645                .open(path.as_ref())
1646        }
1647    }
1648}
1649
1650/// An error when specifying encoding options
1651#[derive(Debug)]
1652pub enum OptionsError {
1653    /// Selected block size is too small
1654    InvalidBlockSize,
1655    /// Maximum LPC order is too large
1656    InvalidLpcOrder,
1657    /// Maximum residual partitions is too large
1658    InvalidMaxPartitions,
1659    /// Selected padding size is too large
1660    ExcessivePadding,
1661}
1662
1663impl std::error::Error for OptionsError {}
1664
1665impl std::fmt::Display for OptionsError {
1666    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1667        match self {
1668            Self::InvalidBlockSize => "block size must be >= 16".fmt(f),
1669            Self::InvalidLpcOrder => "maximum LPC order must be <= 32".fmt(f),
1670            Self::InvalidMaxPartitions => "max partition order must be <= 15".fmt(f),
1671            Self::ExcessivePadding => "padding size is too large for block".fmt(f),
1672        }
1673    }
1674}
1675
1676/// A cut-down version of Options without the metadata blocks
1677struct EncoderOptions {
1678    max_partition_order: u32,
1679    mid_side: bool,
1680    seektable_interval: Option<SeekTableInterval>,
1681    max_lpc_order: Option<NonZero<u8>>,
1682    window: Window,
1683    exhaustive_channel_correlation: bool,
1684    use_rice2: bool,
1685}
1686
1687/// The method to use for windowing the input signal
1688#[derive(Copy, Clone, Debug)]
1689pub enum Window {
1690    /// Basic rectangular window
1691    Rectangle,
1692    /// Hann window
1693    Hann,
1694    /// Tukey window
1695    Tukey(f32),
1696}
1697
1698// TODO - add more windowing options
1699
1700impl Window {
1701    fn generate(&self, window: &mut [f64]) {
1702        use std::f64::consts::PI;
1703
1704        match self {
1705            Self::Rectangle => window.fill(1.0),
1706            Self::Hann => {
1707                // verified output against reference implementation
1708                // See: FLAC__window_hann()
1709
1710                let np =
1711                    f64::from(u16::try_from(window.len()).expect("window size too large")) - 1.0;
1712
1713                window.iter_mut().zip(0u16..).for_each(|(w, n)| {
1714                    *w = 0.5 - 0.5 * (2.0 * PI * f64::from(n) / np).cos();
1715                });
1716            }
1717            Self::Tukey(p) => match p {
1718                // verified output against reference implementation
1719                // See: FLAC__window_tukey()
1720                ..=0.0 => {
1721                    window.fill(1.0);
1722                }
1723                1.0.. => {
1724                    Self::Hann.generate(window);
1725                }
1726                0.0..1.0 => {
1727                    match ((f64::from(*p) / 2.0 * window.len() as f64) as usize).checked_sub(1) {
1728                        Some(np) => match window.get_disjoint_mut([
1729                            0..np,
1730                            np..window.len() - np,
1731                            window.len() - np..window.len(),
1732                        ]) {
1733                            Ok([first, mid, last]) => {
1734                                // u16 is maximum block size
1735                                let np = u16::try_from(np).expect("window size too large");
1736
1737                                for ((x, y), n) in
1738                                    first.iter_mut().zip(last.iter_mut().rev()).zip(0u16..)
1739                                {
1740                                    *x = 0.5 - 0.5 * (PI * f64::from(n) / f64::from(np)).cos();
1741                                    *y = *x;
1742                                }
1743                                mid.fill(1.0);
1744                            }
1745                            Err(_) => {
1746                                window.fill(1.0);
1747                            }
1748                        },
1749                        None => {
1750                            window.fill(1.0);
1751                        }
1752                    }
1753                }
1754                _ => {
1755                    Self::Tukey(0.5).generate(window);
1756                }
1757            },
1758        }
1759    }
1760
1761    fn apply<'w>(
1762        &self,
1763        window: &mut Vec<f64>,
1764        cache: &'w mut Vec<f64>,
1765        samples: &[i32],
1766    ) -> &'w [f64] {
1767        if window.len() != samples.len() {
1768            // need to re-generate window to fit samples
1769            window.resize(samples.len(), 0.0);
1770            self.generate(window);
1771        }
1772
1773        // window signal into cache and return cached slice
1774        cache.clear();
1775        cache.extend(samples.iter().zip(window).map(|(s, w)| f64::from(*s) * *w));
1776        cache.as_slice()
1777    }
1778}
1779
1780impl Default for Window {
1781    fn default() -> Self {
1782        Self::Tukey(0.5)
1783    }
1784}
1785
1786#[derive(Default)]
1787struct EncodingCaches {
1788    channels: Vec<ChannelCache>,
1789    correlated: CorrelationCache,
1790}
1791
1792#[derive(Default)]
1793struct CorrelationCache {
1794    // the average channel samples
1795    average_samples: Vec<i32>,
1796    // the difference channel samples
1797    difference_samples: Vec<i32>,
1798
1799    left_cache: ChannelCache,
1800    right_cache: ChannelCache,
1801    average_cache: ChannelCache,
1802    difference_cache: ChannelCache,
1803}
1804
1805#[derive(Default)]
1806struct ChannelCache {
1807    fixed: FixedCache,
1808    fixed_output: BitRecorder<u32, BigEndian>,
1809    lpc: LpcCache,
1810    lpc_output: BitRecorder<u32, BigEndian>,
1811    constant_output: BitRecorder<u32, BigEndian>,
1812    verbatim_output: BitRecorder<u32, BigEndian>,
1813    wasted: Vec<i32>,
1814}
1815
1816#[derive(Default)]
1817struct FixedCache {
1818    // FIXED subframe buffers, one per order 1-4
1819    fixed_buffers: [Vec<i32>; 4],
1820}
1821
1822#[derive(Default)]
1823struct LpcCache {
1824    window: Vec<f64>,
1825    windowed: Vec<f64>,
1826    residuals: Vec<i32>,
1827}
1828
1829/// A FLAC encoder
1830struct Encoder<W: std::io::Write + std::io::Seek> {
1831    // the writer we're outputting to
1832    writer: Counter<W>,
1833    // the stream's starting offset in the writer, in bytes
1834    start: u64,
1835    // various encoding options
1836    options: EncoderOptions,
1837    // various encoder caches
1838    caches: EncodingCaches,
1839    // our metadata blocks
1840    blocks: BlockList,
1841    // our stream's sample rate
1842    sample_rate: SampleRate<u32>,
1843    // the current frame number
1844    frame_number: FrameNumber,
1845    // the number of channel-independent samples written
1846    samples_written: u64,
1847    // all seekpoints
1848    seekpoints: Vec<EncoderSeekPoint>,
1849    // our running MD5 calculation
1850    md5: md5::Context,
1851    // whether the encoder has finalized the file
1852    finalized: bool,
1853}
1854
1855impl<W: std::io::Write + std::io::Seek> Encoder<W> {
1856    const MAX_SAMPLES: u64 = 68_719_476_736;
1857
1858    fn new(
1859        mut writer: W,
1860        options: Options,
1861        sample_rate: u32,
1862        bits_per_sample: SignedBitCount<32>,
1863        channels: u8,
1864        total_samples: Option<NonZero<u64>>,
1865    ) -> Result<Self, Error> {
1866        use crate::metadata::OptionalBlockType;
1867
1868        let mut blocks = options.metadata;
1869
1870        *blocks.streaminfo_mut() = Streaminfo {
1871            minimum_block_size: options.block_size,
1872            maximum_block_size: options.block_size,
1873            minimum_frame_size: None,
1874            maximum_frame_size: None,
1875            sample_rate: (0..1048576)
1876                .contains(&sample_rate)
1877                .then_some(sample_rate)
1878                .ok_or(Error::InvalidSampleRate)?,
1879            bits_per_sample,
1880            channels: (1..=8)
1881                .contains(&channels)
1882                .then_some(channels)
1883                .and_then(NonZero::new)
1884                .ok_or(Error::ExcessiveChannels)?,
1885            total_samples: match total_samples {
1886                None => None,
1887                total_samples @ Some(samples) => match samples.get() {
1888                    0..Self::MAX_SAMPLES => total_samples,
1889                    _ => return Err(Error::ExcessiveTotalSamples),
1890                },
1891            },
1892            md5: None,
1893        };
1894
1895        // insert a dummy SeekTable to be populated later
1896        if let Some(total_samples) = total_samples {
1897            if let Some(placeholders) = options.seektable_interval.map(|s| {
1898                s.filter(
1899                    sample_rate,
1900                    EncoderSeekPoint::placeholders(total_samples.get(), options.block_size),
1901                )
1902            }) {
1903                use crate::metadata::SeekTable;
1904
1905                blocks.insert(SeekTable {
1906                    // placeholder points should always be contiguous
1907                    points: placeholders
1908                        .take(SeekTable::MAX_POINTS)
1909                        .map(|p| p.into())
1910                        .collect::<Vec<_>>()
1911                        .try_into()
1912                        .unwrap(),
1913                });
1914            }
1915        }
1916
1917        let start = writer.stream_position()?;
1918
1919        // sort blocks to put more relevant items at the front
1920        blocks.sort_by(|block| match block {
1921            OptionalBlockType::VorbisComment => 0,
1922            OptionalBlockType::SeekTable => 1,
1923            OptionalBlockType::Picture => 2,
1924            OptionalBlockType::Application => 3,
1925            OptionalBlockType::Cuesheet => 4,
1926            OptionalBlockType::Padding => 5,
1927        });
1928
1929        write_blocks(writer.by_ref(), blocks.blocks())?;
1930
1931        Ok(Self {
1932            start,
1933            writer: Counter::new(writer),
1934            options: EncoderOptions {
1935                max_partition_order: options.max_partition_order,
1936                mid_side: options.mid_side,
1937                seektable_interval: options.seektable_interval,
1938                max_lpc_order: options.max_lpc_order,
1939                window: options.window,
1940                exhaustive_channel_correlation: options.exhaustive_channel_correlation,
1941                use_rice2: u32::from(bits_per_sample) > 16,
1942            },
1943            caches: EncodingCaches::default(),
1944            sample_rate: blocks
1945                .streaminfo()
1946                .sample_rate
1947                .try_into()
1948                .expect("invalid sample rate"),
1949            blocks,
1950            frame_number: FrameNumber::default(),
1951            samples_written: 0,
1952            seekpoints: Vec::new(),
1953            md5: md5::Context::new(),
1954            finalized: false,
1955        })
1956    }
1957
1958    /// The encoder's channel count
1959    fn channel_count(&self) -> NonZero<u8> {
1960        self.blocks.streaminfo().channels
1961    }
1962
1963    /// Updates running MD5 calculation with signed, little-endian bytes
1964    fn update_md5(&mut self, frame_bytes: &[u8]) {
1965        self.md5.consume(frame_bytes)
1966    }
1967
1968    /// Encodes an audio frame of PCM samples
1969    ///
1970    /// Depending on the encoder's chosen block size,
1971    /// this may encode zero or more FLAC frames to disk.
1972    ///
1973    /// # Errors
1974    ///
1975    /// Returns an I/O error from the underlying stream,
1976    /// or if the frame's parameters are not a match
1977    /// for the encoder's.
1978    fn encode(&mut self, frame: &Frame) -> Result<(), Error> {
1979        // drop in a new seekpoint
1980        self.seekpoints.push(EncoderSeekPoint {
1981            sample_offset: self.samples_written,
1982            byte_offset: Some(self.writer.count),
1983            frame_samples: frame.pcm_frames() as u16,
1984        });
1985
1986        // update running total of samples written
1987        self.samples_written += frame.pcm_frames() as u64;
1988        if let Some(total_samples) = self.blocks.streaminfo().total_samples {
1989            if self.samples_written > total_samples.get() {
1990                return Err(Error::ExcessiveTotalSamples);
1991            }
1992        }
1993
1994        encode_frame(
1995            &self.options,
1996            &mut self.caches,
1997            &mut self.writer,
1998            self.blocks.streaminfo_mut(),
1999            &mut self.frame_number,
2000            self.sample_rate,
2001            frame.channels().collect(),
2002        )
2003    }
2004
2005    fn finalize_inner(&mut self) -> Result<(), Error> {
2006        if !self.finalized {
2007            use crate::metadata::SeekTable;
2008
2009            self.finalized = true;
2010
2011            // update SEEKTABLE metadata block with final values
2012            if let Some(encoded_points) = self
2013                .options
2014                .seektable_interval
2015                .map(|s| s.filter(self.sample_rate.into(), self.seekpoints.iter().cloned()))
2016            {
2017                match self.blocks.get_pair_mut() {
2018                    (Some(SeekTable { points }), _) => {
2019                        // placeholder SEEKTABLE already in place,
2020                        // so no need to adjust PADDING to fit
2021
2022                        // ensure points count is the same
2023
2024                        let points_len = points.len();
2025                        points.clear();
2026                        points
2027                            .try_extend(
2028                                encoded_points
2029                                    .into_iter()
2030                                    .map(|p| p.into())
2031                                    .chain(std::iter::repeat(SeekPoint::Placeholder))
2032                                    .take(points_len),
2033                            )
2034                            .unwrap();
2035                    }
2036                    (None, Some(crate::metadata::Padding { size: padding_size })) => {
2037                        // no SEEKTABLE, but there is a PADDING block,
2038                        // so try to shrink PADDING to fit SEEKTABLE
2039
2040                        use crate::metadata::MetadataBlock;
2041
2042                        let seektable = SeekTable {
2043                            points: encoded_points
2044                                .map(|p| p.into())
2045                                .collect::<Vec<_>>()
2046                                .try_into()
2047                                .unwrap(),
2048                        };
2049                        if let Some(new_padding_size) = seektable
2050                            .total_size()
2051                            .and_then(|seektable_size| padding_size.checked_sub(seektable_size))
2052                        {
2053                            *padding_size = new_padding_size;
2054                            self.blocks.insert(seektable);
2055                        }
2056                    }
2057                    (None, None) => { /* no seektable or padding, so nothing to do */ }
2058                }
2059            }
2060
2061            // verify or update final sample count
2062            match &mut self.blocks.streaminfo_mut().total_samples {
2063                Some(expected) => {
2064                    // ensure final sample count matches
2065                    if expected.get() != self.samples_written {
2066                        return Err(Error::SampleCountMismatch);
2067                    }
2068                }
2069                expected @ None => {
2070                    // update final sample count if possible
2071                    if self.samples_written < Self::MAX_SAMPLES {
2072                        *expected =
2073                            Some(NonZero::new(self.samples_written).ok_or(Error::NoSamples)?);
2074                    } else {
2075                        // TODO - should I just leave this blank
2076                        // if too many samples are written?
2077                        return Err(Error::ExcessiveTotalSamples);
2078                    }
2079                }
2080            }
2081
2082            // update STREAMINFO MD5 sum
2083            self.blocks.streaminfo_mut().md5 = Some(self.md5.clone().compute().0);
2084
2085            // rewrite metadata blocks, relative to the beginning
2086            // of the stream
2087            let writer = self.writer.stream();
2088            writer.seek(std::io::SeekFrom::Start(self.start))?;
2089            write_blocks(writer.by_ref(), self.blocks.blocks())
2090        } else {
2091            Ok(())
2092        }
2093    }
2094}
2095
2096impl<W: std::io::Write + std::io::Seek> Drop for Encoder<W> {
2097    fn drop(&mut self) {
2098        let _ = self.finalize_inner();
2099    }
2100}
2101
2102// Unlike regular SeekPoints, which can have placeholder variants,
2103// these are always defined to be something.  A byte offset
2104// of None indicates a dummy encoder point
2105#[derive(Debug, Clone)]
2106struct EncoderSeekPoint {
2107    sample_offset: u64,
2108    byte_offset: Option<u64>,
2109    frame_samples: u16,
2110}
2111
2112impl EncoderSeekPoint {
2113    // generates set of placeholder points
2114    fn placeholders(total_samples: u64, block_size: u16) -> impl Iterator<Item = EncoderSeekPoint> {
2115        (0..total_samples)
2116            .step_by(usize::from(block_size))
2117            .map(move |sample_offset| EncoderSeekPoint {
2118                sample_offset,
2119                byte_offset: None,
2120                frame_samples: u16::try_from(total_samples - sample_offset)
2121                    .map(|s| s.min(block_size))
2122                    .unwrap_or(block_size),
2123            })
2124    }
2125
2126    // returns sample range of point
2127    fn range(&self) -> std::ops::Range<u64> {
2128        self.sample_offset..(self.sample_offset + u64::from(self.frame_samples))
2129    }
2130}
2131
2132impl From<EncoderSeekPoint> for SeekPoint {
2133    fn from(p: EncoderSeekPoint) -> Self {
2134        match p.byte_offset {
2135            Some(byte_offset) => Self::Defined {
2136                sample_offset: p.sample_offset,
2137                byte_offset,
2138                frame_samples: p.frame_samples,
2139            },
2140            None => Self::Placeholder,
2141        }
2142    }
2143}
2144
2145/// Given a FLAC stream, generates new seek table
2146///
2147/// Though encoders should add seek tables by default,
2148/// sometimes one isn't present.  This function takes
2149/// an existing FLAC file stream and generates a new
2150/// seek table suitable for adding to the file's metadata
2151/// via the [`crate::metadata::update`] function.
2152///
2153/// The stream should be rewound to the beginning of the file.
2154///
2155/// # Errors
2156///
2157/// Returns any error from the underlying stream.
2158///
2159/// # Example
2160/// ```
2161/// use flac_codec::{
2162///     encode::{FlacSampleWriter, Options, SeekTableInterval, generate_seektable},
2163///     metadata::{SeekTable, read_block},
2164/// };
2165/// use std::io::{Cursor, Seek};
2166///
2167/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
2168///
2169/// // add a seekpoint every second
2170/// let options = Options::default().seektable_seconds(1);
2171///
2172/// let mut writer = FlacSampleWriter::new(
2173///     &mut flac,         // our wrapped writer
2174///     options,           // our seektable options
2175///     44100,             // sample rate
2176///     16,                // bits-per-sample
2177///     1,                 // channel count
2178///     Some(60 * 44100),  // one minute's worth of samples
2179/// ).unwrap();
2180///
2181/// // write one minute's worth of samples
2182/// writer.write(vec![0; 60 * 44100].as_slice()).unwrap();
2183///
2184/// // finalize writing file
2185/// assert!(writer.finalize().is_ok());
2186///
2187/// flac.rewind().unwrap();
2188///
2189/// // get existing seektable
2190/// let original_seektable = match read_block::<_, SeekTable>(&mut flac) {
2191///     Ok(Some(seektable)) => seektable,
2192///     _ => panic!("seektable not found"),
2193/// };
2194///
2195/// flac.rewind().unwrap();
2196///
2197/// // generate new seektable, also with seekpoints every second
2198/// let new_seektable = generate_seektable(
2199///     flac,
2200///     SeekTableInterval::Seconds(1.try_into().unwrap())
2201/// ).unwrap();
2202///
2203/// // ensure both seektables are identical
2204/// assert_eq!(original_seektable, new_seektable);
2205/// ```
2206pub fn generate_seektable<R: std::io::Read>(
2207    r: R,
2208    interval: SeekTableInterval,
2209) -> Result<crate::metadata::SeekTable, Error> {
2210    use crate::{
2211        metadata::{Metadata, SeekTable},
2212        stream::FrameIterator,
2213    };
2214
2215    let iter = FrameIterator::new(r)?;
2216    let metadata_len = iter.metadata_len();
2217    let sample_rate = iter.sample_rate();
2218    let mut sample_offset = 0;
2219
2220    iter.map(|r| {
2221        r.map(|(frame, offset)| EncoderSeekPoint {
2222            sample_offset,
2223            byte_offset: Some(offset - metadata_len),
2224            frame_samples: frame.header.block_size.into(),
2225        })
2226        .inspect(|p| {
2227            sample_offset += u64::from(p.frame_samples);
2228        })
2229    })
2230    .collect::<Result<Vec<_>, _>>()
2231    .map(|seekpoints| SeekTable {
2232        points: interval
2233            .filter(sample_rate, seekpoints)
2234            .take(SeekTable::MAX_POINTS)
2235            .map(|p| p.into())
2236            .collect::<Vec<_>>()
2237            .try_into()
2238            .unwrap(),
2239    })
2240}
2241
2242fn encode_frame<W>(
2243    options: &EncoderOptions,
2244    cache: &mut EncodingCaches,
2245    mut writer: W,
2246    streaminfo: &mut Streaminfo,
2247    frame_number: &mut FrameNumber,
2248    sample_rate: SampleRate<u32>,
2249    frame: ArrayVec<&[i32], MAX_CHANNELS>,
2250) -> Result<(), Error>
2251where
2252    W: std::io::Write,
2253{
2254    use crate::Counter;
2255    use crate::crc::{Crc16, CrcWriter};
2256    use crate::stream::FrameHeader;
2257    use bitstream_io::BigEndian;
2258
2259    debug_assert!(!frame.is_empty());
2260
2261    let size = Counter::new(writer.by_ref());
2262    let mut w: CrcWriter<_, Crc16> = CrcWriter::new(size);
2263    let mut bw: BitWriter<CrcWriter<Counter<&mut W>, Crc16>, BigEndian>;
2264
2265    match frame.as_slice() {
2266        [channel] => {
2267            FrameHeader {
2268                blocking_strategy: false,
2269                frame_number: *frame_number,
2270                block_size: (channel.len() as u16)
2271                    .try_into()
2272                    .expect("frame cannot be empty"),
2273                sample_rate,
2274                bits_per_sample: streaminfo.bits_per_sample.into(),
2275                channel_assignment: ChannelAssignment::Independent(Independent::Mono),
2276            }
2277            .write(&mut w, streaminfo)?;
2278
2279            bw = BitWriter::new(w);
2280
2281            cache.channels.resize_with(1, ChannelCache::default);
2282
2283            encode_subframe(
2284                options,
2285                &mut cache.channels[0],
2286                CorrelatedChannel::independent(streaminfo.bits_per_sample, channel),
2287            )?
2288            .playback(&mut bw)?;
2289        }
2290        [left, right] if options.exhaustive_channel_correlation => {
2291            let Correlated {
2292                channel_assignment,
2293                channels: [channel_0, channel_1],
2294            } = correlate_channels_exhaustive(
2295                options,
2296                &mut cache.correlated,
2297                [left, right],
2298                streaminfo.bits_per_sample,
2299            )?;
2300
2301            FrameHeader {
2302                blocking_strategy: false,
2303                frame_number: *frame_number,
2304                block_size: (frame[0].len() as u16)
2305                    .try_into()
2306                    .expect("frame cannot be empty"),
2307                sample_rate,
2308                bits_per_sample: streaminfo.bits_per_sample.into(),
2309                channel_assignment,
2310            }
2311            .write(&mut w, streaminfo)?;
2312
2313            bw = BitWriter::new(w);
2314
2315            channel_0.playback(&mut bw)?;
2316            channel_1.playback(&mut bw)?;
2317        }
2318        [left, right] => {
2319            let Correlated {
2320                channel_assignment,
2321                channels: [channel_0, channel_1],
2322            } = correlate_channels(
2323                options,
2324                &mut cache.correlated,
2325                [left, right],
2326                streaminfo.bits_per_sample,
2327            );
2328
2329            FrameHeader {
2330                blocking_strategy: false,
2331                frame_number: *frame_number,
2332                block_size: (frame[0].len() as u16)
2333                    .try_into()
2334                    .expect("frame cannot be empty"),
2335                sample_rate,
2336                bits_per_sample: streaminfo.bits_per_sample.into(),
2337                channel_assignment,
2338            }
2339            .write(&mut w, streaminfo)?;
2340
2341            cache.channels.resize_with(2, ChannelCache::default);
2342            let [cache_0, cache_1] = cache.channels.get_disjoint_mut([0, 1]).unwrap();
2343            let (channel_0, channel_1) = join(
2344                || encode_subframe(options, cache_0, channel_0),
2345                || encode_subframe(options, cache_1, channel_1),
2346            );
2347
2348            bw = BitWriter::new(w);
2349
2350            channel_0?.playback(&mut bw)?;
2351            channel_1?.playback(&mut bw)?;
2352        }
2353        channels => {
2354            // non-stereo frames are always encoded independently
2355
2356            FrameHeader {
2357                blocking_strategy: false,
2358                frame_number: *frame_number,
2359                block_size: (channels[0].len() as u16)
2360                    .try_into()
2361                    .expect("frame cannot be empty"),
2362                sample_rate,
2363                bits_per_sample: streaminfo.bits_per_sample.into(),
2364                channel_assignment: ChannelAssignment::Independent(
2365                    frame.len().try_into().expect("invalid channel count"),
2366                ),
2367            }
2368            .write(&mut w, streaminfo)?;
2369
2370            bw = BitWriter::new(w);
2371
2372            cache
2373                .channels
2374                .resize_with(channels.len(), ChannelCache::default);
2375
2376            vec_map(
2377                cache.channels.iter_mut().zip(channels).collect(),
2378                |(cache, channel)| {
2379                    encode_subframe(
2380                        options,
2381                        cache,
2382                        CorrelatedChannel::independent(streaminfo.bits_per_sample, channel),
2383                    )
2384                },
2385            )
2386            .into_iter()
2387            .try_for_each(|r| r.and_then(|r| r.playback(bw.by_ref()).map_err(Error::Io)))?;
2388        }
2389    }
2390
2391    let crc16: u16 = bw.aligned_writer()?.checksum().into();
2392    bw.write_from(crc16)?;
2393
2394    frame_number.try_increment()?;
2395
2396    // update minimum and maximum frame size values
2397    if let s @ Some(size) = u32::try_from(bw.into_writer().into_writer().count)
2398        .ok()
2399        .filter(|size| *size < Streaminfo::MAX_FRAME_SIZE)
2400        .and_then(NonZero::new)
2401    {
2402        match &mut streaminfo.minimum_frame_size {
2403            Some(min_size) => {
2404                *min_size = size.min(*min_size);
2405            }
2406            min_size @ None => {
2407                *min_size = s;
2408            }
2409        }
2410
2411        match &mut streaminfo.maximum_frame_size {
2412            Some(max_size) => {
2413                *max_size = size.max(*max_size);
2414            }
2415            max_size @ None => {
2416                *max_size = s;
2417            }
2418        }
2419    }
2420
2421    Ok(())
2422}
2423
2424struct Correlated<C> {
2425    channel_assignment: ChannelAssignment,
2426    channels: [C; 2],
2427}
2428
2429struct CorrelatedChannel<'c> {
2430    samples: &'c [i32],
2431    bits_per_sample: SignedBitCount<32>,
2432    // whether all samples are known to be 0
2433    all_0: bool,
2434}
2435
2436impl<'c> CorrelatedChannel<'c> {
2437    fn independent(bits_per_sample: SignedBitCount<32>, samples: &'c [i32]) -> Self {
2438        Self {
2439            all_0: samples.iter().all(|s| *s == 0),
2440            bits_per_sample,
2441            samples,
2442        }
2443    }
2444}
2445
2446fn correlate_channels<'c>(
2447    options: &EncoderOptions,
2448    CorrelationCache {
2449        average_samples,
2450        difference_samples,
2451        ..
2452    }: &'c mut CorrelationCache,
2453    [left, right]: [&'c [i32]; 2],
2454    bits_per_sample: SignedBitCount<32>,
2455) -> Correlated<CorrelatedChannel<'c>> {
2456    match bits_per_sample.checked_add::<32>(1) {
2457        Some(difference_bits_per_sample) if options.mid_side => {
2458            let mut left_abs_sum = 0;
2459            let mut right_abs_sum = 0;
2460            let mut mid_abs_sum = 0;
2461            let mut side_abs_sum = 0;
2462
2463            join(
2464                || {
2465                    average_samples.clear();
2466                    average_samples.extend(
2467                        left.iter()
2468                            .inspect(|s| left_abs_sum += u64::from(s.unsigned_abs()))
2469                            .zip(
2470                                right
2471                                    .iter()
2472                                    .inspect(|s| right_abs_sum += u64::from(s.unsigned_abs())),
2473                            )
2474                            .map(|(l, r)| (l + r) >> 1)
2475                            .inspect(|s| mid_abs_sum += u64::from(s.unsigned_abs())),
2476                    );
2477                },
2478                || {
2479                    difference_samples.clear();
2480                    difference_samples.extend(
2481                        left.iter()
2482                            .zip(right)
2483                            .map(|(l, r)| l - r)
2484                            .inspect(|s| side_abs_sum += u64::from(s.unsigned_abs())),
2485                    );
2486                },
2487            );
2488
2489            match [
2490                (
2491                    ChannelAssignment::Independent(Independent::Stereo),
2492                    left_abs_sum + right_abs_sum,
2493                ),
2494                (ChannelAssignment::LeftSide, left_abs_sum + side_abs_sum),
2495                (ChannelAssignment::SideRight, side_abs_sum + right_abs_sum),
2496                (ChannelAssignment::MidSide, mid_abs_sum + side_abs_sum),
2497            ]
2498            .into_iter()
2499            .min_by_key(|(_, total)| *total)
2500            .unwrap()
2501            .0
2502            {
2503                channel_assignment @ ChannelAssignment::LeftSide => Correlated {
2504                    channel_assignment,
2505                    channels: [
2506                        CorrelatedChannel {
2507                            samples: left,
2508                            bits_per_sample,
2509                            all_0: left_abs_sum == 0,
2510                        },
2511                        CorrelatedChannel {
2512                            samples: difference_samples,
2513                            bits_per_sample: difference_bits_per_sample,
2514                            all_0: side_abs_sum == 0,
2515                        },
2516                    ],
2517                },
2518                channel_assignment @ ChannelAssignment::SideRight => Correlated {
2519                    channel_assignment,
2520                    channels: [
2521                        CorrelatedChannel {
2522                            samples: difference_samples,
2523                            bits_per_sample: difference_bits_per_sample,
2524                            all_0: side_abs_sum == 0,
2525                        },
2526                        CorrelatedChannel {
2527                            samples: right,
2528                            bits_per_sample,
2529                            all_0: right_abs_sum == 0,
2530                        },
2531                    ],
2532                },
2533                channel_assignment @ ChannelAssignment::MidSide => Correlated {
2534                    channel_assignment,
2535                    channels: [
2536                        CorrelatedChannel {
2537                            samples: average_samples,
2538                            bits_per_sample,
2539                            all_0: mid_abs_sum == 0,
2540                        },
2541                        CorrelatedChannel {
2542                            samples: difference_samples,
2543                            bits_per_sample: difference_bits_per_sample,
2544                            all_0: side_abs_sum == 0,
2545                        },
2546                    ],
2547                },
2548                channel_assignment @ ChannelAssignment::Independent(_) => Correlated {
2549                    channel_assignment,
2550                    channels: [
2551                        CorrelatedChannel {
2552                            samples: left,
2553                            bits_per_sample,
2554                            all_0: left_abs_sum == 0,
2555                        },
2556                        CorrelatedChannel {
2557                            samples: right,
2558                            bits_per_sample,
2559                            all_0: right_abs_sum == 0,
2560                        },
2561                    ],
2562                },
2563            }
2564        }
2565        Some(difference_bits_per_sample) => {
2566            let mut left_abs_sum = 0;
2567            let mut right_abs_sum = 0;
2568            let mut side_abs_sum = 0;
2569
2570            difference_samples.clear();
2571            difference_samples.extend(
2572                left.iter()
2573                    .inspect(|s| left_abs_sum += u64::from(s.unsigned_abs()))
2574                    .zip(
2575                        right
2576                            .iter()
2577                            .inspect(|s| right_abs_sum += u64::from(s.unsigned_abs())),
2578                    )
2579                    .map(|(l, r)| l - r)
2580                    .inspect(|s| side_abs_sum += u64::from(s.unsigned_abs())),
2581            );
2582
2583            match [
2584                (ChannelAssignment::LeftSide, left_abs_sum + side_abs_sum),
2585                (ChannelAssignment::SideRight, side_abs_sum + right_abs_sum),
2586                (
2587                    ChannelAssignment::Independent(Independent::Stereo),
2588                    left_abs_sum + right_abs_sum,
2589                ),
2590            ]
2591            .into_iter()
2592            .min_by_key(|(_, total)| *total)
2593            .unwrap()
2594            .0
2595            {
2596                channel_assignment @ ChannelAssignment::LeftSide => Correlated {
2597                    channel_assignment,
2598                    channels: [
2599                        CorrelatedChannel {
2600                            samples: left,
2601                            bits_per_sample,
2602                            all_0: left_abs_sum == 0,
2603                        },
2604                        CorrelatedChannel {
2605                            samples: difference_samples,
2606                            bits_per_sample: difference_bits_per_sample,
2607                            all_0: side_abs_sum == 0,
2608                        },
2609                    ],
2610                },
2611                channel_assignment @ ChannelAssignment::SideRight => Correlated {
2612                    channel_assignment,
2613                    channels: [
2614                        CorrelatedChannel {
2615                            samples: difference_samples,
2616                            bits_per_sample: difference_bits_per_sample,
2617                            all_0: side_abs_sum == 0,
2618                        },
2619                        CorrelatedChannel {
2620                            samples: right,
2621                            bits_per_sample,
2622                            all_0: right_abs_sum == 0,
2623                        },
2624                    ],
2625                },
2626                ChannelAssignment::MidSide => unreachable!(),
2627                channel_assignment @ ChannelAssignment::Independent(_) => Correlated {
2628                    channel_assignment,
2629                    channels: [
2630                        CorrelatedChannel {
2631                            samples: left,
2632                            bits_per_sample,
2633                            all_0: left_abs_sum == 0,
2634                        },
2635                        CorrelatedChannel {
2636                            samples: right,
2637                            bits_per_sample,
2638                            all_0: right_abs_sum == 0,
2639                        },
2640                    ],
2641                },
2642            }
2643        }
2644        None => {
2645            // 32 bps stream, so forego difference channel
2646            // and encode them both indepedently
2647
2648            Correlated {
2649                channel_assignment: ChannelAssignment::Independent(Independent::Stereo),
2650                channels: [
2651                    CorrelatedChannel::independent(bits_per_sample, left),
2652                    CorrelatedChannel::independent(bits_per_sample, right),
2653                ],
2654            }
2655        }
2656    }
2657}
2658
2659fn correlate_channels_exhaustive<'c>(
2660    options: &EncoderOptions,
2661    CorrelationCache {
2662        average_samples,
2663        difference_samples,
2664        left_cache,
2665        right_cache,
2666        average_cache,
2667        difference_cache,
2668        ..
2669    }: &'c mut CorrelationCache,
2670    [left, right]: [&'c [i32]; 2],
2671    bits_per_sample: SignedBitCount<32>,
2672) -> Result<Correlated<&'c BitRecorder<u32, BigEndian>>, Error> {
2673    let (left_recorder, right_recorder) = try_join(
2674        || {
2675            encode_subframe(
2676                options,
2677                left_cache,
2678                CorrelatedChannel {
2679                    samples: left,
2680                    bits_per_sample,
2681                    all_0: false,
2682                },
2683            )
2684        },
2685        || {
2686            encode_subframe(
2687                options,
2688                right_cache,
2689                CorrelatedChannel {
2690                    samples: right,
2691                    bits_per_sample,
2692                    all_0: false,
2693                },
2694            )
2695        },
2696    )?;
2697
2698    match bits_per_sample.checked_add::<32>(1) {
2699        Some(difference_bits_per_sample) if options.mid_side => {
2700            let (average_recorder, difference_recorder) = try_join(
2701                || {
2702                    average_samples.clear();
2703                    average_samples
2704                        .extend(left.iter().zip(right.iter()).map(|(l, r)| (l + r) >> 1));
2705                    encode_subframe(
2706                        options,
2707                        average_cache,
2708                        CorrelatedChannel {
2709                            samples: average_samples,
2710                            bits_per_sample,
2711                            all_0: false,
2712                        },
2713                    )
2714                },
2715                || {
2716                    difference_samples.clear();
2717                    difference_samples.extend(left.iter().zip(right).map(|(l, r)| l - r));
2718                    encode_subframe(
2719                        options,
2720                        difference_cache,
2721                        CorrelatedChannel {
2722                            samples: difference_samples,
2723                            bits_per_sample: difference_bits_per_sample,
2724                            all_0: false,
2725                        },
2726                    )
2727                },
2728            )?;
2729
2730            match [
2731                (
2732                    ChannelAssignment::Independent(Independent::Stereo),
2733                    left_recorder.written() + right_recorder.written(),
2734                ),
2735                (
2736                    ChannelAssignment::LeftSide,
2737                    left_recorder.written() + difference_recorder.written(),
2738                ),
2739                (
2740                    ChannelAssignment::SideRight,
2741                    difference_recorder.written() + right_recorder.written(),
2742                ),
2743                (
2744                    ChannelAssignment::MidSide,
2745                    average_recorder.written() + difference_recorder.written(),
2746                ),
2747            ]
2748            .into_iter()
2749            .min_by_key(|(_, total)| *total)
2750            .unwrap()
2751            .0
2752            {
2753                channel_assignment @ ChannelAssignment::LeftSide => Ok(Correlated {
2754                    channel_assignment,
2755                    channels: [left_recorder, difference_recorder],
2756                }),
2757                channel_assignment @ ChannelAssignment::SideRight => Ok(Correlated {
2758                    channel_assignment,
2759                    channels: [difference_recorder, right_recorder],
2760                }),
2761                channel_assignment @ ChannelAssignment::MidSide => Ok(Correlated {
2762                    channel_assignment,
2763                    channels: [average_recorder, difference_recorder],
2764                }),
2765                channel_assignment @ ChannelAssignment::Independent(_) => Ok(Correlated {
2766                    channel_assignment,
2767                    channels: [left_recorder, right_recorder],
2768                }),
2769            }
2770        }
2771        Some(difference_bits_per_sample) => {
2772            let difference_recorder = {
2773                difference_samples.clear();
2774                difference_samples.extend(left.iter().zip(right).map(|(l, r)| l - r));
2775                encode_subframe(
2776                    options,
2777                    difference_cache,
2778                    CorrelatedChannel {
2779                        samples: difference_samples,
2780                        bits_per_sample: difference_bits_per_sample,
2781                        all_0: false,
2782                    },
2783                )?
2784            };
2785
2786            match [
2787                (
2788                    ChannelAssignment::Independent(Independent::Stereo),
2789                    left_recorder.written() + right_recorder.written(),
2790                ),
2791                (
2792                    ChannelAssignment::LeftSide,
2793                    left_recorder.written() + difference_recorder.written(),
2794                ),
2795                (
2796                    ChannelAssignment::SideRight,
2797                    difference_recorder.written() + right_recorder.written(),
2798                ),
2799            ]
2800            .into_iter()
2801            .min_by_key(|(_, total)| *total)
2802            .unwrap()
2803            .0
2804            {
2805                channel_assignment @ ChannelAssignment::LeftSide => Ok(Correlated {
2806                    channel_assignment,
2807                    channels: [left_recorder, difference_recorder],
2808                }),
2809                channel_assignment @ ChannelAssignment::SideRight => Ok(Correlated {
2810                    channel_assignment,
2811                    channels: [difference_recorder, right_recorder],
2812                }),
2813                ChannelAssignment::MidSide => unreachable!(),
2814                channel_assignment @ ChannelAssignment::Independent(_) => Ok(Correlated {
2815                    channel_assignment,
2816                    channels: [left_recorder, right_recorder],
2817                }),
2818            }
2819        }
2820        None => {
2821            // 32 bps stream, so forego difference channel
2822            // and encode them both indepedently
2823
2824            Ok(Correlated {
2825                channel_assignment: ChannelAssignment::Independent(Independent::Stereo),
2826                channels: [left_recorder, right_recorder],
2827            })
2828        }
2829    }
2830}
2831
2832fn encode_subframe<'c>(
2833    options: &EncoderOptions,
2834    ChannelCache {
2835        fixed: fixed_cache,
2836        fixed_output,
2837        lpc: lpc_cache,
2838        lpc_output,
2839        constant_output,
2840        verbatim_output,
2841        wasted,
2842    }: &'c mut ChannelCache,
2843    CorrelatedChannel {
2844        samples: channel,
2845        bits_per_sample,
2846        all_0,
2847    }: CorrelatedChannel,
2848) -> Result<&'c BitRecorder<u32, BigEndian>, Error> {
2849    const WASTED_MAX: NonZero<u32> = NonZero::new(32).unwrap();
2850
2851    debug_assert!(!channel.is_empty());
2852
2853    if all_0 {
2854        // all samples are 0
2855        constant_output.clear();
2856        encode_constant_subframe(constant_output, channel[0], bits_per_sample, 0)?;
2857        return Ok(constant_output);
2858    }
2859
2860    // determine any wasted bits
2861    let (channel, bits_per_sample, wasted_bps) =
2862        match channel.iter().try_fold(WASTED_MAX, |acc, sample| {
2863            NonZero::new(sample.trailing_zeros()).map(|sample| sample.min(acc))
2864        }) {
2865            None => (channel, bits_per_sample, 0),
2866            Some(WASTED_MAX) => {
2867                constant_output.clear();
2868                encode_constant_subframe(constant_output, channel[0], bits_per_sample, 0)?;
2869                return Ok(constant_output);
2870            }
2871            Some(wasted_bps) => {
2872                let wasted_bps = wasted_bps.get();
2873                wasted.clear();
2874                wasted.extend(channel.iter().map(|sample| sample >> wasted_bps));
2875                (
2876                    wasted.as_slice(),
2877                    bits_per_sample.checked_sub(wasted_bps).unwrap(),
2878                    wasted_bps,
2879                )
2880            }
2881        };
2882
2883    fixed_output.clear();
2884
2885    let best = match options.max_lpc_order {
2886        Some(max_lpc_order) => {
2887            lpc_output.clear();
2888
2889            match join(
2890                || {
2891                    encode_fixed_subframe(
2892                        options,
2893                        fixed_cache,
2894                        fixed_output,
2895                        channel,
2896                        bits_per_sample,
2897                        wasted_bps,
2898                    )
2899                },
2900                || {
2901                    encode_lpc_subframe(
2902                        options,
2903                        max_lpc_order,
2904                        lpc_cache,
2905                        lpc_output,
2906                        channel,
2907                        bits_per_sample,
2908                        wasted_bps,
2909                    )
2910                },
2911            ) {
2912                (Ok(()), Ok(())) => [fixed_output, lpc_output]
2913                    .into_iter()
2914                    .min_by_key(|c| c.written())
2915                    .unwrap(),
2916                (Err(_), Ok(())) => lpc_output,
2917                (Ok(()), Err(_)) => fixed_output,
2918                (Err(_), Err(_)) => {
2919                    verbatim_output.clear();
2920                    encode_verbatim_subframe(
2921                        verbatim_output,
2922                        channel,
2923                        bits_per_sample,
2924                        wasted_bps,
2925                    )?;
2926                    return Ok(verbatim_output);
2927                }
2928            }
2929        }
2930        _ => {
2931            match encode_fixed_subframe(
2932                options,
2933                fixed_cache,
2934                fixed_output,
2935                channel,
2936                bits_per_sample,
2937                wasted_bps,
2938            ) {
2939                Ok(()) => fixed_output,
2940                Err(_) => {
2941                    verbatim_output.clear();
2942                    encode_verbatim_subframe(
2943                        verbatim_output,
2944                        channel,
2945                        bits_per_sample,
2946                        wasted_bps,
2947                    )?;
2948                    return Ok(verbatim_output);
2949                }
2950            }
2951        }
2952    };
2953
2954    let verbatim_len = channel.len() as u32 * u32::from(bits_per_sample);
2955
2956    if best.written() < verbatim_len {
2957        Ok(best)
2958    } else {
2959        verbatim_output.clear();
2960        encode_verbatim_subframe(verbatim_output, channel, bits_per_sample, wasted_bps)?;
2961        Ok(verbatim_output)
2962    }
2963}
2964
2965fn encode_constant_subframe<W: BitWrite>(
2966    writer: &mut W,
2967    sample: i32,
2968    bits_per_sample: SignedBitCount<32>,
2969    wasted_bps: u32,
2970) -> Result<(), Error> {
2971    use crate::stream::{SubframeHeader, SubframeHeaderType};
2972
2973    writer.build(&SubframeHeader {
2974        type_: SubframeHeaderType::Constant,
2975        wasted_bps,
2976    })?;
2977
2978    writer
2979        .write_signed_counted(bits_per_sample, sample)
2980        .map_err(Error::Io)
2981}
2982
2983fn encode_verbatim_subframe<W: BitWrite>(
2984    writer: &mut W,
2985    channel: &[i32],
2986    bits_per_sample: SignedBitCount<32>,
2987    wasted_bps: u32,
2988) -> Result<(), Error> {
2989    use crate::stream::{SubframeHeader, SubframeHeaderType};
2990
2991    writer.build(&SubframeHeader {
2992        type_: SubframeHeaderType::Verbatim,
2993        wasted_bps,
2994    })?;
2995
2996    channel
2997        .iter()
2998        .try_for_each(|i| writer.write_signed_counted(bits_per_sample, *i))?;
2999
3000    Ok(())
3001}
3002
3003fn encode_fixed_subframe<W: BitWrite>(
3004    options: &EncoderOptions,
3005    FixedCache {
3006        fixed_buffers: buffers,
3007    }: &mut FixedCache,
3008    writer: &mut W,
3009    channel: &[i32],
3010    bits_per_sample: SignedBitCount<32>,
3011    wasted_bps: u32,
3012) -> Result<(), Error> {
3013    use crate::stream::{SubframeHeader, SubframeHeaderType};
3014
3015    // calculate residuals for FIXED subframe orders 0-4
3016    // (or fewer, if we don't have enough samples)
3017    let (order, warm_up, residuals) = {
3018        let mut fixed_orders = ArrayVec::<&[i32], 5>::new();
3019        fixed_orders.push(channel);
3020
3021        // accumulate a set of FIXED diffs
3022        'outer: for buf in buffers.iter_mut() {
3023            let prev_order = fixed_orders.last().unwrap();
3024            match prev_order.split_at_checked(1) {
3025                Some((_, r)) => {
3026                    buf.clear();
3027                    for (n, p) in r.iter().zip(*prev_order) {
3028                        match n.checked_sub(*p) {
3029                            Some(v) => {
3030                                buf.push(v);
3031                            }
3032                            None => break 'outer,
3033                        }
3034                    }
3035                    if buf.is_empty() {
3036                        break;
3037                    } else {
3038                        fixed_orders.push(buf.as_slice());
3039                    }
3040                }
3041                None => break,
3042            }
3043        }
3044
3045        let min_fixed = fixed_orders.last().unwrap().len();
3046
3047        // choose diff with the smallest abs sum
3048        fixed_orders
3049            .into_iter()
3050            .enumerate()
3051            .min_by_key(|(_, residuals)| {
3052                residuals[(residuals.len() - min_fixed)..]
3053                    .iter()
3054                    .map(|r| u64::from(r.unsigned_abs()))
3055                    .sum::<u64>()
3056            })
3057            .map(|(order, residuals)| (order as u8, &channel[0..order], residuals))
3058            .unwrap()
3059    };
3060
3061    writer.build(&SubframeHeader {
3062        type_: SubframeHeaderType::Fixed { order },
3063        wasted_bps,
3064    })?;
3065
3066    warm_up
3067        .iter()
3068        .try_for_each(|sample: &i32| writer.write_signed_counted(bits_per_sample, *sample))?;
3069
3070    write_residuals(options, writer, order.into(), residuals)
3071}
3072
3073fn encode_lpc_subframe<W: BitWrite>(
3074    options: &EncoderOptions,
3075    max_lpc_order: NonZero<u8>,
3076    cache: &mut LpcCache,
3077    writer: &mut W,
3078    channel: &[i32],
3079    bits_per_sample: SignedBitCount<32>,
3080    wasted_bps: u32,
3081) -> Result<(), Error> {
3082    use crate::stream::{SubframeHeader, SubframeHeaderType};
3083
3084    let LpcSubframeParameters {
3085        warm_up,
3086        residuals,
3087        parameters:
3088            LpcParameters {
3089                order,
3090                precision,
3091                shift,
3092                coefficients,
3093            },
3094    } = LpcSubframeParameters::best(options, bits_per_sample, max_lpc_order, cache, channel)?;
3095
3096    writer.build(&SubframeHeader {
3097        type_: SubframeHeaderType::Lpc { order },
3098        wasted_bps,
3099    })?;
3100
3101    for sample in warm_up {
3102        writer.write_signed_counted(bits_per_sample, *sample)?;
3103    }
3104
3105    writer.write_count::<0b1111>(
3106        precision
3107            .count()
3108            .checked_sub(1)
3109            .ok_or(Error::InvalidQlpPrecision)?,
3110    )?;
3111
3112    writer.write::<5, i32>(shift as i32)?;
3113
3114    for coeff in coefficients {
3115        writer.write_signed_counted(precision, coeff)?;
3116    }
3117
3118    write_residuals(options, writer, order.get().into(), residuals)
3119}
3120
3121struct LpcSubframeParameters<'w, 'r> {
3122    parameters: LpcParameters,
3123    warm_up: &'w [i32],
3124    residuals: &'r [i32],
3125}
3126
3127impl<'w, 'r> LpcSubframeParameters<'w, 'r> {
3128    fn best(
3129        options: &EncoderOptions,
3130        bits_per_sample: SignedBitCount<32>,
3131        max_lpc_order: NonZero<u8>,
3132        LpcCache {
3133            residuals,
3134            window,
3135            windowed,
3136        }: &'r mut LpcCache,
3137        channel: &'w [i32],
3138    ) -> Result<Self, Error> {
3139        let parameters = LpcParameters::best(
3140            options,
3141            bits_per_sample,
3142            max_lpc_order,
3143            window,
3144            windowed,
3145            channel,
3146        )?;
3147
3148        Self::encode_residuals(&parameters, channel, residuals)
3149            .map(|(warm_up, residuals)| Self {
3150                warm_up,
3151                residuals,
3152                parameters,
3153            })
3154            .map_err(|ResidualOverflow| Error::ResidualOverflow)
3155    }
3156
3157    fn encode_residuals(
3158        parameters: &LpcParameters,
3159        channel: &'w [i32],
3160        residuals: &'r mut Vec<i32>,
3161    ) -> Result<(&'w [i32], &'r [i32]), ResidualOverflow> {
3162        residuals.clear();
3163
3164        for split in usize::from(parameters.order.get())..channel.len() {
3165            let (previous, current) = channel.split_at(split);
3166
3167            residuals.push(
3168                current[0]
3169                    .checked_sub(
3170                        (previous
3171                            .iter()
3172                            .rev()
3173                            .zip(&parameters.coefficients)
3174                            .map(|(x, y)| *x as i64 * *y as i64)
3175                            .sum::<i64>()
3176                            >> parameters.shift) as i32,
3177                    )
3178                    .ok_or(ResidualOverflow)?,
3179            );
3180        }
3181
3182        Ok((
3183            &channel[0..parameters.order.get().into()],
3184            residuals.as_slice(),
3185        ))
3186    }
3187}
3188
3189#[derive(Debug)]
3190struct ResidualOverflow;
3191
3192impl From<ResidualOverflow> for Error {
3193    #[inline]
3194    fn from(_: ResidualOverflow) -> Self {
3195        Error::ResidualOverflow
3196    }
3197}
3198
3199#[test]
3200fn test_residual_encoding_1() {
3201    let samples = [
3202        0, 16, 31, 44, 54, 61, 64, 63, 58, 49, 38, 24, 8, -8, -24, -38, -49, -58, -63, -64, -61,
3203        -54, -44, -31, -16,
3204    ];
3205
3206    let expected_residuals = [
3207        2, 2, 2, 3, 3, 3, 2, 2, 3, 0, 0, 0, -1, -1, -1, -3, -2, -2, -2, -1, -1, 0, 0,
3208    ];
3209
3210    let mut actual_residuals = Vec::with_capacity(expected_residuals.len());
3211
3212    let (warm_up, residuals) = LpcSubframeParameters::encode_residuals(
3213        &LpcParameters {
3214            order: NonZero::new(2).unwrap(),
3215            precision: SignedBitCount::new::<7>(),
3216            shift: 5,
3217            coefficients: arrayvec![59, -30],
3218        },
3219        &samples,
3220        &mut actual_residuals,
3221    )
3222    .unwrap();
3223
3224    assert_eq!(warm_up, &samples[0..2]);
3225    assert_eq!(residuals, &expected_residuals);
3226}
3227
3228#[test]
3229fn test_residual_encoding_2() {
3230    let samples = [
3231        64, 62, 56, 47, 34, 20, 4, -12, -27, -41, -52, -60, -63, -63, -60, -52, -41, -27, -12, 4,
3232        20, 34, 47, 56, 62,
3233    ];
3234
3235    let expected_residuals = [
3236        2, 2, 0, 1, -1, -1, -1, -2, -2, -2, -1, -3, -2, 0, -1, 1, 0, 2, 2, 2, 4, 2, 4,
3237    ];
3238
3239    let mut actual_residuals = Vec::with_capacity(expected_residuals.len());
3240
3241    let (warm_up, residuals) = LpcSubframeParameters::encode_residuals(
3242        &LpcParameters {
3243            order: NonZero::new(2).unwrap(),
3244            precision: SignedBitCount::new::<7>(),
3245            shift: 5,
3246            coefficients: arrayvec![58, -29],
3247        },
3248        &samples,
3249        &mut actual_residuals,
3250    )
3251    .unwrap();
3252
3253    assert_eq!(warm_up, &samples[0..2]);
3254    assert_eq!(residuals, &expected_residuals);
3255}
3256
3257#[derive(Debug)]
3258struct LpcParameters {
3259    order: NonZero<u8>,
3260    precision: SignedBitCount<15>,
3261    shift: u32,
3262    coefficients: ArrayVec<i32, MAX_LPC_COEFFS>,
3263}
3264
3265// There isn't any particular *best* way to determine
3266// the ideal LPC subframe parameters (though there are
3267// some worst ways, like choosing them at random).
3268// Even the reference implementation has changed its
3269// defaults over time.  So long as the subframe's residuals
3270// are calculated correctly, decoders don't care one way or another.
3271//
3272// I'll try to use an approach similar to the reference implementation's.
3273
3274impl LpcParameters {
3275    fn best(
3276        options: &EncoderOptions,
3277        bits_per_sample: SignedBitCount<32>,
3278        max_lpc_order: NonZero<u8>,
3279        window: &mut Vec<f64>,
3280        windowed: &mut Vec<f64>,
3281        channel: &[i32],
3282    ) -> Result<Self, Error> {
3283        if channel.len() <= usize::from(max_lpc_order.get()) {
3284            // not enough samples in channel to calculate LPC parameters
3285            return Err(Error::InsufficientLpcSamples);
3286        }
3287
3288        let precision = match channel.len() {
3289            // this shouldn't be possible
3290            0 => panic!("at least one sample required in channel"),
3291            1..=192 => SignedBitCount::new::<7>(),
3292            193..=384 => SignedBitCount::new::<8>(),
3293            385..=576 => SignedBitCount::new::<9>(),
3294            577..=1152 => SignedBitCount::new::<10>(),
3295            1153..=2304 => SignedBitCount::new::<11>(),
3296            2305..=4608 => SignedBitCount::new::<12>(),
3297            4609.. => SignedBitCount::new::<13>(),
3298        };
3299
3300        let (order, lp_coeffs) = compute_best_order(
3301            bits_per_sample,
3302            precision,
3303            channel
3304                .len()
3305                .try_into()
3306                // this shouldn't be possible
3307                .expect("excessive samples for subframe"),
3308            lp_coefficients(autocorrelate(
3309                options.window.apply(window, windowed, channel),
3310                max_lpc_order,
3311            )),
3312        )?;
3313
3314        Self::quantize(order, lp_coeffs, precision)
3315    }
3316
3317    fn quantize(
3318        order: NonZero<u8>,
3319        coeffs: ArrayVec<f64, MAX_LPC_COEFFS>,
3320        precision: SignedBitCount<15>,
3321    ) -> Result<Self, Error> {
3322        const MAX_SHIFT: i32 = (1 << 4) - 1;
3323        const MIN_SHIFT: i32 = -(1 << 4);
3324
3325        // verified output against reference implementation
3326        // See: FLAC__lpc_quantize_coefficients
3327
3328        debug_assert!(coeffs.len() == usize::from(order.get()));
3329
3330        let max_coeff = (1 << (u32::from(precision) - 1)) - 1;
3331        let min_coeff = -(1 << (u32::from(precision) - 1));
3332
3333        let l = coeffs
3334            .iter()
3335            .map(|c| c.abs())
3336            .max_by(|x, y| x.total_cmp(y))
3337            // f64.log2() gives unfortunate results when <= 0.0
3338            .filter(|l| *l > 0.0)
3339            .ok_or(Error::ZeroLpCoefficients)?;
3340
3341        let mut error = 0.0;
3342
3343        match ((u32::from(precision) - 1) as i32 - ((l.log2().floor()) as i32) - 1).min(MAX_SHIFT) {
3344            shift @ 0.. => {
3345                // normal, positive shift case
3346                let shift = shift as u32;
3347
3348                Ok(Self {
3349                    order,
3350                    precision,
3351                    shift,
3352                    coefficients: coeffs
3353                        .into_iter()
3354                        .map(|lp_coeff| {
3355                            let sum: f64 = lp_coeff.mul_add((1 << shift) as f64, error);
3356                            let qlp_coeff = (sum.round() as i32).clamp(min_coeff, max_coeff);
3357                            error = sum - (qlp_coeff as f64);
3358                            qlp_coeff
3359                        })
3360                        .collect(),
3361                })
3362            }
3363            shift @ MIN_SHIFT..0 => {
3364                // unusual negative shift case
3365                let shift = -shift as u32;
3366
3367                Ok(Self {
3368                    order,
3369                    precision,
3370                    shift: 0,
3371                    coefficients: coeffs
3372                        .into_iter()
3373                        .map(|lp_coeff| {
3374                            let sum: f64 = (lp_coeff / (1 << shift) as f64) + error;
3375                            let qlp_coeff = (sum.round() as i32).clamp(min_coeff, max_coeff);
3376                            error = sum - (qlp_coeff as f64);
3377                            qlp_coeff
3378                        })
3379                        .collect(),
3380                })
3381            }
3382            ..MIN_SHIFT => Err(Error::LpNegativeShiftError),
3383        }
3384    }
3385}
3386
3387#[test]
3388fn test_quantization() {
3389    // test against numbers generated from reference implementation
3390
3391    let order = NonZero::new(4).unwrap();
3392
3393    let quantized = LpcParameters::quantize(
3394        order,
3395        arrayvec![0.797774, -0.045362, -0.050136, -0.054254],
3396        SignedBitCount::new::<10>(),
3397    )
3398    .unwrap();
3399
3400    assert_eq!(quantized.order, order);
3401    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3402    assert_eq!(quantized.shift, 9);
3403    assert_eq!(quantized.coefficients, arrayvec![408, -23, -25, -28]);
3404
3405    // note the relationship between the un-quantized,
3406    // floating point parameters and the shift value (9)
3407    //
3408    // 409 / 2 ** 9 ≈ 0.796875
3409    // -23 / 2 ** 9 ≈ -0.044921
3410    // -25 / 2 ** 9 ≈ -0.048828
3411    // -28 / 2 ** 9 ≈ -0.054687
3412    //
3413    // we're converting floats to fractions
3414
3415    let quantized = LpcParameters::quantize(
3416        order,
3417        arrayvec![-0.054687, -0.953216, -0.027115, 0.033537],
3418        SignedBitCount::new::<10>(),
3419    )
3420    .unwrap();
3421
3422    assert_eq!(quantized.order, order);
3423    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3424    assert_eq!(quantized.shift, 9);
3425    assert_eq!(quantized.coefficients, arrayvec![-28, -488, -14, 17]);
3426
3427    // coefficients should never be all zero, which is bad
3428    assert!(matches!(
3429        LpcParameters::quantize(
3430            order,
3431            arrayvec![0.0, 0.0, 0.0, 0.0],
3432            SignedBitCount::new::<10>(),
3433        ),
3434        Err(Error::ZeroLpCoefficients)
3435    ));
3436
3437    // negative shifts should also be handled properly
3438    let quantized = LpcParameters::quantize(
3439        order,
3440        arrayvec![-0.1, 0.1, 10000000.0, -0.2],
3441        SignedBitCount::new::<10>(),
3442    )
3443    .unwrap();
3444
3445    assert_eq!(quantized.order, order);
3446    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3447    assert_eq!(quantized.shift, 0);
3448    assert_eq!(quantized.coefficients, arrayvec![0, 0, 305, 0]);
3449
3450    // and massive negative shifts must be an error
3451    assert!(matches!(
3452        LpcParameters::quantize(
3453            order,
3454            arrayvec![-0.1, 0.1, 100000000.0, -0.2],
3455            SignedBitCount::new::<10>(),
3456        ),
3457        Err(Error::LpNegativeShiftError)
3458    ));
3459}
3460
3461fn autocorrelate(
3462    windowed: &[f64],
3463    max_lpc_order: NonZero<u8>,
3464) -> ArrayVec<f64, { MAX_LPC_COEFFS + 1 }> {
3465    // verified output against reference implementation
3466    // See: FLAC__lpc_compute_autocorrelation
3467
3468    debug_assert!(usize::from(max_lpc_order.get()) < MAX_LPC_COEFFS);
3469
3470    let mut tail = windowed;
3471    // let mut autocorrelated = Vec::with_capacity(max_lpc_order.get().into());
3472    let mut autocorrelated = ArrayVec::default();
3473
3474    for _ in 0..=max_lpc_order.get() {
3475        if tail.is_empty() {
3476            return autocorrelated;
3477        } else {
3478            autocorrelated.push(windowed.iter().zip(tail).map(|(x, y)| x * y).sum());
3479            tail.split_off_first();
3480        }
3481    }
3482
3483    autocorrelated
3484}
3485
3486#[test]
3487fn test_autocorrelation() {
3488    // test against numbers generated from reference implementation
3489
3490    assert_eq!(
3491        autocorrelate(&[1.0], NonZero::new(1).unwrap()),
3492        arrayvec![1.0]
3493    );
3494
3495    assert_eq!(
3496        autocorrelate(&[1.0, 2.0, 3.0, 4.0, 5.0], NonZero::new(4).unwrap()),
3497        arrayvec![55.0, 40.0, 26.0, 14.0, 5.0],
3498    );
3499
3500    assert_eq!(
3501        autocorrelate(
3502            &[
3503                0.0, 16.0, 31.0, 44.0, 54.0, 61.0, 64.0, 63.0, 58.0, 49.0, 38.0, 24.0, 8.0, -8.0,
3504                -24.0, -38.0, -49.0, -58.0, -63.0, -64.0, -61.0, -54.0, -44.0, -31.0, -16.0,
3505            ],
3506            NonZero::new(4).unwrap()
3507        ),
3508        arrayvec![51408.0, 49792.0, 45304.0, 38466.0, 29914.0],
3509    )
3510}
3511
3512#[derive(Debug)]
3513struct LpCoeff {
3514    coeffs: ArrayVec<f64, MAX_LPC_COEFFS>,
3515    error: f64,
3516}
3517
3518// returns a Vec of (coefficients, error) pairs
3519fn lp_coefficients(
3520    autocorrelated: ArrayVec<f64, { MAX_LPC_COEFFS + 1 }>,
3521) -> ArrayVec<LpCoeff, MAX_LPC_COEFFS> {
3522    // verified output against reference implementation
3523    // See: FLAC__lpc_compute_lp_coefficients
3524
3525    match autocorrelated.len() {
3526        0 | 1 => panic!("must have at least 2 autocorrelation values"),
3527        _ => {
3528            let k = autocorrelated[1] / autocorrelated[0];
3529            let mut lp_coefficients = arrayvec![LpCoeff {
3530                coeffs: arrayvec![k],
3531                error: autocorrelated[0] * (1.0 - k.powi(2)),
3532            }];
3533
3534            for i in 1..(autocorrelated.len() - 1) {
3535                if let [prev @ .., next] = &autocorrelated[0..=i + 1] {
3536                    let LpCoeff { coeffs, error } = lp_coefficients.last().unwrap();
3537
3538                    let q = next
3539                        - prev
3540                            .iter()
3541                            .rev()
3542                            .zip(coeffs)
3543                            .map(|(x, y)| x * y)
3544                            .sum::<f64>();
3545
3546                    let k = q / error;
3547
3548                    lp_coefficients.push(LpCoeff {
3549                        coeffs: coeffs
3550                            .iter()
3551                            .zip(coeffs.iter().rev().map(|c| k * c))
3552                            .map(|(c1, c2)| (c1 - c2))
3553                            .chain(std::iter::once(k))
3554                            .collect(),
3555                        error: error * (1.0 - k.powi(2)),
3556                    });
3557                }
3558            }
3559
3560            lp_coefficients
3561        }
3562    }
3563}
3564
3565#[allow(unused)]
3566macro_rules! assert_float_approx {
3567    ($a:expr, $b:expr) => {{
3568        let a = $a;
3569        let b = $b;
3570        assert!((a - b).abs() < 1.0e-6, "{a} != {b}");
3571    }};
3572}
3573
3574#[test]
3575fn test_lp_coefficients_1() {
3576    // test against numbers generated from reference implementation
3577
3578    let lp_coeffs = lp_coefficients(arrayvec![55.0, 40.0, 26.0, 14.0, 5.0]);
3579
3580    assert_eq!(lp_coeffs.len(), 4);
3581
3582    assert_float_approx!(lp_coeffs[0].error, 25.909091);
3583    assert_float_approx!(lp_coeffs[1].error, 25.540351);
3584    assert_float_approx!(lp_coeffs[2].error, 25.316142);
3585    assert_float_approx!(lp_coeffs[3].error, 25.241623);
3586
3587    assert_eq!(lp_coeffs[0].coeffs.len(), 1);
3588    assert_float_approx!(lp_coeffs[0].coeffs[0], 0.727273);
3589
3590    assert_eq!(lp_coeffs[1].coeffs.len(), 2);
3591    assert_float_approx!(lp_coeffs[1].coeffs[0], 0.814035);
3592    assert_float_approx!(lp_coeffs[1].coeffs[1], -0.119298);
3593
3594    assert_eq!(lp_coeffs[2].coeffs.len(), 3);
3595    assert_float_approx!(lp_coeffs[2].coeffs[0], 0.802858);
3596    assert_float_approx!(lp_coeffs[2].coeffs[1], -0.043028);
3597    assert_float_approx!(lp_coeffs[2].coeffs[2], -0.093694);
3598
3599    assert_eq!(lp_coeffs[3].coeffs.len(), 4);
3600    assert_float_approx!(lp_coeffs[3].coeffs[0], 0.797774);
3601    assert_float_approx!(lp_coeffs[3].coeffs[1], -0.045362);
3602    assert_float_approx!(lp_coeffs[3].coeffs[2], -0.050136);
3603    assert_float_approx!(lp_coeffs[3].coeffs[3], -0.054254);
3604}
3605
3606#[test]
3607fn test_lp_coefficients_2() {
3608    // test against numbers generated from reference implementation
3609
3610    let lp_coeffs = lp_coefficients(arrayvec![51408.0, 49792.0, 45304.0, 38466.0, 29914.0]);
3611
3612    assert_eq!(lp_coeffs.len(), 4);
3613
3614    assert_float_approx!(lp_coeffs[0].error, 3181.201369);
3615    assert_float_approx!(lp_coeffs[1].error, 495.815931);
3616    assert_float_approx!(lp_coeffs[2].error, 495.161449);
3617    assert_float_approx!(lp_coeffs[3].error, 494.604514);
3618
3619    assert_eq!(lp_coeffs[0].coeffs.len(), 1);
3620    assert_float_approx!(lp_coeffs[0].coeffs[0], 0.968565);
3621
3622    assert_eq!(lp_coeffs[1].coeffs.len(), 2);
3623    assert_float_approx!(lp_coeffs[1].coeffs[0], 1.858456);
3624    assert_float_approx!(lp_coeffs[1].coeffs[1], -0.918772);
3625
3626    assert_eq!(lp_coeffs[2].coeffs.len(), 3);
3627    assert_float_approx!(lp_coeffs[2].coeffs[0], 1.891837);
3628    assert_float_approx!(lp_coeffs[2].coeffs[1], -0.986293);
3629    assert_float_approx!(lp_coeffs[2].coeffs[2], 0.036332);
3630
3631    assert_eq!(lp_coeffs[3].coeffs.len(), 4);
3632    assert_float_approx!(lp_coeffs[3].coeffs[0], 1.890618);
3633    assert_float_approx!(lp_coeffs[3].coeffs[1], -0.953216);
3634    assert_float_approx!(lp_coeffs[3].coeffs[2], -0.027115);
3635    assert_float_approx!(lp_coeffs[3].coeffs[3], 0.033537);
3636}
3637
3638// Returns (bits, order, coeffients) tuples
3639fn subframe_bits_by_order(
3640    bits_per_sample: SignedBitCount<32>,
3641    precision: SignedBitCount<15>,
3642    sample_count: u16,
3643    coeffs: ArrayVec<LpCoeff, MAX_LPC_COEFFS>,
3644) -> impl Iterator<Item = (f64, u8, ArrayVec<f64, MAX_LPC_COEFFS>)> {
3645    debug_assert!(sample_count > 0);
3646
3647    let error_scale = 0.5 / f64::from(sample_count);
3648
3649    coeffs
3650        .into_iter()
3651        .take_while(|coeffs| coeffs.error > 0.0)
3652        .zip(1..)
3653        .map(move |(LpCoeff { coeffs, error }, order)| {
3654            let header_bits =
3655                u32::from(order) * (u32::from(bits_per_sample) + u32::from(precision));
3656
3657            let bits_per_residual =
3658                (error * error_scale).ln() / (2.0 * std::f64::consts::LN_2).max(0.0);
3659
3660            let subframe_bits = bits_per_residual.mul_add(
3661                f64::from(sample_count - u16::from(order)),
3662                f64::from(header_bits),
3663            );
3664
3665            (subframe_bits, order, coeffs)
3666        })
3667}
3668
3669// Uses the error in the LP coefficients to determine the best order
3670// and returns that order along with the stripped-out coefficients
3671fn compute_best_order(
3672    bits_per_sample: SignedBitCount<32>,
3673    precision: SignedBitCount<15>,
3674    sample_count: u16,
3675    coeffs: ArrayVec<LpCoeff, MAX_LPC_COEFFS>,
3676) -> Result<(NonZero<u8>, ArrayVec<f64, MAX_LPC_COEFFS>), Error> {
3677    // verified output against reference implementation
3678    // See: FLAC__lpc_compute_best_order  and
3679    // See: FLAC__lpc_compute_expected_bits_per_residual_sample_with_error_scale
3680
3681    subframe_bits_by_order(bits_per_sample, precision, sample_count, coeffs)
3682        .min_by(|(x, _, _), (y, _, _)| x.total_cmp(y))
3683        .and_then(|(_, order, coeffs)| Some((NonZero::new(order)?, coeffs)))
3684        .ok_or(Error::NoBestLpcOrder)
3685}
3686
3687#[test]
3688fn test_compute_best_order() {
3689    // test against numbers generated from reference implementation
3690
3691    let mut bits = subframe_bits_by_order(
3692        SignedBitCount::new::<16>(),
3693        SignedBitCount::new::<5>(),
3694        20,
3695        [3181.201369, 495.815931, 495.161449, 494.604514]
3696            .into_iter()
3697            .map(|error| LpCoeff {
3698                coeffs: ArrayVec::default(),
3699                error,
3700            })
3701            .collect(),
3702    )
3703    .map(|t| t.0);
3704
3705    assert_float_approx!(bits.next().unwrap(), 80.977565);
3706    assert_float_approx!(bits.next().unwrap(), 74.685594);
3707    assert_float_approx!(bits.next().unwrap(), 93.853530);
3708    assert_float_approx!(bits.next().unwrap(), 113.025628);
3709
3710    let mut bits = subframe_bits_by_order(
3711        SignedBitCount::new::<16>(),
3712        SignedBitCount::new::<10>(),
3713        4096,
3714        [15000.0, 25000.0, 20000.0, 30000.0]
3715            .into_iter()
3716            .map(|error| LpCoeff {
3717                coeffs: ArrayVec::default(),
3718                error,
3719            })
3720            .collect(),
3721    )
3722    .map(|t| t.0);
3723
3724    assert_float_approx!(bits.next().unwrap(), 1812.801817);
3725    assert_float_approx!(bits.next().unwrap(), 3346.934051);
3726    assert_float_approx!(bits.next().unwrap(), 2713.303385);
3727    assert_float_approx!(bits.next().unwrap(), 3935.492805);
3728}
3729
3730fn write_residuals<W: BitWrite>(
3731    options: &EncoderOptions,
3732    writer: &mut W,
3733    predictor_order: usize,
3734    residuals: &[i32],
3735) -> Result<(), Error> {
3736    use crate::stream::ResidualPartitionHeader;
3737    use bitstream_io::{BitCount, ToBitStream};
3738
3739    const MAX_PARTITIONS: usize = 64;
3740
3741    #[derive(Debug)]
3742    struct Partition<'r, const RICE_MAX: u32> {
3743        header: ResidualPartitionHeader<RICE_MAX>,
3744        residuals: &'r [i32],
3745    }
3746
3747    impl<'r, const RICE_MAX: u32> Partition<'r, RICE_MAX> {
3748        fn new(partition: &'r [i32], estimated_bits: &mut u32) -> Option<Self> {
3749            let partition_samples = partition.len() as u16;
3750            if partition_samples == 0 {
3751                return None;
3752            }
3753
3754            let partition_sum = partition
3755                .iter()
3756                .map(|i| u64::from(i.unsigned_abs()))
3757                .sum::<u64>();
3758
3759            if partition_sum > 0 {
3760                let rice = if partition_sum > partition_samples.into() {
3761                    let bits_needed = ((partition_sum as f64) / f64::from(partition_samples))
3762                        .log2()
3763                        .ceil() as u32;
3764
3765                    match BitCount::try_from(bits_needed).ok().filter(|rice| {
3766                        u32::from(*rice) < u32::from(BitCount::<RICE_MAX>::new::<RICE_MAX>())
3767                    }) {
3768                        Some(rice) => rice,
3769                        None => {
3770                            let escape_size = (partition
3771                                .iter()
3772                                .map(|i| u64::from(i.unsigned_abs()))
3773                                .sum::<u64>()
3774                                .ilog2()
3775                                + 2)
3776                            .try_into()
3777                            .ok()?;
3778
3779                            *estimated_bits +=
3780                                u32::from(escape_size) * u32::from(partition_samples);
3781
3782                            return Some(Self {
3783                                header: ResidualPartitionHeader::Escaped { escape_size },
3784                                residuals: partition,
3785                            });
3786                        }
3787                    }
3788                } else {
3789                    BitCount::new::<0>()
3790                };
3791
3792                let partition_size: u32 = 4u32
3793                    + ((1 + u32::from(rice)) * u32::from(partition_samples))
3794                    + if u32::from(rice) > 0 {
3795                        u32::try_from(partition_sum >> (u32::from(rice) - 1)).ok()?
3796                    } else {
3797                        u32::try_from(partition_sum << 1).ok()?
3798                    }
3799                    - (u32::from(partition_samples) / 2);
3800
3801                *estimated_bits += partition_size;
3802
3803                Some(Partition {
3804                    header: ResidualPartitionHeader::Standard { rice },
3805                    residuals: partition,
3806                })
3807            } else {
3808                // all partition residuals are 0, so use a constant
3809                Some(Partition {
3810                    header: ResidualPartitionHeader::Constant,
3811                    residuals: partition,
3812                })
3813            }
3814        }
3815    }
3816
3817    impl<const RICE_MAX: u32> ToBitStream for Partition<'_, RICE_MAX> {
3818        type Error = std::io::Error;
3819
3820        #[inline]
3821        fn to_writer<W: BitWrite + ?Sized>(&self, w: &mut W) -> Result<(), Self::Error> {
3822            w.build(&self.header)?;
3823            match self.header {
3824                ResidualPartitionHeader::Standard { rice } => {
3825                    let mask = rice.mask_lsb();
3826
3827                    self.residuals.iter().try_for_each(|s| {
3828                        let (msb, lsb) = mask(if s.is_negative() {
3829                            ((-*s as u32 - 1) << 1) + 1
3830                        } else {
3831                            (*s as u32) << 1
3832                        });
3833                        w.write_unary::<1>(msb)?;
3834                        w.write_checked(lsb)
3835                    })?;
3836                }
3837                ResidualPartitionHeader::Escaped { escape_size } => {
3838                    self.residuals
3839                        .iter()
3840                        .try_for_each(|s| w.write_signed_counted(escape_size, *s))?;
3841                }
3842                ResidualPartitionHeader::Constant => { /* nothing left to do */ }
3843            }
3844            Ok(())
3845        }
3846    }
3847
3848    fn best_partitions<'r, const RICE_MAX: u32>(
3849        options: &EncoderOptions,
3850        block_size: usize,
3851        residuals: &'r [i32],
3852    ) -> ArrayVec<Partition<'r, RICE_MAX>, MAX_PARTITIONS> {
3853        (0..=block_size.trailing_zeros().min(options.max_partition_order))
3854            .map(|partition_order| 1 << partition_order)
3855            .take_while(|partition_count: &usize| partition_count.is_power_of_two())
3856            .filter_map(|partition_count| {
3857                let mut estimated_bits = 0;
3858
3859                let partitions = residuals
3860                    .rchunks(block_size / partition_count)
3861                    .rev()
3862                    .map(|partition| Partition::new(partition, &mut estimated_bits))
3863                    .collect::<Option<ArrayVec<_, MAX_PARTITIONS>>>()
3864                    .filter(|p| !p.is_empty() && p.len().is_power_of_two())?;
3865
3866                Some((partitions, estimated_bits))
3867            })
3868            .min_by_key(|(_, estimated_bits)| *estimated_bits)
3869            .map(|(partitions, _)| partitions)
3870            .unwrap_or_else(|| {
3871                std::iter::once(Partition {
3872                    header: ResidualPartitionHeader::Escaped {
3873                        escape_size: SignedBitCount::new::<0b11111>(),
3874                    },
3875                    residuals,
3876                })
3877                .collect()
3878            })
3879    }
3880
3881    fn write_partitions<const RICE_MAX: u32, W: BitWrite>(
3882        writer: &mut W,
3883        partitions: ArrayVec<Partition<'_, RICE_MAX>, MAX_PARTITIONS>,
3884    ) -> Result<(), Error> {
3885        writer.write::<4, u32>(partitions.len().ilog2())?; // partition order
3886        for partition in partitions {
3887            writer.build(&partition)?;
3888        }
3889        Ok(())
3890    }
3891
3892    #[inline]
3893    fn try_shrink_header<const RICE_MAX: u32, const RICE_NEW_MAX: u32>(
3894        header: ResidualPartitionHeader<RICE_MAX>,
3895    ) -> Option<ResidualPartitionHeader<RICE_NEW_MAX>> {
3896        Some(match header {
3897            ResidualPartitionHeader::Standard { rice } => ResidualPartitionHeader::Standard {
3898                rice: rice.try_map(|r| (r < RICE_NEW_MAX).then_some(r))?,
3899            },
3900            ResidualPartitionHeader::Escaped { escape_size } => {
3901                ResidualPartitionHeader::Escaped { escape_size }
3902            }
3903            ResidualPartitionHeader::Constant => ResidualPartitionHeader::Constant,
3904        })
3905    }
3906
3907    enum CodingMethod<'p> {
3908        Rice(ArrayVec<Partition<'p, 0b1111>, MAX_PARTITIONS>),
3909        Rice2(ArrayVec<Partition<'p, 0b11111>, MAX_PARTITIONS>),
3910    }
3911
3912    fn try_reduce_rice(
3913        partitions: ArrayVec<Partition<'_, 0b11111>, MAX_PARTITIONS>,
3914    ) -> CodingMethod {
3915        match partitions
3916            .iter()
3917            .map(|Partition { header, residuals }| {
3918                try_shrink_header(*header).map(|header| Partition { header, residuals })
3919            })
3920            .collect()
3921        {
3922            Some(partitions) => CodingMethod::Rice(partitions),
3923            None => CodingMethod::Rice2(partitions),
3924        }
3925    }
3926
3927    let block_size = predictor_order + residuals.len();
3928
3929    if options.use_rice2 {
3930        match try_reduce_rice(best_partitions(options, block_size, residuals)) {
3931            CodingMethod::Rice(partitions) => {
3932                writer.write::<2, u8>(0)?; // coding method 0
3933                write_partitions(writer, partitions)
3934            }
3935            CodingMethod::Rice2(partitions) => {
3936                writer.write::<2, u8>(1)?; // coding method 1
3937                write_partitions(writer, partitions)
3938            }
3939        }
3940    } else {
3941        let partitions = best_partitions::<0b1111>(options, block_size, residuals);
3942        writer.write::<2, u8>(0)?; // coding method 0
3943        write_partitions(writer, partitions)
3944    }
3945}
3946
3947fn try_join<A, B, RA, RB, E>(oper_a: A, oper_b: B) -> Result<(RA, RB), E>
3948where
3949    A: FnOnce() -> Result<RA, E> + Send,
3950    B: FnOnce() -> Result<RB, E> + Send,
3951    RA: Send,
3952    RB: Send,
3953    E: Send,
3954{
3955    let (a, b) = join(oper_a, oper_b);
3956    Ok((a?, b?))
3957}
3958
3959#[cfg(feature = "rayon")]
3960use rayon::join;
3961
3962#[cfg(not(feature = "rayon"))]
3963fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
3964where
3965    A: FnOnce() -> RA + Send,
3966    B: FnOnce() -> RB + Send,
3967    RA: Send,
3968    RB: Send,
3969{
3970    (oper_a(), oper_b())
3971}
3972
3973#[cfg(feature = "rayon")]
3974fn vec_map<T, U, F>(src: Vec<T>, f: F) -> Vec<U>
3975where
3976    T: Send,
3977    U: Send,
3978    F: Fn(T) -> U + Send + Sync,
3979{
3980    use rayon::iter::{IntoParallelIterator, ParallelIterator};
3981
3982    src.into_par_iter().map(f).collect()
3983}
3984
3985#[cfg(not(feature = "rayon"))]
3986fn vec_map<T, U, F>(src: Vec<T>, f: F) -> Vec<U>
3987where
3988    T: Send,
3989    U: Send,
3990    F: Fn(T) -> U + Send + Sync,
3991{
3992    src.into_iter().map(f).collect()
3993}
3994
3995fn exact_div<N>(n: N, rhs: N) -> Option<N>
3996where
3997    N: std::ops::Div<Output = N> + std::ops::Rem<Output = N> + std::cmp::PartialEq + Copy + Default,
3998{
3999    (n % rhs == N::default()).then_some(n / rhs)
4000}