flac_codec/
encode.rs

1// Copyright 2025 Brian Langenberger
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! For encoding PCM samples to FLAC files
10//!
11//! ## Multithreading
12//!
13//! Encoders will operate using multithreading if the optional `rayon` feature is enabled,
14//! typically boosting performance by processing channels in parallel.
15//! But because subframes must eventually be written serially, and their size cannot generally
16//! be known in advance, processing two channels across two threads will not
17//! encode twice as fast.
18
19use crate::audio::Frame;
20use crate::metadata::{
21    Application, BlockList, BlockSize, Cuesheet, Picture, PortableMetadataBlock, SeekPoint,
22    Streaminfo, VorbisComment, write_blocks,
23};
24use crate::stream::{ChannelAssignment, FrameNumber, Independent, SampleRate};
25use crate::{Counter, Error};
26use arrayvec::ArrayVec;
27use bitstream_io::{BigEndian, BitRecorder, BitWrite, BitWriter, SignedBitCount};
28use std::collections::VecDeque;
29use std::fs::File;
30use std::io::BufWriter;
31use std::num::NonZero;
32use std::path::Path;
33
34const MAX_CHANNELS: usize = 8;
35// maximum number of LPC coefficients
36const MAX_LPC_COEFFS: usize = 32;
37
38// Invent a vec!-like macro that the official crate lacks
39macro_rules! arrayvec {
40    ( $( $x:expr ),* ) => {
41        {
42            let mut v = ArrayVec::default();
43            $( v.push($x); )*
44            v
45        }
46    }
47}
48
49/// A FLAC writer which accepts samples as bytes
50///
51/// # Example
52///
53/// ```
54/// use flac_codec::{
55///     byteorder::LittleEndian,
56///     encode::{FlacByteWriter, Options},
57///     decode::{FlacByteReader, Metadata},
58/// };
59/// use std::io::{Cursor, Read, Seek, Write};
60///
61/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
62///
63/// let mut writer = FlacByteWriter::endian(
64///     &mut flac,           // our wrapped writer
65///     LittleEndian,        // .wav-style byte order
66///     Options::default(),  // default encoding options
67///     44100,               // sample rate
68///     16,                  // bits-per-sample
69///     1,                   // channel count
70///     Some(2000),          // total bytes
71/// ).unwrap();
72///
73/// // write 1000 samples as 16-bit, signed, little-endian bytes (2000 bytes total)
74/// let written_bytes = (0..1000).map(i16::to_le_bytes).flatten().collect::<Vec<u8>>();
75/// assert!(writer.write_all(&written_bytes).is_ok());
76///
77/// // finalize writing file
78/// assert!(writer.finalize().is_ok());
79///
80/// flac.rewind().unwrap();
81///
82/// // open reader around written FLAC file
83/// let mut reader = FlacByteReader::endian(flac, LittleEndian).unwrap();
84///
85/// // read 2000 bytes
86/// let mut read_bytes = vec![];
87/// assert!(reader.read_to_end(&mut read_bytes).is_ok());
88///
89/// // ensure MD5 sum of signed, little-endian samples matches hash in file
90/// let mut md5 = md5::Context::new();
91/// md5.consume(&read_bytes);
92/// assert_eq!(&md5.compute().0, reader.md5().unwrap());
93///
94/// // ensure input and output matches
95/// assert_eq!(read_bytes, written_bytes);
96/// ```
97pub struct FlacByteWriter<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> {
98    // the wrapped encoder
99    encoder: Encoder<W>,
100    // bytes that make up a partial FLAC frame
101    buf: VecDeque<u8>,
102    // a whole set of samples for a FLAC frame
103    frame: Frame,
104    // size of a single sample in bytes
105    bytes_per_sample: usize,
106    // size of single set of channel-independent samples in bytes
107    pcm_frame_size: usize,
108    // size of whole FLAC frame's samples in bytes
109    frame_byte_size: usize,
110    // whether the encoder has finalized the file
111    finalized: bool,
112    // the input bytes' endianness
113    endianness: std::marker::PhantomData<E>,
114}
115
116impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> FlacByteWriter<W, E> {
117    /// Creates new FLAC writer with the given parameters
118    ///
119    /// The writer should be positioned at the start of the file.
120    ///
121    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
122    ///
123    /// `bits_per_sample` must be between 1 and 32.
124    ///
125    /// `channels` must be between 1 and 8.
126    ///
127    /// Note that if `total_bytes` is indicated,
128    /// the number of channel-independent samples written *must*
129    /// be equal to that amount or an error will occur when writing
130    /// or finalizing the stream.
131    ///
132    /// # Errors
133    ///
134    /// Returns I/O error if unable to write initial
135    /// metadata blocks.
136    /// Returns error if any of the encoding parameters are invalid.
137    pub fn new(
138        writer: W,
139        options: Options,
140        sample_rate: u32,
141        bits_per_sample: u32,
142        channels: u8,
143        total_bytes: Option<u64>,
144    ) -> Result<Self, Error> {
145        let bits_per_sample: SignedBitCount<32> = bits_per_sample
146            .try_into()
147            .map_err(|_| Error::InvalidBitsPerSample)?;
148
149        let bytes_per_sample = u32::from(bits_per_sample).div_ceil(8) as usize;
150
151        let pcm_frame_size = bytes_per_sample * channels as usize;
152
153        Ok(Self {
154            buf: VecDeque::default(),
155            frame: Frame::empty(channels.into(), bits_per_sample.into()),
156            bytes_per_sample,
157            pcm_frame_size,
158            frame_byte_size: pcm_frame_size * options.block_size as usize,
159            encoder: Encoder::new(
160                writer,
161                options,
162                sample_rate,
163                bits_per_sample,
164                channels,
165                total_bytes
166                    .map(|bytes| {
167                        exact_div(bytes, channels.into())
168                            .and_then(|s| exact_div(s, bytes_per_sample as u64))
169                            .ok_or(Error::SamplesNotDivisibleByChannels)
170                            .and_then(|b| NonZero::new(b).ok_or(Error::InvalidTotalBytes))
171                    })
172                    .transpose()?,
173            )?,
174            finalized: false,
175            endianness: std::marker::PhantomData,
176        })
177    }
178
179    /// Creates new FLAC writer with CDDA parameters
180    ///
181    /// The writer should be positioned at the start of the file.
182    ///
183    /// Sample rate is 44100 Hz, bits-per-sample is 16,
184    /// channels is 2.
185    ///
186    /// Note that if `total_bytes` is indicated,
187    /// the number of channel-independent samples written *must*
188    /// be equal to that amount or an error will occur when writing
189    /// or finalizing the stream.
190    ///
191    /// # Errors
192    ///
193    /// Returns I/O error if unable to write initial
194    /// metadata blocks.
195    /// Returns error if any of the encoding parameters are invalid.
196    pub fn new_cdda(writer: W, options: Options, total_bytes: Option<u64>) -> Result<Self, Error> {
197        Self::new(writer, options, 44100, 16, 2, total_bytes)
198    }
199
200    /// Creates new FLAC writer in the given endianness with the given parameters
201    ///
202    /// The writer should be positioned at the start of the file.
203    ///
204    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
205    ///
206    /// `bits_per_sample` must be between 1 and 32.
207    ///
208    /// `channels` must be between 1 and 8.
209    ///
210    /// Note that if `total_bytes` is indicated,
211    /// the number of bytes written *must*
212    /// be equal to that amount or an error will occur when writing
213    /// or finalizing the stream.
214    ///
215    /// # Errors
216    ///
217    /// Returns I/O error if unable to write initial
218    /// metadata blocks.
219    #[inline]
220    pub fn endian(
221        writer: W,
222        _endianness: E,
223        options: Options,
224        sample_rate: u32,
225        bits_per_sample: u32,
226        channels: u8,
227        total_bytes: Option<u64>,
228    ) -> Result<Self, Error> {
229        Self::new(
230            writer,
231            options,
232            sample_rate,
233            bits_per_sample,
234            channels,
235            total_bytes,
236        )
237    }
238
239    fn finalize_inner(&mut self) -> Result<(), Error> {
240        if !self.finalized {
241            self.finalized = true;
242
243            // encode as many bytes as possible into final frame, if necessary
244            if !self.buf.is_empty() {
245                use crate::byteorder::LittleEndian;
246
247                // truncate buffer to whole PCM frames
248                let buf = self.buf.make_contiguous();
249                let buf_len = buf.len();
250                let buf = &mut buf[..(buf_len - buf_len % self.pcm_frame_size)];
251
252                // convert buffer to little-endian bytes
253                E::bytes_to_le(buf, self.bytes_per_sample);
254
255                // update MD5 sum with little-endian bytes
256                self.encoder.md5.consume(&buf);
257
258                self.encoder
259                    .encode(self.frame.fill_from_buf::<LittleEndian>(buf))?;
260            }
261
262            self.encoder.finalize_inner()
263        } else {
264            Ok(())
265        }
266    }
267
268    /// Attempt to finalize stream
269    ///
270    /// It is necessary to finalize the FLAC encoder
271    /// so that it will write any partially unwritten samples
272    /// to the stream and update the [`crate::metadata::Streaminfo`] and [`crate::metadata::SeekTable`] blocks
273    /// with their final values.
274    ///
275    /// Dropping the encoder will attempt to finalize the stream
276    /// automatically, but will ignore any errors that may occur.
277    pub fn finalize(mut self) -> Result<(), Error> {
278        self.finalize_inner()?;
279        Ok(())
280    }
281}
282
283impl<E: crate::byteorder::Endianness> FlacByteWriter<BufWriter<File>, E> {
284    /// Creates new FLAC file at the given path
285    ///
286    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
287    ///
288    /// `bits_per_sample` must be between 1 and 32.
289    ///
290    /// `channels` must be between 1 and 8.
291    ///
292    /// Note that if `total_bytes` is indicated,
293    /// the number of bytes written *must*
294    /// be equal to that amount or an error will occur when writing
295    /// or finalizing the stream.
296    ///
297    /// # Errors
298    ///
299    /// Returns I/O error if unable to write initial
300    /// metadata blocks.
301    #[inline]
302    pub fn create<P: AsRef<Path>>(
303        path: P,
304        options: Options,
305        sample_rate: u32,
306        bits_per_sample: u32,
307        channels: u8,
308        total_bytes: Option<u64>,
309    ) -> Result<Self, Error> {
310        FlacByteWriter::new(
311            BufWriter::new(options.create(path)?),
312            options,
313            sample_rate,
314            bits_per_sample,
315            channels,
316            total_bytes,
317        )
318    }
319
320    /// Creates new FLAC file with CDDA parameters at the given path
321    ///
322    /// Sample rate is 44100 Hz, bits-per-sample is 16,
323    /// channels is 2.
324    ///
325    /// Note that if `total_bytes` is indicated,
326    /// the number of bytes written *must*
327    /// be equal to that amount or an error will occur when writing
328    /// or finalizing the stream.
329    ///
330    /// # Errors
331    ///
332    /// Returns I/O error if unable to write initial
333    /// metadata blocks.
334    pub fn create_cdda<P: AsRef<Path>>(
335        path: P,
336        options: Options,
337        total_bytes: Option<u64>,
338    ) -> Result<Self, Error> {
339        Self::create(path, options, 44100, 16, 2, total_bytes)
340    }
341}
342
343impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> std::io::Write
344    for FlacByteWriter<W, E>
345{
346    /// Writes a set of sample bytes to the FLAC file
347    ///
348    /// Samples are signed and encoded in the stream's given byte order.
349    ///
350    /// Samples are then interleaved by channel, like:
351    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
352    ///
353    /// This is the same format used by common PCM container
354    /// formats like WAVE and AIFF.
355    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
356        use crate::byteorder::LittleEndian;
357
358        // dump whole set of bytes into our internal buffer
359        self.buf.extend(buf);
360
361        // encode as many FLAC frames as possible (which may be 0)
362        let mut encoded_frames = 0;
363        for buf in self
364            .buf
365            .make_contiguous()
366            .chunks_exact_mut(self.frame_byte_size)
367        {
368            // convert buffer to little-endian bytes
369            E::bytes_to_le(buf, self.bytes_per_sample);
370
371            // update MD5 sum with little-endian bytes
372            self.encoder.md5.consume(&buf);
373
374            // encode fresh FLAC frame
375            self.encoder
376                .encode(self.frame.fill_from_buf::<LittleEndian>(buf))?;
377
378            encoded_frames += 1;
379        }
380        // TODO - use truncate_front whenever that stabilizes
381        self.buf.drain(0..self.frame_byte_size * encoded_frames);
382
383        // indicate whole buffer's been consumed
384        Ok(buf.len())
385    }
386
387    #[inline]
388    fn flush(&mut self) -> std::io::Result<()> {
389        // we don't want to flush a partial frame to disk,
390        // but we can at least flush our internal writer
391        self.encoder.writer.flush()
392    }
393}
394
395impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> Drop
396    for FlacByteWriter<W, E>
397{
398    fn drop(&mut self) {
399        let _ = self.finalize_inner();
400    }
401}
402
403/// A FLAC writer which accepts samples as signed integers
404///
405/// # Example
406///
407/// ```
408/// use flac_codec::{
409///     encode::{FlacSampleWriter, Options},
410///     decode::FlacSampleReader,
411/// };
412/// use std::io::{Cursor, Seek};
413///
414/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
415///
416/// let mut writer = FlacSampleWriter::new(
417///     &mut flac,           // our wrapped writer
418///     Options::default(),  // default encoding options
419///     44100,               // sample rate
420///     16,                  // bits-per-sample
421///     1,                   // channel count
422///     Some(1000),          // total samples
423/// ).unwrap();
424///
425/// // write 1000 samples
426/// let written_samples = (0..1000).collect::<Vec<i32>>();
427/// assert!(writer.write(&written_samples).is_ok());
428///
429/// // finalize writing file
430/// assert!(writer.finalize().is_ok());
431///
432/// flac.rewind().unwrap();
433///
434/// // open reader around written FLAC file
435/// let mut reader = FlacSampleReader::new(flac).unwrap();
436///
437/// // read 1000 samples
438/// let mut read_samples = vec![0; 1000];
439/// assert!(matches!(reader.read(&mut read_samples), Ok(1000)));
440///
441/// // ensure they match
442/// assert_eq!(read_samples, written_samples);
443/// ```
444pub struct FlacSampleWriter<W: std::io::Write + std::io::Seek> {
445    // the wrapped encoder
446    encoder: Encoder<W>,
447    // samples that make up a partial FLAC frame
448    // in channel-interleaved order
449    // (must de-interleave later in case someone writes
450    // only partial set of channels in a single write call)
451    sample_buf: VecDeque<i32>,
452    // a whole set of samples for a FLAC frame
453    frame: Frame,
454    // size of a single frame in samples
455    frame_sample_size: usize,
456    // size of a single PCM frame in samples
457    pcm_frame_size: usize,
458    // size of a single sample in bytes
459    bytes_per_sample: usize,
460    // whether the encoder has finalized the file
461    finalized: bool,
462}
463
464impl<W: std::io::Write + std::io::Seek> FlacSampleWriter<W> {
465    /// Creates new FLAC writer with the given parameters
466    ///
467    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
468    ///
469    /// `bits_per_sample` must be between 1 and 32.
470    ///
471    /// `channels` must be between 1 and 8.
472    ///
473    /// Note that if `total_samples` is indicated,
474    /// the number of samples written *must*
475    /// be equal to that amount or an error will occur when writing
476    /// or finalizing the stream.
477    ///
478    /// # Errors
479    ///
480    /// Returns I/O error if unable to write initial
481    /// metadata blocks.
482    /// Returns error if any of the encoding parameters are invalid.
483    pub fn new(
484        writer: W,
485        options: Options,
486        sample_rate: u32,
487        bits_per_sample: u32,
488        channels: u8,
489        total_samples: Option<u64>,
490    ) -> Result<Self, Error> {
491        let bits_per_sample: SignedBitCount<32> = bits_per_sample
492            .try_into()
493            .map_err(|_| Error::InvalidBitsPerSample)?;
494
495        let bytes_per_sample = u32::from(bits_per_sample).div_ceil(8) as usize;
496
497        let pcm_frame_size = usize::from(channels);
498
499        Ok(Self {
500            sample_buf: VecDeque::default(),
501            frame: Frame::empty(channels.into(), bits_per_sample.into()),
502            bytes_per_sample,
503            pcm_frame_size,
504            frame_sample_size: pcm_frame_size * options.block_size as usize,
505            encoder: Encoder::new(
506                writer,
507                options,
508                sample_rate,
509                bits_per_sample,
510                channels,
511                total_samples
512                    .map(|samples| {
513                        exact_div(samples, channels.into())
514                            .ok_or(Error::SamplesNotDivisibleByChannels)
515                            .and_then(|s| NonZero::new(s).ok_or(Error::InvalidTotalSamples))
516                    })
517                    .transpose()?,
518            )?,
519            finalized: false,
520        })
521    }
522
523    /// Creates new FLAC writer with CDDA parameters
524    ///
525    /// Sample rate is 44100 Hz, bits-per-sample is 16,
526    /// channels is 2.
527    ///
528    /// Note that if `total_samples` is indicated,
529    /// the number of samples written *must*
530    /// be equal to that amount or an error will occur when writing
531    /// or finalizing the stream.
532    ///
533    /// # Errors
534    ///
535    /// Returns I/O error if unable to write initial
536    /// metadata blocks.
537    /// Returns error if any of the encoding parameters are invalid.
538    pub fn new_cdda(
539        writer: W,
540        options: Options,
541        total_samples: Option<u64>,
542    ) -> Result<Self, Error> {
543        Self::new(writer, options, 44100, 16, 2, total_samples)
544    }
545
546    /// Given a set of samples, writes them to the FLAC file
547    ///
548    /// Samples are interleaved by channel, like:
549    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
550    ///
551    /// This may output 0 or more actual FLAC frames,
552    /// depending on the quantity of samples and the amount
553    /// previously written.
554    pub fn write(&mut self, samples: &[i32]) -> Result<(), Error> {
555        // dump whole set of samples into our internal buffer
556        self.sample_buf.extend(samples);
557
558        // encode as many FLAC frames as possible (which may be 0)
559        let mut encoded_frames = 0;
560        for buf in self
561            .sample_buf
562            .make_contiguous()
563            .chunks_exact_mut(self.frame_sample_size)
564        {
565            // update running MD5 sum calculation
566            // since samples are already interleaved in channel order
567            update_md5(
568                &mut self.encoder.md5,
569                buf.iter().copied(),
570                self.bytes_per_sample,
571            );
572
573            // encode fresh FLAC frame
574            self.encoder.encode(self.frame.fill_from_samples(buf))?;
575
576            encoded_frames += 1;
577        }
578        self.sample_buf
579            .drain(0..self.frame_sample_size * encoded_frames);
580
581        Ok(())
582    }
583
584    fn finalize_inner(&mut self) -> Result<(), Error> {
585        if !self.finalized {
586            self.finalized = true;
587
588            // encode as many samples possible into final frame, if necessary
589            if !self.sample_buf.is_empty() {
590                // truncate buffer to whole PCM frames
591                let buf = self.sample_buf.make_contiguous();
592                let buf_len = buf.len();
593                let buf = &mut buf[..(buf_len - buf_len % self.pcm_frame_size)];
594
595                // update running MD5 sum calculation
596                // since samples are already interleaved in channel order
597                update_md5(
598                    &mut self.encoder.md5,
599                    buf.iter().copied(),
600                    self.bytes_per_sample,
601                );
602
603                // encode final FLAC frame
604                self.encoder.encode(self.frame.fill_from_samples(buf))?;
605            }
606
607            self.encoder.finalize_inner()
608        } else {
609            Ok(())
610        }
611    }
612
613    /// Attempt to finalize stream
614    ///
615    /// It is necessary to finalize the FLAC encoder
616    /// so that it will write any partially unwritten samples
617    /// to the stream and update the [`crate::metadata::Streaminfo`] and [`crate::metadata::SeekTable`] blocks
618    /// with their final values.
619    ///
620    /// Dropping the encoder will attempt to finalize the stream
621    /// automatically, but will ignore any errors that may occur.
622    pub fn finalize(mut self) -> Result<(), Error> {
623        self.finalize_inner()?;
624        Ok(())
625    }
626}
627
628impl FlacSampleWriter<BufWriter<File>> {
629    /// Creates new FLAC file at the given path
630    ///
631    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
632    ///
633    /// `bits_per_sample` must be between 1 and 32.
634    ///
635    /// `channels` must be between 1 and 8.
636    ///
637    /// Note that if `total_bytes` is indicated,
638    /// the number of bytes written *must*
639    /// be equal to that amount or an error will occur when writing
640    /// or finalizing the stream.
641    ///
642    /// # Errors
643    ///
644    /// Returns I/O error if unable to write initial
645    /// metadata blocks.
646    #[inline]
647    pub fn create<P: AsRef<Path>>(
648        path: P,
649        options: Options,
650        sample_rate: u32,
651        bits_per_sample: u32,
652        channels: u8,
653        total_samples: Option<u64>,
654    ) -> Result<Self, Error> {
655        FlacSampleWriter::new(
656            BufWriter::new(options.create(path)?),
657            options,
658            sample_rate,
659            bits_per_sample,
660            channels,
661            total_samples,
662        )
663    }
664
665    /// Creates new FLAC file at the given path with CDDA parameters
666    ///
667    /// Sample rate is 44100 Hz, bits-per-sample is 16,
668    /// channels is 2.
669    ///
670    /// Note that if `total_bytes` is indicated,
671    /// the number of bytes written *must*
672    /// be equal to that amount or an error will occur when writing
673    /// or finalizing the stream.
674    ///
675    /// # Errors
676    ///
677    /// Returns I/O error if unable to write initial
678    /// metadata blocks.
679    #[inline]
680    pub fn create_cdda<P: AsRef<Path>>(
681        path: P,
682        options: Options,
683        total_samples: Option<u64>,
684    ) -> Result<Self, Error> {
685        Self::create(path, options, 44100, 16, 2, total_samples)
686    }
687}
688
689/// A FLAC writer which accepts samples as channels of signed integers
690///
691/// # Example
692/// ```
693/// use flac_codec::{
694///     encode::{FlacChannelWriter, Options},
695///     decode::FlacChannelReader,
696/// };
697/// use std::io::{Cursor, Seek};
698///
699/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
700///
701/// let mut writer = FlacChannelWriter::new(
702///     &mut flac,           // our wrapped writer
703///     Options::default(),  // default encoding options
704///     44100,               // sample rate
705///     16,                  // bits-per-sample
706///     2,                   // channel count
707///     Some(5),             // total channel-independent samples
708/// ).unwrap();
709///
710/// // write our samples, divided by channel
711/// let written_samples = vec![
712///     vec![1, 2, 3, 4, 5],
713///     vec![-1, -2, -3, -4, -5],
714/// ];
715/// assert!(writer.write(&written_samples).is_ok());
716///
717/// // finalize writing file
718/// assert!(writer.finalize().is_ok());
719///
720/// flac.rewind().unwrap();
721///
722/// // open reader around written FLAC file
723/// let mut reader = FlacChannelReader::new(flac).unwrap();
724///
725/// // read a buffer's worth of samples
726/// let read_samples = reader.fill_buf().unwrap();
727///
728/// // ensure the channels match
729/// assert_eq!(read_samples.len(), written_samples.len());
730/// assert_eq!(read_samples[0], written_samples[0]);
731/// assert_eq!(read_samples[1], written_samples[1]);
732/// ```
733pub struct FlacChannelWriter<W: std::io::Write + std::io::Seek> {
734    // the wrapped encoder
735    encoder: Encoder<W>,
736    // channels that make up a partial FLAC frame
737    channel_bufs: Vec<VecDeque<i32>>,
738    // a whole set of samples for a FLAC frame
739    frame: Frame,
740    // size of a single frame in samples
741    frame_sample_size: usize,
742    // size of a single sample in bytes
743    bytes_per_sample: usize,
744    // whether the encoder has finalized the file
745    finalized: bool,
746}
747
748impl<W: std::io::Write + std::io::Seek> FlacChannelWriter<W> {
749    /// Creates new FLAC writer with the given parameters
750    ///
751    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
752    ///
753    /// `bits_per_sample` must be between 1 and 32.
754    ///
755    /// `channels` must be between 1 and 8.
756    ///
757    /// Note that if `total_samples` is indicated,
758    /// the number of samples written *must*
759    /// be equal to that amount or an error will occur when writing
760    /// or finalizing the stream.
761    ///
762    /// # Errors
763    ///
764    /// Returns I/O error if unable to write initial
765    /// metadata blocks.
766    /// Returns error if any of the encoding parameters are invalid.
767    pub fn new(
768        writer: W,
769        options: Options,
770        sample_rate: u32,
771        bits_per_sample: u32,
772        channels: u8,
773        total_samples: Option<u64>,
774    ) -> Result<Self, Error> {
775        let bits_per_sample: SignedBitCount<32> = bits_per_sample
776            .try_into()
777            .map_err(|_| Error::InvalidBitsPerSample)?;
778
779        let bytes_per_sample = u32::from(bits_per_sample).div_ceil(8) as usize;
780
781        Ok(Self {
782            channel_bufs: vec![VecDeque::default(); channels.into()],
783            frame: Frame::empty(channels.into(), bits_per_sample.into()),
784            bytes_per_sample,
785            frame_sample_size: options.block_size as usize,
786            encoder: Encoder::new(
787                writer,
788                options,
789                sample_rate,
790                bits_per_sample,
791                channels,
792                total_samples.and_then(NonZero::new),
793            )?,
794            finalized: false,
795        })
796    }
797
798    /// Creates new FLAC writer with CDDA parameters
799    ///
800    /// Sample rate is 44100 Hz, bits-per-sample is 16,
801    /// channels is 2.
802    ///
803    /// Note that if `total_samples` is indicated,
804    /// the number of samples written *must*
805    /// be equal to that amount or an error will occur when writing
806    /// or finalizing the stream.
807    ///
808    /// # Errors
809    ///
810    /// Returns I/O error if unable to write initial
811    /// metadata blocks.
812    /// Returns error if any of the encoding parameters are invalid.
813    pub fn new_cdda(
814        writer: W,
815        options: Options,
816        total_samples: Option<u64>,
817    ) -> Result<Self, Error> {
818        Self::new(writer, options, 44100, 16, 2, total_samples)
819    }
820
821    /// Given a set of channels containing samples, writes them to the FLAC file
822    ///
823    /// Channels should be a slice-able set of sample slices, like:
824    /// [[left₀ , left₁ , left₂ , …] , [right₀ , right₁ , right₂ , …]]
825    ///
826    /// The number of channels must be identical to the
827    /// channel count indicated when intializing the encoder.
828    ///
829    /// The number of samples in each channel must also be identical.
830    pub fn write<C, S>(&mut self, channels: C) -> Result<(), Error>
831    where
832        C: AsRef<[S]>,
833        S: AsRef<[i32]>,
834    {
835        use crate::audio::MultiZip;
836
837        // sanity-check our channel inputs
838        let channels = channels.as_ref();
839
840        match channels {
841            whole @ [first, rest @ ..]
842                if whole.len() == usize::from(self.encoder.channel_count().get()) =>
843            {
844                if rest
845                    .iter()
846                    .any(|c| c.as_ref().len() != first.as_ref().len())
847                {
848                    return Err(Error::ChannelLengthMismatch);
849                }
850            }
851            _ => {
852                return Err(Error::ChannelCountMismatch);
853            }
854        }
855
856        // dump whole set of samples into our internal channel buffers
857        for (buf, channel) in self.channel_bufs.iter_mut().zip(channels) {
858            buf.extend(channel.as_ref());
859        }
860
861        // encode as many FLAC frames as possible (which may be 0)
862        let mut encoded_frames = 0;
863        for bufs in self
864            .channel_bufs
865            .iter_mut()
866            .map(|v| v.make_contiguous().chunks_exact_mut(self.frame_sample_size))
867            .collect::<MultiZip<_>>()
868        {
869            // update running MD5 sum calculation
870            update_md5(
871                &mut self.encoder.md5,
872                bufs.iter()
873                    .map(|c| c.iter().copied())
874                    .collect::<MultiZip<_>>()
875                    .flatten(),
876                self.bytes_per_sample,
877            );
878
879            // encode fresh FLAC frame
880            self.encoder.encode(self.frame.fill_from_channels(&bufs))?;
881
882            encoded_frames += 1;
883        }
884
885        // drain encoded samples from buffers
886        for channel in self.channel_bufs.iter_mut() {
887            channel.drain(0..self.frame_sample_size * encoded_frames);
888        }
889
890        Ok(())
891    }
892
893    fn finalize_inner(&mut self) -> Result<(), Error> {
894        use crate::audio::MultiZip;
895
896        if !self.finalized {
897            self.finalized = true;
898
899            // encode as many samples possible into final frame, if necessary
900            if !self.channel_bufs[0].is_empty() {
901                // update running MD5 sum calculation
902                update_md5(
903                    &mut self.encoder.md5,
904                    self.channel_bufs
905                        .iter()
906                        .map(|c| c.iter().copied())
907                        .collect::<MultiZip<_>>()
908                        .flatten(),
909                    self.bytes_per_sample,
910                );
911
912                // encode final FLAC frame
913                self.encoder.encode(
914                    self.frame.fill_from_channels(
915                        self.channel_bufs
916                            .iter_mut()
917                            .map(|v| v.make_contiguous())
918                            .collect::<ArrayVec<_, MAX_CHANNELS>>()
919                            .as_slice(),
920                    ),
921                )?;
922            }
923
924            self.encoder.finalize_inner()
925        } else {
926            Ok(())
927        }
928    }
929
930    /// Attempt to finalize stream
931    ///
932    /// It is necessary to finalize the FLAC encoder
933    /// so that it will write any partially unwritten samples
934    /// to the stream and update the [`crate::metadata::Streaminfo`] and [`crate::metadata::SeekTable`] blocks
935    /// with their final values.
936    ///
937    /// Dropping the encoder will attempt to finalize the stream
938    /// automatically, but will ignore any errors that may occur.
939    pub fn finalize(mut self) -> Result<(), Error> {
940        self.finalize_inner()?;
941        Ok(())
942    }
943}
944
945impl FlacChannelWriter<BufWriter<File>> {
946    /// Creates new FLAC file at the given path
947    ///
948    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
949    ///
950    /// `bits_per_sample` must be between 1 and 32.
951    ///
952    /// `channels` must be between 1 and 8.
953    ///
954    /// Note that if `total_bytes` is indicated,
955    /// the number of bytes written *must*
956    /// be equal to that amount or an error will occur when writing
957    /// or finalizing the stream.
958    ///
959    /// # Errors
960    ///
961    /// Returns I/O error if unable to write initial
962    /// metadata blocks.
963    #[inline]
964    pub fn create<P: AsRef<Path>>(
965        path: P,
966        options: Options,
967        sample_rate: u32,
968        bits_per_sample: u32,
969        channels: u8,
970        total_samples: Option<u64>,
971    ) -> Result<Self, Error> {
972        FlacChannelWriter::new(
973            BufWriter::new(options.create(path)?),
974            options,
975            sample_rate,
976            bits_per_sample,
977            channels,
978            total_samples,
979        )
980    }
981
982    /// Creates new FLAC file at the given path with CDDA parameters
983    ///
984    /// Sample rate is 44100 Hz, bits-per-sample is 16,
985    /// channels is 2.
986    ///
987    /// Note that if `total_bytes` is indicated,
988    /// the number of bytes written *must*
989    /// be equal to that amount or an error will occur when writing
990    /// or finalizing the stream.
991    ///
992    /// # Errors
993    ///
994    /// Returns I/O error if unable to write initial
995    /// metadata blocks.
996    #[inline]
997    pub fn create_cdda<P: AsRef<Path>>(
998        path: P,
999        options: Options,
1000        total_samples: Option<u64>,
1001    ) -> Result<Self, Error> {
1002        Self::create(path, options, 44100, 16, 2, total_samples)
1003    }
1004}
1005
1006/// A FLAC writer which operates on streamed output
1007///
1008/// Because this encodes FLAC frames without any metadata
1009/// blocks or finalizing, it does not need to be seekable.
1010///
1011/// # Example
1012///
1013/// ```
1014/// use flac_codec::{
1015///     decode::{FlacStreamReader, FrameBuf},
1016///     encode::{FlacStreamWriter, Options},
1017/// };
1018/// use std::io::{Cursor, Seek};
1019/// use bitstream_io::SignedBitCount;
1020///
1021/// let mut flac = Cursor::new(vec![]);
1022///
1023/// let samples = (0..100).collect::<Vec<i32>>();
1024///
1025/// let mut w = FlacStreamWriter::new(&mut flac, Options::default());
1026///
1027/// // write a single FLAC frame with some samples
1028/// w.write(
1029///     44100,  // sample rate
1030///     1,      // channels
1031///     16,     // bits-per-sample
1032///     &samples,
1033/// ).unwrap();
1034///
1035/// flac.rewind().unwrap();
1036///
1037/// let mut r = FlacStreamReader::new(&mut flac);
1038///
1039/// // read a single FLAC frame with some samples
1040/// assert_eq!(
1041///     r.read().unwrap(),
1042///     FrameBuf {
1043///         samples: &samples,
1044///         sample_rate: 44100,
1045///         channels: 1,
1046///         bits_per_sample: 16,
1047///     },
1048/// );
1049/// ```
1050pub struct FlacStreamWriter<W> {
1051    // the writer we're outputting to
1052    writer: W,
1053    // various encoding optins
1054    options: EncoderOptions,
1055    // various encoder caches
1056    caches: EncodingCaches,
1057    // a whole set of samples for a FLAC frame
1058    frame: Frame,
1059    // the current frame number
1060    frame_number: FrameNumber,
1061}
1062
1063impl<W: std::io::Write> FlacStreamWriter<W> {
1064    /// Creates new stream writer
1065    pub fn new(writer: W, options: Options) -> Self {
1066        Self {
1067            writer,
1068            options: EncoderOptions {
1069                max_partition_order: options.max_partition_order,
1070                mid_side: options.mid_side,
1071                seektable_interval: options.seektable_interval,
1072                max_lpc_order: options.max_lpc_order,
1073                window: options.window,
1074                exhaustive_channel_correlation: options.exhaustive_channel_correlation,
1075                use_rice2: false,
1076            },
1077            caches: EncodingCaches::default(),
1078            frame: Frame::default(),
1079            frame_number: FrameNumber::default(),
1080        }
1081    }
1082
1083    /// Writes a whole FLAC frame to our output stream
1084    ///
1085    /// Samples are interleaved by channel, like:
1086    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
1087    ///
1088    /// This writes a whole FLAC frame to the output stream on each call.
1089    ///
1090    /// # Errors
1091    ///
1092    /// Returns an error of any of the parameters are invalid
1093    /// or if an I/O error occurs when writing to the stream.
1094    pub fn write(
1095        &mut self,
1096        sample_rate: u32,
1097        channels: u8,
1098        bits_per_sample: u32,
1099        samples: &[i32],
1100    ) -> Result<(), Error> {
1101        use crate::crc::{Crc16, CrcWriter};
1102        use crate::stream::{BitsPerSample, FrameHeader, SampleRate};
1103
1104        let bits_per_sample: SignedBitCount<32> = bits_per_sample
1105            .try_into()
1106            .map_err(|_| Error::NonSubsetBitsPerSample)?;
1107
1108        // samples must divide evenly into channels
1109        if !samples.len().is_multiple_of(usize::from(channels)) {
1110            return Err(Error::SamplesNotDivisibleByChannels);
1111        } else if !(1..=8).contains(&channels) {
1112            return Err(Error::ExcessiveChannels);
1113        }
1114
1115        self.options.use_rice2 = u32::from(bits_per_sample) > 16;
1116
1117        self.frame
1118            .resize(bits_per_sample.into(), channels.into(), 0);
1119        self.frame.fill_from_samples(samples);
1120
1121        // block size must be valid
1122        let block_size: crate::stream::BlockSize<u16> = crate::stream::BlockSize::try_from(
1123            u16::try_from(self.frame.pcm_frames()).map_err(|_| Error::InvalidBlockSize)?,
1124        )
1125        .map_err(|_| Error::InvalidBlockSize)?;
1126
1127        // sample rate must be valid for subset streams
1128        let sample_rate: SampleRate<u32> = sample_rate.try_into().and_then(|rate| match rate {
1129            SampleRate::Streaminfo(_) => Err(Error::NonSubsetSampleRate),
1130            rate => Ok(rate),
1131        })?;
1132
1133        // bits-per-sample must be valid for subset streams
1134        let header_bits_per_sample = match BitsPerSample::from(bits_per_sample) {
1135            BitsPerSample::Streaminfo(_) => return Err(Error::NonSubsetBitsPerSample),
1136            bps => bps,
1137        };
1138
1139        let mut w: CrcWriter<_, Crc16> = CrcWriter::new(&mut self.writer);
1140        let mut bw: BitWriter<CrcWriter<&mut W, Crc16>, BigEndian>;
1141
1142        match self
1143            .frame
1144            .channels()
1145            .collect::<ArrayVec<&[i32], MAX_CHANNELS>>()
1146            .as_slice()
1147        {
1148            [channel] => {
1149                FrameHeader {
1150                    blocking_strategy: false,
1151                    frame_number: self.frame_number,
1152                    block_size: (channel.len() as u16)
1153                        .try_into()
1154                        .expect("frame cannot be empty"),
1155                    sample_rate,
1156                    bits_per_sample: header_bits_per_sample,
1157                    channel_assignment: ChannelAssignment::Independent(Independent::Mono),
1158                }
1159                .write_subset(&mut w)?;
1160
1161                bw = BitWriter::new(w);
1162
1163                self.caches.channels.resize_with(1, ChannelCache::default);
1164
1165                encode_subframe(
1166                    &self.options,
1167                    &mut self.caches.channels[0],
1168                    CorrelatedChannel::independent(bits_per_sample, channel),
1169                )?
1170                .playback(&mut bw)?;
1171            }
1172            [left, right] if self.options.exhaustive_channel_correlation => {
1173                let Correlated {
1174                    channel_assignment,
1175                    channels: [channel_0, channel_1],
1176                } = correlate_channels_exhaustive(
1177                    &self.options,
1178                    &mut self.caches.correlated,
1179                    [left, right],
1180                    bits_per_sample,
1181                )?;
1182
1183                FrameHeader {
1184                    blocking_strategy: false,
1185                    frame_number: self.frame_number,
1186                    block_size,
1187                    sample_rate,
1188                    bits_per_sample: header_bits_per_sample,
1189                    channel_assignment,
1190                }
1191                .write_subset(&mut w)?;
1192
1193                bw = BitWriter::new(w);
1194
1195                channel_0.playback(&mut bw)?;
1196                channel_1.playback(&mut bw)?;
1197            }
1198            [left, right] => {
1199                let Correlated {
1200                    channel_assignment,
1201                    channels: [channel_0, channel_1],
1202                } = correlate_channels(
1203                    &self.options,
1204                    &mut self.caches.correlated,
1205                    [left, right],
1206                    bits_per_sample,
1207                );
1208
1209                FrameHeader {
1210                    blocking_strategy: false,
1211                    frame_number: self.frame_number,
1212                    block_size,
1213                    sample_rate,
1214                    bits_per_sample: header_bits_per_sample,
1215                    channel_assignment,
1216                }
1217                .write_subset(&mut w)?;
1218
1219                self.caches.channels.resize_with(2, ChannelCache::default);
1220                let [cache_0, cache_1] = self.caches.channels.get_disjoint_mut([0, 1]).unwrap();
1221                let (channel_0, channel_1) = join(
1222                    || encode_subframe(&self.options, cache_0, channel_0),
1223                    || encode_subframe(&self.options, cache_1, channel_1),
1224                );
1225
1226                bw = BitWriter::new(w);
1227
1228                channel_0?.playback(&mut bw)?;
1229                channel_1?.playback(&mut bw)?;
1230            }
1231            channels => {
1232                // non-stereo frames are always encoded independently
1233                FrameHeader {
1234                    blocking_strategy: false,
1235                    frame_number: self.frame_number,
1236                    block_size,
1237                    sample_rate,
1238                    bits_per_sample: header_bits_per_sample,
1239                    channel_assignment: ChannelAssignment::Independent(
1240                        channels.len().try_into().expect("invalid channel count"),
1241                    ),
1242                }
1243                .write_subset(&mut w)?;
1244
1245                bw = BitWriter::new(w);
1246
1247                self.caches
1248                    .channels
1249                    .resize_with(channels.len(), ChannelCache::default);
1250
1251                vec_map(
1252                    self.caches.channels.iter_mut().zip(channels).collect(),
1253                    |(cache, channel)| {
1254                        encode_subframe(
1255                            &self.options,
1256                            cache,
1257                            CorrelatedChannel::independent(bits_per_sample, channel),
1258                        )
1259                    },
1260                )
1261                .into_iter()
1262                .try_for_each(|r| r.and_then(|r| r.playback(bw.by_ref()).map_err(Error::Io)))?;
1263            }
1264        }
1265
1266        let crc16: u16 = bw.aligned_writer()?.checksum().into();
1267        bw.write_from(crc16)?;
1268
1269        if self.frame_number.try_increment().is_err() {
1270            self.frame_number = FrameNumber::default();
1271        }
1272
1273        Ok(())
1274    }
1275
1276    /// Writes a whole FLAC frame to our output stream with CDDA parameters
1277    ///
1278    /// Samples are interleaved by channel, like:
1279    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
1280    ///
1281    /// This writes a whole FLAC frame to the output stream on each call.
1282    ///
1283    /// # Errors
1284    ///
1285    /// Returns an error of any of the parameters are invalid
1286    /// or if an I/O error occurs when writing to the stream.
1287    pub fn write_cdda(&mut self, samples: &[i32]) -> Result<(), Error> {
1288        self.write(44100, 2, 16, samples)
1289    }
1290}
1291
1292fn update_md5(md5: &mut md5::Context, samples: impl Iterator<Item = i32>, bytes_per_sample: usize) {
1293    use crate::byteorder::{Endianness, LittleEndian};
1294
1295    match bytes_per_sample {
1296        1 => {
1297            for s in samples {
1298                md5.consume(LittleEndian::i8_to_bytes(s as i8));
1299            }
1300        }
1301        2 => {
1302            for s in samples {
1303                md5.consume(LittleEndian::i16_to_bytes(s as i16));
1304            }
1305        }
1306        3 => {
1307            for s in samples {
1308                md5.consume(LittleEndian::i24_to_bytes(s));
1309            }
1310        }
1311        4 => {
1312            for s in samples {
1313                md5.consume(LittleEndian::i32_to_bytes(s));
1314            }
1315        }
1316        _ => panic!("unsupported number of bytes per sample"),
1317    }
1318}
1319
1320/// The interval of seek points to generate
1321#[derive(Copy, Clone, Debug)]
1322pub enum SeekTableInterval {
1323    ///Generate seekpoint every nth seconds
1324    Seconds(NonZero<u8>),
1325    /// Generate seekpoint every nth frames
1326    Frames(NonZero<usize>),
1327}
1328
1329impl Default for SeekTableInterval {
1330    fn default() -> Self {
1331        Self::Seconds(NonZero::new(10).unwrap())
1332    }
1333}
1334
1335impl SeekTableInterval {
1336    // decimates full set of seekpoints based on the requested
1337    // seektable style, or returns None if no seektable is requested
1338    fn filter<'s>(
1339        self,
1340        sample_rate: u32,
1341        seekpoints: impl IntoIterator<Item = EncoderSeekPoint> + 's,
1342    ) -> Box<dyn Iterator<Item = EncoderSeekPoint> + 's> {
1343        match self {
1344            Self::Seconds(seconds) => {
1345                let nth_sample = u64::from(u32::from(seconds.get()) * sample_rate);
1346                let mut offset = 0;
1347                Box::new(seekpoints.into_iter().filter(move |point| {
1348                    if point.range().contains(&offset) {
1349                        offset += nth_sample;
1350                        true
1351                    } else {
1352                        false
1353                    }
1354                }))
1355            }
1356            Self::Frames(frames) => Box::new(seekpoints.into_iter().step_by(frames.get())),
1357        }
1358    }
1359}
1360
1361/// FLAC encoding options
1362#[derive(Clone, Debug)]
1363pub struct Options {
1364    // whether to clobber existing file
1365    clobber: bool,
1366    block_size: u16,
1367    max_partition_order: u32,
1368    mid_side: bool,
1369    metadata: BlockList,
1370    seektable_interval: Option<SeekTableInterval>,
1371    max_lpc_order: Option<NonZero<u8>>,
1372    window: Window,
1373    exhaustive_channel_correlation: bool,
1374}
1375
1376impl Default for Options {
1377    fn default() -> Self {
1378        // a dummy placeholder value
1379        // since we can't know the stream parameters yet
1380        let mut metadata = BlockList::new(Streaminfo {
1381            minimum_block_size: 0,
1382            maximum_block_size: 0,
1383            minimum_frame_size: None,
1384            maximum_frame_size: None,
1385            sample_rate: 0,
1386            channels: NonZero::new(1).unwrap(),
1387            bits_per_sample: SignedBitCount::new::<4>(),
1388            total_samples: None,
1389            md5: None,
1390        });
1391
1392        metadata.insert(crate::metadata::Padding {
1393            size: 4096u16.into(),
1394        });
1395
1396        Self {
1397            clobber: false,
1398            block_size: 4096,
1399            mid_side: true,
1400            max_partition_order: 5,
1401            metadata,
1402            seektable_interval: Some(SeekTableInterval::default()),
1403            max_lpc_order: NonZero::new(8),
1404            window: Window::default(),
1405            exhaustive_channel_correlation: true,
1406        }
1407    }
1408}
1409
1410impl Options {
1411    /// Sets new block size
1412    ///
1413    /// Block size must be ≥ 16
1414    ///
1415    /// For subset streams, this must be ≤ 4608
1416    /// if the sample rate is ≤ 48 kHz -
1417    /// or ≤ 16384 for higher sample rates.
1418    pub fn block_size(self, block_size: u16) -> Result<Self, OptionsError> {
1419        match block_size {
1420            0..16 => Err(OptionsError::InvalidBlockSize),
1421            16.. => Ok(Self { block_size, ..self }),
1422        }
1423    }
1424
1425    /// Sets new maximum LPC order
1426    ///
1427    /// The valid range is: 0 < `max_lpc_order` ≤ 32
1428    ///
1429    /// A value of `None` means that no LPC subframes will be encoded.
1430    pub fn max_lpc_order(self, max_lpc_order: Option<u8>) -> Result<Self, OptionsError> {
1431        Ok(Self {
1432            max_lpc_order: max_lpc_order
1433                .map(|o| {
1434                    o.try_into()
1435                        .ok()
1436                        .filter(|o| *o <= NonZero::new(32).unwrap())
1437                        .ok_or(OptionsError::InvalidLpcOrder)
1438                })
1439                .transpose()?,
1440            ..self
1441        })
1442    }
1443
1444    /// Sets maximum residual partion order.
1445    ///
1446    /// The valid range is: 0 ≤ `max_partition_order` ≤ 15
1447    pub fn max_partition_order(self, max_partition_order: u32) -> Result<Self, OptionsError> {
1448        match max_partition_order {
1449            0..=15 => Ok(Self {
1450                max_partition_order,
1451                ..self
1452            }),
1453            16.. => Err(OptionsError::InvalidMaxPartitions),
1454        }
1455    }
1456
1457    /// Whether to use mid-side encoding
1458    ///
1459    /// The default is `true`.
1460    pub fn mid_side(self, mid_side: bool) -> Self {
1461        Self { mid_side, ..self }
1462    }
1463
1464    /// The windowing function to use for input samples
1465    pub fn window(self, window: Window) -> Self {
1466        Self { window, ..self }
1467    }
1468
1469    /// Whether to calculate the best channel correlation quickly
1470    ///
1471    /// The default is `false`
1472    pub fn fast_channel_correlation(self, fast: bool) -> Self {
1473        Self {
1474            exhaustive_channel_correlation: !fast,
1475            ..self
1476        }
1477    }
1478
1479    /// Updates size of padding block
1480    ///
1481    /// `size` must be < 2²⁴
1482    ///
1483    /// If `size` is set to 0, removes the block entirely.
1484    ///
1485    /// The default is to add a 4096 byte padding block.
1486    pub fn padding(mut self, size: u32) -> Result<Self, OptionsError> {
1487        use crate::metadata::Padding;
1488
1489        match size
1490            .try_into()
1491            .map_err(|_| OptionsError::ExcessivePadding)?
1492        {
1493            BlockSize::ZERO => self.metadata.remove::<Padding>(),
1494            size => self.metadata.update::<Padding>(|p| {
1495                p.size = size;
1496            }),
1497        }
1498        Ok(self)
1499    }
1500
1501    /// Remove any padding blocks from metadata
1502    ///
1503    /// This makes the file smaller, but will likely require
1504    /// rewriting it if any metadata needs to be modified later.
1505    pub fn no_padding(mut self) -> Self {
1506        self.metadata.remove::<crate::metadata::Padding>();
1507        self
1508    }
1509
1510    /// Adds new tag to comment metadata block
1511    ///
1512    /// Creates new [`crate::metadata::VorbisComment`] block if not already present.
1513    pub fn tag<S>(mut self, field: &str, value: S) -> Self
1514    where
1515        S: std::fmt::Display,
1516    {
1517        self.metadata
1518            .update::<VorbisComment>(|vc| vc.insert(field, value));
1519        self
1520    }
1521
1522    /// Replaces entire [`crate::metadata::VorbisComment`] metadata block
1523    ///
1524    /// This may be more convenient when adding many fields at once.
1525    pub fn comment(mut self, comment: VorbisComment) -> Self {
1526        self.metadata.insert(comment);
1527        self
1528    }
1529
1530    /// Add new [`crate::metadata::Picture`] block to metadata
1531    ///
1532    /// Files may contain multiple [`crate::metadata::Picture`] blocks,
1533    /// and this adds a new block each time it is used.
1534    pub fn picture(mut self, picture: Picture) -> Self {
1535        self.metadata.insert(picture);
1536        self
1537    }
1538
1539    /// Add new [`crate::metadata::Cuesheet`] block to metadata
1540    ///
1541    /// Files may (theoretically) contain multiple [`crate::metadata::Cuesheet`] blocks,
1542    /// and this adds a new block each time it is used.
1543    ///
1544    /// In practice, CD images almost always use only a single
1545    /// cue sheet.
1546    pub fn cuesheet(mut self, cuesheet: Cuesheet) -> Self {
1547        self.metadata.insert(cuesheet);
1548        self
1549    }
1550
1551    /// Add new [`crate::metadata::Application`] block to metadata
1552    ///
1553    /// Files may contain multiple [`crate::metadata::Application`] blocks,
1554    /// and this adds a new block each time it is used.
1555    pub fn application(mut self, application: Application) -> Self {
1556        self.metadata.insert(application);
1557        self
1558    }
1559
1560    /// Generate [`crate::metadata::SeekTable`] with the given number of seconds between seek points
1561    ///
1562    /// The default is to generate a SEEKTABLE with 10 seconds between seek points.
1563    ///
1564    /// If `seconds` is 0, removes the SEEKTABLE block.
1565    ///
1566    /// The interval between seek points may be larger than requested
1567    /// if the encoder's block size is larger than the seekpoint interval.
1568    pub fn seektable_seconds(mut self, seconds: u8) -> Self {
1569        // note that we can't drop a placeholder seektable
1570        // into the metadata blocks until we know
1571        // the sample rate and total samples of our stream
1572        self.seektable_interval = NonZero::new(seconds).map(SeekTableInterval::Seconds);
1573        self
1574    }
1575
1576    /// Generate [`crate::metadata::SeekTable`] with the given number of FLAC frames between seek points
1577    ///
1578    /// If `frames` is 0, removes the SEEKTABLE block
1579    pub fn seektable_frames(mut self, frames: usize) -> Self {
1580        self.seektable_interval = NonZero::new(frames).map(SeekTableInterval::Frames);
1581        self
1582    }
1583
1584    /// Do not generate a seektable in our encoded file
1585    pub fn no_seektable(self) -> Self {
1586        Self {
1587            seektable_interval: None,
1588            ..self
1589        }
1590    }
1591
1592    /// Add new block to metadata
1593    ///
1594    /// If the block may only occur once in a file,
1595    /// any previous block of that same type is removed.
1596    pub fn add_block<B>(&mut self, block: B) -> &mut Self
1597    where
1598        B: PortableMetadataBlock,
1599    {
1600        self.metadata.insert(block);
1601        self
1602    }
1603
1604    /// Add new blocks to metadata
1605    ///
1606    /// If the block may only occur once in a file,
1607    /// any current block of that type is replaced by
1608    /// the final block in the iterator - if any.
1609    /// Otherwise, all blocks in the iterator are used.
1610    pub fn add_blocks<B>(&mut self, iter: impl IntoIterator<Item = B>) -> &mut Self
1611    where
1612        B: PortableMetadataBlock,
1613    {
1614        for block in iter {
1615            self.metadata.insert(block);
1616        }
1617        self
1618    }
1619
1620    /// Overwrites existing file if it already exists
1621    ///
1622    /// This only applies to encoding functions which
1623    /// create files from paths.
1624    ///
1625    /// The default is to not overwrite files
1626    /// if they already exist.
1627    pub fn overwrite(mut self) -> Self {
1628        self.clobber = true;
1629        self
1630    }
1631
1632    /// Returns the fastest encoding options
1633    ///
1634    /// These are tuned to encode as quickly as possible.
1635    pub fn fast() -> Self {
1636        Self {
1637            block_size: 1152,
1638            mid_side: false,
1639            max_partition_order: 3,
1640            max_lpc_order: None,
1641            exhaustive_channel_correlation: false,
1642            ..Self::default()
1643        }
1644    }
1645
1646    /// Returns the fastest encoding options
1647    ///
1648    /// These are tuned to encode files as small as possible.
1649    pub fn best() -> Self {
1650        Self {
1651            block_size: 4096,
1652            mid_side: true,
1653            max_partition_order: 6,
1654            max_lpc_order: NonZero::new(12),
1655            ..Self::default()
1656        }
1657    }
1658
1659    /// Creates files according to whether clobber is set or not
1660    fn create<P: AsRef<Path>>(&self, path: P) -> std::io::Result<File> {
1661        if self.clobber {
1662            File::create(path)
1663        } else {
1664            use std::fs::OpenOptions;
1665
1666            OpenOptions::new()
1667                .write(true)
1668                .create_new(true)
1669                .open(path.as_ref())
1670        }
1671    }
1672}
1673
1674/// An error when specifying encoding options
1675#[derive(Debug)]
1676pub enum OptionsError {
1677    /// Selected block size is too small
1678    InvalidBlockSize,
1679    /// Maximum LPC order is too large
1680    InvalidLpcOrder,
1681    /// Maximum residual partitions is too large
1682    InvalidMaxPartitions,
1683    /// Selected padding size is too large
1684    ExcessivePadding,
1685}
1686
1687impl std::error::Error for OptionsError {}
1688
1689impl std::fmt::Display for OptionsError {
1690    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1691        match self {
1692            Self::InvalidBlockSize => "block size must be >= 16".fmt(f),
1693            Self::InvalidLpcOrder => "maximum LPC order must be <= 32".fmt(f),
1694            Self::InvalidMaxPartitions => "max partition order must be <= 15".fmt(f),
1695            Self::ExcessivePadding => "padding size is too large for block".fmt(f),
1696        }
1697    }
1698}
1699
1700/// A cut-down version of Options without the metadata blocks
1701struct EncoderOptions {
1702    max_partition_order: u32,
1703    mid_side: bool,
1704    seektable_interval: Option<SeekTableInterval>,
1705    max_lpc_order: Option<NonZero<u8>>,
1706    window: Window,
1707    exhaustive_channel_correlation: bool,
1708    use_rice2: bool,
1709}
1710
1711/// The method to use for windowing the input signal
1712#[derive(Copy, Clone, Debug)]
1713pub enum Window {
1714    /// Basic rectangular window
1715    Rectangle,
1716    /// Hann window
1717    Hann,
1718    /// Tukey window
1719    Tukey(f32),
1720}
1721
1722// TODO - add more windowing options
1723
1724impl Window {
1725    fn generate(&self, window: &mut [f64]) {
1726        use std::f64::consts::PI;
1727
1728        match self {
1729            Self::Rectangle => window.fill(1.0),
1730            Self::Hann => {
1731                // verified output against reference implementation
1732                // See: FLAC__window_hann()
1733
1734                let np =
1735                    f64::from(u16::try_from(window.len()).expect("window size too large")) - 1.0;
1736
1737                window.iter_mut().zip(0u16..).for_each(|(w, n)| {
1738                    *w = 0.5 - 0.5 * (2.0 * PI * f64::from(n) / np).cos();
1739                });
1740            }
1741            Self::Tukey(p) => match p {
1742                // verified output against reference implementation
1743                // See: FLAC__window_tukey()
1744                ..=0.0 => {
1745                    window.fill(1.0);
1746                }
1747                1.0.. => {
1748                    Self::Hann.generate(window);
1749                }
1750                0.0..1.0 => {
1751                    match ((f64::from(*p) / 2.0 * window.len() as f64) as usize).checked_sub(1) {
1752                        Some(np) => match window.get_disjoint_mut([
1753                            0..np,
1754                            np..window.len() - np,
1755                            window.len() - np..window.len(),
1756                        ]) {
1757                            Ok([first, mid, last]) => {
1758                                // u16 is maximum block size
1759                                let np = u16::try_from(np).expect("window size too large");
1760
1761                                for ((x, y), n) in
1762                                    first.iter_mut().zip(last.iter_mut().rev()).zip(0u16..)
1763                                {
1764                                    *x = 0.5 - 0.5 * (PI * f64::from(n) / f64::from(np)).cos();
1765                                    *y = *x;
1766                                }
1767                                mid.fill(1.0);
1768                            }
1769                            Err(_) => {
1770                                window.fill(1.0);
1771                            }
1772                        },
1773                        None => {
1774                            window.fill(1.0);
1775                        }
1776                    }
1777                }
1778                _ => {
1779                    Self::Tukey(0.5).generate(window);
1780                }
1781            },
1782        }
1783    }
1784
1785    fn apply<'w>(
1786        &self,
1787        window: &mut Vec<f64>,
1788        cache: &'w mut Vec<f64>,
1789        samples: &[i32],
1790    ) -> &'w [f64] {
1791        if window.len() != samples.len() {
1792            // need to re-generate window to fit samples
1793            window.resize(samples.len(), 0.0);
1794            self.generate(window);
1795        }
1796
1797        // window signal into cache and return cached slice
1798        cache.clear();
1799        cache.extend(samples.iter().zip(window).map(|(s, w)| f64::from(*s) * *w));
1800        cache.as_slice()
1801    }
1802}
1803
1804impl Default for Window {
1805    fn default() -> Self {
1806        Self::Tukey(0.5)
1807    }
1808}
1809
1810#[derive(Default)]
1811struct EncodingCaches {
1812    channels: Vec<ChannelCache>,
1813    correlated: CorrelationCache,
1814}
1815
1816#[derive(Default)]
1817struct CorrelationCache {
1818    // the average channel samples
1819    average_samples: Vec<i32>,
1820    // the difference channel samples
1821    difference_samples: Vec<i32>,
1822
1823    left_cache: ChannelCache,
1824    right_cache: ChannelCache,
1825    average_cache: ChannelCache,
1826    difference_cache: ChannelCache,
1827}
1828
1829#[derive(Default)]
1830struct ChannelCache {
1831    fixed: FixedCache,
1832    fixed_output: BitRecorder<u32, BigEndian>,
1833    lpc: LpcCache,
1834    lpc_output: BitRecorder<u32, BigEndian>,
1835    constant_output: BitRecorder<u32, BigEndian>,
1836    verbatim_output: BitRecorder<u32, BigEndian>,
1837    wasted: Vec<i32>,
1838}
1839
1840#[derive(Default)]
1841struct FixedCache {
1842    // FIXED subframe buffers, one per order 1-4
1843    fixed_buffers: [Vec<i32>; 4],
1844}
1845
1846#[derive(Default)]
1847struct LpcCache {
1848    window: Vec<f64>,
1849    windowed: Vec<f64>,
1850    residuals: Vec<i32>,
1851}
1852
1853/// A FLAC encoder
1854struct Encoder<W: std::io::Write + std::io::Seek> {
1855    // the writer we're outputting to
1856    writer: Counter<W>,
1857    // the stream's starting offset in the writer, in bytes
1858    start: u64,
1859    // various encoding options
1860    options: EncoderOptions,
1861    // various encoder caches
1862    caches: EncodingCaches,
1863    // our metadata blocks
1864    blocks: BlockList,
1865    // our stream's sample rate
1866    sample_rate: SampleRate<u32>,
1867    // the current frame number
1868    frame_number: FrameNumber,
1869    // the number of channel-independent samples written
1870    samples_written: u64,
1871    // all seekpoints
1872    seekpoints: Vec<EncoderSeekPoint>,
1873    // our running MD5 calculation
1874    md5: md5::Context,
1875    // whether the encoder has finalized the file
1876    finalized: bool,
1877}
1878
1879impl<W: std::io::Write + std::io::Seek> Encoder<W> {
1880    const MAX_SAMPLES: u64 = 68_719_476_736;
1881
1882    fn new(
1883        mut writer: W,
1884        options: Options,
1885        sample_rate: u32,
1886        bits_per_sample: SignedBitCount<32>,
1887        channels: u8,
1888        total_samples: Option<NonZero<u64>>,
1889    ) -> Result<Self, Error> {
1890        use crate::metadata::OptionalBlockType;
1891
1892        let mut blocks = options.metadata;
1893
1894        *blocks.streaminfo_mut() = Streaminfo {
1895            minimum_block_size: options.block_size,
1896            maximum_block_size: options.block_size,
1897            minimum_frame_size: None,
1898            maximum_frame_size: None,
1899            sample_rate: (0..1048576)
1900                .contains(&sample_rate)
1901                .then_some(sample_rate)
1902                .ok_or(Error::InvalidSampleRate)?,
1903            bits_per_sample,
1904            channels: (1..=8)
1905                .contains(&channels)
1906                .then_some(channels)
1907                .and_then(NonZero::new)
1908                .ok_or(Error::ExcessiveChannels)?,
1909            total_samples: match total_samples {
1910                None => None,
1911                total_samples @ Some(samples) => match samples.get() {
1912                    0..Self::MAX_SAMPLES => total_samples,
1913                    _ => return Err(Error::ExcessiveTotalSamples),
1914                },
1915            },
1916            md5: None,
1917        };
1918
1919        // insert a dummy SeekTable to be populated later
1920        if let Some(total_samples) = total_samples
1921            && let Some(placeholders) = options.seektable_interval.map(|s| {
1922                s.filter(
1923                    sample_rate,
1924                    EncoderSeekPoint::placeholders(total_samples.get(), options.block_size),
1925                )
1926            })
1927        {
1928            use crate::metadata::SeekTable;
1929
1930            blocks.insert(SeekTable {
1931                // placeholder points should always be contiguous
1932                points: placeholders
1933                    .take(SeekTable::MAX_POINTS)
1934                    .map(|p| p.into())
1935                    .collect::<Vec<_>>()
1936                    .try_into()
1937                    .unwrap(),
1938            });
1939        }
1940
1941        let start = writer.stream_position()?;
1942
1943        // sort blocks to put more relevant items at the front
1944        blocks.sort_by(|block| match block {
1945            OptionalBlockType::VorbisComment => 0,
1946            OptionalBlockType::SeekTable => 1,
1947            OptionalBlockType::Picture => 2,
1948            OptionalBlockType::Application => 3,
1949            OptionalBlockType::Cuesheet => 4,
1950            OptionalBlockType::Padding => 5,
1951        });
1952
1953        write_blocks(writer.by_ref(), blocks.blocks())?;
1954
1955        Ok(Self {
1956            start,
1957            writer: Counter::new(writer),
1958            options: EncoderOptions {
1959                max_partition_order: options.max_partition_order,
1960                mid_side: options.mid_side,
1961                seektable_interval: options.seektable_interval,
1962                max_lpc_order: options.max_lpc_order,
1963                window: options.window,
1964                exhaustive_channel_correlation: options.exhaustive_channel_correlation,
1965                use_rice2: u32::from(bits_per_sample) > 16,
1966            },
1967            caches: EncodingCaches::default(),
1968            sample_rate: blocks
1969                .streaminfo()
1970                .sample_rate
1971                .try_into()
1972                .expect("invalid sample rate"),
1973            blocks,
1974            frame_number: FrameNumber::default(),
1975            samples_written: 0,
1976            seekpoints: Vec::new(),
1977            md5: md5::Context::new(),
1978            finalized: false,
1979        })
1980    }
1981
1982    /// The encoder's channel count
1983    fn channel_count(&self) -> NonZero<u8> {
1984        self.blocks.streaminfo().channels
1985    }
1986
1987    /// Encodes an audio frame of PCM samples
1988    ///
1989    /// Depending on the encoder's chosen block size,
1990    /// this may encode zero or more FLAC frames to disk.
1991    ///
1992    /// # Errors
1993    ///
1994    /// Returns an I/O error from the underlying stream,
1995    /// or if the frame's parameters are not a match
1996    /// for the encoder's.
1997    fn encode(&mut self, frame: &Frame) -> Result<(), Error> {
1998        // drop in a new seekpoint
1999        self.seekpoints.push(EncoderSeekPoint {
2000            sample_offset: self.samples_written,
2001            byte_offset: Some(self.writer.count),
2002            frame_samples: frame.pcm_frames() as u16,
2003        });
2004
2005        // update running total of samples written
2006        self.samples_written += frame.pcm_frames() as u64;
2007        if let Some(total_samples) = self.blocks.streaminfo().total_samples
2008            && self.samples_written > total_samples.get()
2009        {
2010            return Err(Error::ExcessiveTotalSamples);
2011        }
2012
2013        encode_frame(
2014            &self.options,
2015            &mut self.caches,
2016            &mut self.writer,
2017            self.blocks.streaminfo_mut(),
2018            &mut self.frame_number,
2019            self.sample_rate,
2020            frame.channels().collect(),
2021        )
2022    }
2023
2024    fn finalize_inner(&mut self) -> Result<(), Error> {
2025        if !self.finalized {
2026            use crate::metadata::SeekTable;
2027
2028            self.finalized = true;
2029
2030            // update SEEKTABLE metadata block with final values
2031            if let Some(encoded_points) = self
2032                .options
2033                .seektable_interval
2034                .map(|s| s.filter(self.sample_rate.into(), self.seekpoints.iter().cloned()))
2035            {
2036                match self.blocks.get_pair_mut() {
2037                    (Some(SeekTable { points }), _) => {
2038                        // placeholder SEEKTABLE already in place,
2039                        // so no need to adjust PADDING to fit
2040
2041                        // ensure points count is the same
2042
2043                        let points_len = points.len();
2044                        points.clear();
2045                        points
2046                            .try_extend(
2047                                encoded_points
2048                                    .into_iter()
2049                                    .map(|p| p.into())
2050                                    .chain(std::iter::repeat(SeekPoint::Placeholder))
2051                                    .take(points_len),
2052                            )
2053                            .unwrap();
2054                    }
2055                    (None, Some(crate::metadata::Padding { size: padding_size })) => {
2056                        // no SEEKTABLE, but there is a PADDING block,
2057                        // so try to shrink PADDING to fit SEEKTABLE
2058
2059                        use crate::metadata::MetadataBlock;
2060
2061                        let seektable = SeekTable {
2062                            points: encoded_points
2063                                .map(|p| p.into())
2064                                .collect::<Vec<_>>()
2065                                .try_into()
2066                                .unwrap(),
2067                        };
2068                        if let Some(new_padding_size) = seektable
2069                            .total_size()
2070                            .and_then(|seektable_size| padding_size.checked_sub(seektable_size))
2071                        {
2072                            *padding_size = new_padding_size;
2073                            self.blocks.insert(seektable);
2074                        }
2075                    }
2076                    (None, None) => { /* no seektable or padding, so nothing to do */ }
2077                }
2078            }
2079
2080            // verify or update final sample count
2081            match &mut self.blocks.streaminfo_mut().total_samples {
2082                Some(expected) => {
2083                    // ensure final sample count matches
2084                    if expected.get() != self.samples_written {
2085                        return Err(Error::SampleCountMismatch);
2086                    }
2087                }
2088                expected @ None => {
2089                    // update final sample count if possible
2090                    if self.samples_written < Self::MAX_SAMPLES {
2091                        *expected =
2092                            Some(NonZero::new(self.samples_written).ok_or(Error::NoSamples)?);
2093                    } else {
2094                        // TODO - should I just leave this blank
2095                        // if too many samples are written?
2096                        return Err(Error::ExcessiveTotalSamples);
2097                    }
2098                }
2099            }
2100
2101            // update STREAMINFO MD5 sum
2102            self.blocks.streaminfo_mut().md5 = Some(self.md5.clone().finalize().0);
2103
2104            // rewrite metadata blocks, relative to the beginning
2105            // of the stream
2106            let writer = self.writer.stream();
2107            writer.seek(std::io::SeekFrom::Start(self.start))?;
2108            write_blocks(writer.by_ref(), self.blocks.blocks())
2109        } else {
2110            Ok(())
2111        }
2112    }
2113}
2114
2115impl<W: std::io::Write + std::io::Seek> Drop for Encoder<W> {
2116    fn drop(&mut self) {
2117        let _ = self.finalize_inner();
2118    }
2119}
2120
2121// Unlike regular SeekPoints, which can have placeholder variants,
2122// these are always defined to be something.  A byte offset
2123// of None indicates a dummy encoder point
2124#[derive(Debug, Clone)]
2125struct EncoderSeekPoint {
2126    sample_offset: u64,
2127    byte_offset: Option<u64>,
2128    frame_samples: u16,
2129}
2130
2131impl EncoderSeekPoint {
2132    // generates set of placeholder points
2133    fn placeholders(total_samples: u64, block_size: u16) -> impl Iterator<Item = EncoderSeekPoint> {
2134        (0..total_samples)
2135            .step_by(usize::from(block_size))
2136            .map(move |sample_offset| EncoderSeekPoint {
2137                sample_offset,
2138                byte_offset: None,
2139                frame_samples: u16::try_from(total_samples - sample_offset)
2140                    .map(|s| s.min(block_size))
2141                    .unwrap_or(block_size),
2142            })
2143    }
2144
2145    // returns sample range of point
2146    fn range(&self) -> std::ops::Range<u64> {
2147        self.sample_offset..(self.sample_offset + u64::from(self.frame_samples))
2148    }
2149}
2150
2151impl From<EncoderSeekPoint> for SeekPoint {
2152    fn from(p: EncoderSeekPoint) -> Self {
2153        match p.byte_offset {
2154            Some(byte_offset) => Self::Defined {
2155                sample_offset: p.sample_offset,
2156                byte_offset,
2157                frame_samples: p.frame_samples,
2158            },
2159            None => Self::Placeholder,
2160        }
2161    }
2162}
2163
2164/// Given a FLAC stream, generates new seek table
2165///
2166/// Though encoders should add seek tables by default,
2167/// sometimes one isn't present.  This function takes
2168/// an existing FLAC file stream and generates a new
2169/// seek table suitable for adding to the file's metadata
2170/// via the [`crate::metadata::update`] function.
2171///
2172/// The stream should be rewound to the beginning of the file.
2173///
2174/// # Errors
2175///
2176/// Returns any error from the underlying stream.
2177///
2178/// # Example
2179/// ```
2180/// use flac_codec::{
2181///     encode::{FlacSampleWriter, Options, SeekTableInterval, generate_seektable},
2182///     metadata::{SeekTable, read_block},
2183/// };
2184/// use std::io::{Cursor, Seek};
2185///
2186/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
2187///
2188/// // add a seekpoint every second
2189/// let options = Options::default().seektable_seconds(1);
2190///
2191/// let mut writer = FlacSampleWriter::new(
2192///     &mut flac,         // our wrapped writer
2193///     options,           // our seektable options
2194///     44100,             // sample rate
2195///     16,                // bits-per-sample
2196///     1,                 // channel count
2197///     Some(60 * 44100),  // one minute's worth of samples
2198/// ).unwrap();
2199///
2200/// // write one minute's worth of samples
2201/// writer.write(vec![0; 60 * 44100].as_slice()).unwrap();
2202///
2203/// // finalize writing file
2204/// assert!(writer.finalize().is_ok());
2205///
2206/// flac.rewind().unwrap();
2207///
2208/// // get existing seektable
2209/// let original_seektable = match read_block::<_, SeekTable>(&mut flac) {
2210///     Ok(Some(seektable)) => seektable,
2211///     _ => panic!("seektable not found"),
2212/// };
2213///
2214/// flac.rewind().unwrap();
2215///
2216/// // generate new seektable, also with seekpoints every second
2217/// let new_seektable = generate_seektable(
2218///     flac,
2219///     SeekTableInterval::Seconds(1.try_into().unwrap())
2220/// ).unwrap();
2221///
2222/// // ensure both seektables are identical
2223/// assert_eq!(original_seektable, new_seektable);
2224/// ```
2225pub fn generate_seektable<R: std::io::Read>(
2226    r: R,
2227    interval: SeekTableInterval,
2228) -> Result<crate::metadata::SeekTable, Error> {
2229    use crate::{
2230        metadata::{Metadata, SeekTable},
2231        stream::FrameIterator,
2232    };
2233
2234    let iter = FrameIterator::new(r)?;
2235    let metadata_len = iter.metadata_len();
2236    let sample_rate = iter.sample_rate();
2237    let mut sample_offset = 0;
2238
2239    iter.map(|r| {
2240        r.map(|(frame, offset)| EncoderSeekPoint {
2241            sample_offset,
2242            byte_offset: Some(offset - metadata_len),
2243            frame_samples: frame.header.block_size.into(),
2244        })
2245        .inspect(|p| {
2246            sample_offset += u64::from(p.frame_samples);
2247        })
2248    })
2249    .collect::<Result<Vec<_>, _>>()
2250    .map(|seekpoints| SeekTable {
2251        points: interval
2252            .filter(sample_rate, seekpoints)
2253            .take(SeekTable::MAX_POINTS)
2254            .map(|p| p.into())
2255            .collect::<Vec<_>>()
2256            .try_into()
2257            .unwrap(),
2258    })
2259}
2260
2261fn encode_frame<W>(
2262    options: &EncoderOptions,
2263    cache: &mut EncodingCaches,
2264    mut writer: W,
2265    streaminfo: &mut Streaminfo,
2266    frame_number: &mut FrameNumber,
2267    sample_rate: SampleRate<u32>,
2268    frame: ArrayVec<&[i32], MAX_CHANNELS>,
2269) -> Result<(), Error>
2270where
2271    W: std::io::Write,
2272{
2273    use crate::Counter;
2274    use crate::crc::{Crc16, CrcWriter};
2275    use crate::stream::FrameHeader;
2276    use bitstream_io::BigEndian;
2277
2278    debug_assert!(!frame.is_empty());
2279
2280    let size = Counter::new(writer.by_ref());
2281    let mut w: CrcWriter<_, Crc16> = CrcWriter::new(size);
2282    let mut bw: BitWriter<CrcWriter<Counter<&mut W>, Crc16>, BigEndian>;
2283
2284    match frame.as_slice() {
2285        [channel] => {
2286            FrameHeader {
2287                blocking_strategy: false,
2288                frame_number: *frame_number,
2289                block_size: (channel.len() as u16)
2290                    .try_into()
2291                    .expect("frame cannot be empty"),
2292                sample_rate,
2293                bits_per_sample: streaminfo.bits_per_sample.into(),
2294                channel_assignment: ChannelAssignment::Independent(Independent::Mono),
2295            }
2296            .write(&mut w, streaminfo)?;
2297
2298            bw = BitWriter::new(w);
2299
2300            cache.channels.resize_with(1, ChannelCache::default);
2301
2302            encode_subframe(
2303                options,
2304                &mut cache.channels[0],
2305                CorrelatedChannel::independent(streaminfo.bits_per_sample, channel),
2306            )?
2307            .playback(&mut bw)?;
2308        }
2309        [left, right] if options.exhaustive_channel_correlation => {
2310            let Correlated {
2311                channel_assignment,
2312                channels: [channel_0, channel_1],
2313            } = correlate_channels_exhaustive(
2314                options,
2315                &mut cache.correlated,
2316                [left, right],
2317                streaminfo.bits_per_sample,
2318            )?;
2319
2320            FrameHeader {
2321                blocking_strategy: false,
2322                frame_number: *frame_number,
2323                block_size: (frame[0].len() as u16)
2324                    .try_into()
2325                    .expect("frame cannot be empty"),
2326                sample_rate,
2327                bits_per_sample: streaminfo.bits_per_sample.into(),
2328                channel_assignment,
2329            }
2330            .write(&mut w, streaminfo)?;
2331
2332            bw = BitWriter::new(w);
2333
2334            channel_0.playback(&mut bw)?;
2335            channel_1.playback(&mut bw)?;
2336        }
2337        [left, right] => {
2338            let Correlated {
2339                channel_assignment,
2340                channels: [channel_0, channel_1],
2341            } = correlate_channels(
2342                options,
2343                &mut cache.correlated,
2344                [left, right],
2345                streaminfo.bits_per_sample,
2346            );
2347
2348            FrameHeader {
2349                blocking_strategy: false,
2350                frame_number: *frame_number,
2351                block_size: (frame[0].len() as u16)
2352                    .try_into()
2353                    .expect("frame cannot be empty"),
2354                sample_rate,
2355                bits_per_sample: streaminfo.bits_per_sample.into(),
2356                channel_assignment,
2357            }
2358            .write(&mut w, streaminfo)?;
2359
2360            cache.channels.resize_with(2, ChannelCache::default);
2361            let [cache_0, cache_1] = cache.channels.get_disjoint_mut([0, 1]).unwrap();
2362            let (channel_0, channel_1) = join(
2363                || encode_subframe(options, cache_0, channel_0),
2364                || encode_subframe(options, cache_1, channel_1),
2365            );
2366
2367            bw = BitWriter::new(w);
2368
2369            channel_0?.playback(&mut bw)?;
2370            channel_1?.playback(&mut bw)?;
2371        }
2372        channels => {
2373            // non-stereo frames are always encoded independently
2374
2375            FrameHeader {
2376                blocking_strategy: false,
2377                frame_number: *frame_number,
2378                block_size: (channels[0].len() as u16)
2379                    .try_into()
2380                    .expect("frame cannot be empty"),
2381                sample_rate,
2382                bits_per_sample: streaminfo.bits_per_sample.into(),
2383                channel_assignment: ChannelAssignment::Independent(
2384                    frame.len().try_into().expect("invalid channel count"),
2385                ),
2386            }
2387            .write(&mut w, streaminfo)?;
2388
2389            bw = BitWriter::new(w);
2390
2391            cache
2392                .channels
2393                .resize_with(channels.len(), ChannelCache::default);
2394
2395            vec_map(
2396                cache.channels.iter_mut().zip(channels).collect(),
2397                |(cache, channel)| {
2398                    encode_subframe(
2399                        options,
2400                        cache,
2401                        CorrelatedChannel::independent(streaminfo.bits_per_sample, channel),
2402                    )
2403                },
2404            )
2405            .into_iter()
2406            .try_for_each(|r| r.and_then(|r| r.playback(bw.by_ref()).map_err(Error::Io)))?;
2407        }
2408    }
2409
2410    let crc16: u16 = bw.aligned_writer()?.checksum().into();
2411    bw.write_from(crc16)?;
2412
2413    frame_number.try_increment()?;
2414
2415    // update minimum and maximum frame size values
2416    if let s @ Some(size) = u32::try_from(bw.into_writer().into_writer().count)
2417        .ok()
2418        .filter(|size| *size < Streaminfo::MAX_FRAME_SIZE)
2419        .and_then(NonZero::new)
2420    {
2421        match &mut streaminfo.minimum_frame_size {
2422            Some(min_size) => {
2423                *min_size = size.min(*min_size);
2424            }
2425            min_size @ None => {
2426                *min_size = s;
2427            }
2428        }
2429
2430        match &mut streaminfo.maximum_frame_size {
2431            Some(max_size) => {
2432                *max_size = size.max(*max_size);
2433            }
2434            max_size @ None => {
2435                *max_size = s;
2436            }
2437        }
2438    }
2439
2440    Ok(())
2441}
2442
2443struct Correlated<C> {
2444    channel_assignment: ChannelAssignment,
2445    channels: [C; 2],
2446}
2447
2448struct CorrelatedChannel<'c> {
2449    samples: &'c [i32],
2450    bits_per_sample: SignedBitCount<32>,
2451    // whether all samples are known to be 0
2452    all_0: bool,
2453}
2454
2455impl<'c> CorrelatedChannel<'c> {
2456    fn independent(bits_per_sample: SignedBitCount<32>, samples: &'c [i32]) -> Self {
2457        Self {
2458            all_0: samples.iter().all(|s| *s == 0),
2459            bits_per_sample,
2460            samples,
2461        }
2462    }
2463}
2464
2465fn correlate_channels<'c>(
2466    options: &EncoderOptions,
2467    CorrelationCache {
2468        average_samples,
2469        difference_samples,
2470        ..
2471    }: &'c mut CorrelationCache,
2472    [left, right]: [&'c [i32]; 2],
2473    bits_per_sample: SignedBitCount<32>,
2474) -> Correlated<CorrelatedChannel<'c>> {
2475    match bits_per_sample.checked_add::<32>(1) {
2476        Some(difference_bits_per_sample) if options.mid_side => {
2477            let mut left_abs_sum = 0;
2478            let mut right_abs_sum = 0;
2479            let mut mid_abs_sum = 0;
2480            let mut side_abs_sum = 0;
2481
2482            join(
2483                || {
2484                    average_samples.clear();
2485                    average_samples.extend(
2486                        left.iter()
2487                            .inspect(|s| left_abs_sum += u64::from(s.unsigned_abs()))
2488                            .zip(
2489                                right
2490                                    .iter()
2491                                    .inspect(|s| right_abs_sum += u64::from(s.unsigned_abs())),
2492                            )
2493                            .map(|(l, r)| (l + r) >> 1)
2494                            .inspect(|s| mid_abs_sum += u64::from(s.unsigned_abs())),
2495                    );
2496                },
2497                || {
2498                    difference_samples.clear();
2499                    difference_samples.extend(
2500                        left.iter()
2501                            .zip(right)
2502                            .map(|(l, r)| l - r)
2503                            .inspect(|s| side_abs_sum += u64::from(s.unsigned_abs())),
2504                    );
2505                },
2506            );
2507
2508            match [
2509                (
2510                    ChannelAssignment::Independent(Independent::Stereo),
2511                    left_abs_sum + right_abs_sum,
2512                ),
2513                (ChannelAssignment::LeftSide, left_abs_sum + side_abs_sum),
2514                (ChannelAssignment::SideRight, side_abs_sum + right_abs_sum),
2515                (ChannelAssignment::MidSide, mid_abs_sum + side_abs_sum),
2516            ]
2517            .into_iter()
2518            .min_by_key(|(_, total)| *total)
2519            .unwrap()
2520            .0
2521            {
2522                channel_assignment @ ChannelAssignment::LeftSide => Correlated {
2523                    channel_assignment,
2524                    channels: [
2525                        CorrelatedChannel {
2526                            samples: left,
2527                            bits_per_sample,
2528                            all_0: left_abs_sum == 0,
2529                        },
2530                        CorrelatedChannel {
2531                            samples: difference_samples,
2532                            bits_per_sample: difference_bits_per_sample,
2533                            all_0: side_abs_sum == 0,
2534                        },
2535                    ],
2536                },
2537                channel_assignment @ ChannelAssignment::SideRight => Correlated {
2538                    channel_assignment,
2539                    channels: [
2540                        CorrelatedChannel {
2541                            samples: difference_samples,
2542                            bits_per_sample: difference_bits_per_sample,
2543                            all_0: side_abs_sum == 0,
2544                        },
2545                        CorrelatedChannel {
2546                            samples: right,
2547                            bits_per_sample,
2548                            all_0: right_abs_sum == 0,
2549                        },
2550                    ],
2551                },
2552                channel_assignment @ ChannelAssignment::MidSide => Correlated {
2553                    channel_assignment,
2554                    channels: [
2555                        CorrelatedChannel {
2556                            samples: average_samples,
2557                            bits_per_sample,
2558                            all_0: mid_abs_sum == 0,
2559                        },
2560                        CorrelatedChannel {
2561                            samples: difference_samples,
2562                            bits_per_sample: difference_bits_per_sample,
2563                            all_0: side_abs_sum == 0,
2564                        },
2565                    ],
2566                },
2567                channel_assignment @ ChannelAssignment::Independent(_) => Correlated {
2568                    channel_assignment,
2569                    channels: [
2570                        CorrelatedChannel {
2571                            samples: left,
2572                            bits_per_sample,
2573                            all_0: left_abs_sum == 0,
2574                        },
2575                        CorrelatedChannel {
2576                            samples: right,
2577                            bits_per_sample,
2578                            all_0: right_abs_sum == 0,
2579                        },
2580                    ],
2581                },
2582            }
2583        }
2584        Some(difference_bits_per_sample) => {
2585            let mut left_abs_sum = 0;
2586            let mut right_abs_sum = 0;
2587            let mut side_abs_sum = 0;
2588
2589            difference_samples.clear();
2590            difference_samples.extend(
2591                left.iter()
2592                    .inspect(|s| left_abs_sum += u64::from(s.unsigned_abs()))
2593                    .zip(
2594                        right
2595                            .iter()
2596                            .inspect(|s| right_abs_sum += u64::from(s.unsigned_abs())),
2597                    )
2598                    .map(|(l, r)| l - r)
2599                    .inspect(|s| side_abs_sum += u64::from(s.unsigned_abs())),
2600            );
2601
2602            match [
2603                (ChannelAssignment::LeftSide, left_abs_sum + side_abs_sum),
2604                (ChannelAssignment::SideRight, side_abs_sum + right_abs_sum),
2605                (
2606                    ChannelAssignment::Independent(Independent::Stereo),
2607                    left_abs_sum + right_abs_sum,
2608                ),
2609            ]
2610            .into_iter()
2611            .min_by_key(|(_, total)| *total)
2612            .unwrap()
2613            .0
2614            {
2615                channel_assignment @ ChannelAssignment::LeftSide => Correlated {
2616                    channel_assignment,
2617                    channels: [
2618                        CorrelatedChannel {
2619                            samples: left,
2620                            bits_per_sample,
2621                            all_0: left_abs_sum == 0,
2622                        },
2623                        CorrelatedChannel {
2624                            samples: difference_samples,
2625                            bits_per_sample: difference_bits_per_sample,
2626                            all_0: side_abs_sum == 0,
2627                        },
2628                    ],
2629                },
2630                channel_assignment @ ChannelAssignment::SideRight => Correlated {
2631                    channel_assignment,
2632                    channels: [
2633                        CorrelatedChannel {
2634                            samples: difference_samples,
2635                            bits_per_sample: difference_bits_per_sample,
2636                            all_0: side_abs_sum == 0,
2637                        },
2638                        CorrelatedChannel {
2639                            samples: right,
2640                            bits_per_sample,
2641                            all_0: right_abs_sum == 0,
2642                        },
2643                    ],
2644                },
2645                ChannelAssignment::MidSide => unreachable!(),
2646                channel_assignment @ ChannelAssignment::Independent(_) => Correlated {
2647                    channel_assignment,
2648                    channels: [
2649                        CorrelatedChannel {
2650                            samples: left,
2651                            bits_per_sample,
2652                            all_0: left_abs_sum == 0,
2653                        },
2654                        CorrelatedChannel {
2655                            samples: right,
2656                            bits_per_sample,
2657                            all_0: right_abs_sum == 0,
2658                        },
2659                    ],
2660                },
2661            }
2662        }
2663        None => {
2664            // 32 bps stream, so forego difference channel
2665            // and encode them both indepedently
2666
2667            Correlated {
2668                channel_assignment: ChannelAssignment::Independent(Independent::Stereo),
2669                channels: [
2670                    CorrelatedChannel::independent(bits_per_sample, left),
2671                    CorrelatedChannel::independent(bits_per_sample, right),
2672                ],
2673            }
2674        }
2675    }
2676}
2677
2678fn correlate_channels_exhaustive<'c>(
2679    options: &EncoderOptions,
2680    CorrelationCache {
2681        average_samples,
2682        difference_samples,
2683        left_cache,
2684        right_cache,
2685        average_cache,
2686        difference_cache,
2687        ..
2688    }: &'c mut CorrelationCache,
2689    [left, right]: [&'c [i32]; 2],
2690    bits_per_sample: SignedBitCount<32>,
2691) -> Result<Correlated<&'c BitRecorder<u32, BigEndian>>, Error> {
2692    let (left_recorder, right_recorder) = try_join(
2693        || {
2694            encode_subframe(
2695                options,
2696                left_cache,
2697                CorrelatedChannel {
2698                    samples: left,
2699                    bits_per_sample,
2700                    all_0: false,
2701                },
2702            )
2703        },
2704        || {
2705            encode_subframe(
2706                options,
2707                right_cache,
2708                CorrelatedChannel {
2709                    samples: right,
2710                    bits_per_sample,
2711                    all_0: false,
2712                },
2713            )
2714        },
2715    )?;
2716
2717    match bits_per_sample.checked_add::<32>(1) {
2718        Some(difference_bits_per_sample) if options.mid_side => {
2719            let (average_recorder, difference_recorder) = try_join(
2720                || {
2721                    average_samples.clear();
2722                    average_samples
2723                        .extend(left.iter().zip(right.iter()).map(|(l, r)| (l + r) >> 1));
2724                    encode_subframe(
2725                        options,
2726                        average_cache,
2727                        CorrelatedChannel {
2728                            samples: average_samples,
2729                            bits_per_sample,
2730                            all_0: false,
2731                        },
2732                    )
2733                },
2734                || {
2735                    difference_samples.clear();
2736                    difference_samples.extend(left.iter().zip(right).map(|(l, r)| l - r));
2737                    encode_subframe(
2738                        options,
2739                        difference_cache,
2740                        CorrelatedChannel {
2741                            samples: difference_samples,
2742                            bits_per_sample: difference_bits_per_sample,
2743                            all_0: false,
2744                        },
2745                    )
2746                },
2747            )?;
2748
2749            match [
2750                (
2751                    ChannelAssignment::Independent(Independent::Stereo),
2752                    left_recorder.written() + right_recorder.written(),
2753                ),
2754                (
2755                    ChannelAssignment::LeftSide,
2756                    left_recorder.written() + difference_recorder.written(),
2757                ),
2758                (
2759                    ChannelAssignment::SideRight,
2760                    difference_recorder.written() + right_recorder.written(),
2761                ),
2762                (
2763                    ChannelAssignment::MidSide,
2764                    average_recorder.written() + difference_recorder.written(),
2765                ),
2766            ]
2767            .into_iter()
2768            .min_by_key(|(_, total)| *total)
2769            .unwrap()
2770            .0
2771            {
2772                channel_assignment @ ChannelAssignment::LeftSide => Ok(Correlated {
2773                    channel_assignment,
2774                    channels: [left_recorder, difference_recorder],
2775                }),
2776                channel_assignment @ ChannelAssignment::SideRight => Ok(Correlated {
2777                    channel_assignment,
2778                    channels: [difference_recorder, right_recorder],
2779                }),
2780                channel_assignment @ ChannelAssignment::MidSide => Ok(Correlated {
2781                    channel_assignment,
2782                    channels: [average_recorder, difference_recorder],
2783                }),
2784                channel_assignment @ ChannelAssignment::Independent(_) => Ok(Correlated {
2785                    channel_assignment,
2786                    channels: [left_recorder, right_recorder],
2787                }),
2788            }
2789        }
2790        Some(difference_bits_per_sample) => {
2791            let difference_recorder = {
2792                difference_samples.clear();
2793                difference_samples.extend(left.iter().zip(right).map(|(l, r)| l - r));
2794                encode_subframe(
2795                    options,
2796                    difference_cache,
2797                    CorrelatedChannel {
2798                        samples: difference_samples,
2799                        bits_per_sample: difference_bits_per_sample,
2800                        all_0: false,
2801                    },
2802                )?
2803            };
2804
2805            match [
2806                (
2807                    ChannelAssignment::Independent(Independent::Stereo),
2808                    left_recorder.written() + right_recorder.written(),
2809                ),
2810                (
2811                    ChannelAssignment::LeftSide,
2812                    left_recorder.written() + difference_recorder.written(),
2813                ),
2814                (
2815                    ChannelAssignment::SideRight,
2816                    difference_recorder.written() + right_recorder.written(),
2817                ),
2818            ]
2819            .into_iter()
2820            .min_by_key(|(_, total)| *total)
2821            .unwrap()
2822            .0
2823            {
2824                channel_assignment @ ChannelAssignment::LeftSide => Ok(Correlated {
2825                    channel_assignment,
2826                    channels: [left_recorder, difference_recorder],
2827                }),
2828                channel_assignment @ ChannelAssignment::SideRight => Ok(Correlated {
2829                    channel_assignment,
2830                    channels: [difference_recorder, right_recorder],
2831                }),
2832                ChannelAssignment::MidSide => unreachable!(),
2833                channel_assignment @ ChannelAssignment::Independent(_) => Ok(Correlated {
2834                    channel_assignment,
2835                    channels: [left_recorder, right_recorder],
2836                }),
2837            }
2838        }
2839        None => {
2840            // 32 bps stream, so forego difference channel
2841            // and encode them both indepedently
2842
2843            Ok(Correlated {
2844                channel_assignment: ChannelAssignment::Independent(Independent::Stereo),
2845                channels: [left_recorder, right_recorder],
2846            })
2847        }
2848    }
2849}
2850
2851fn encode_subframe<'c>(
2852    options: &EncoderOptions,
2853    ChannelCache {
2854        fixed: fixed_cache,
2855        fixed_output,
2856        lpc: lpc_cache,
2857        lpc_output,
2858        constant_output,
2859        verbatim_output,
2860        wasted,
2861    }: &'c mut ChannelCache,
2862    CorrelatedChannel {
2863        samples: channel,
2864        bits_per_sample,
2865        all_0,
2866    }: CorrelatedChannel,
2867) -> Result<&'c BitRecorder<u32, BigEndian>, Error> {
2868    const WASTED_MAX: NonZero<u32> = NonZero::new(32).unwrap();
2869
2870    debug_assert!(!channel.is_empty());
2871
2872    if all_0 {
2873        // all samples are 0
2874        constant_output.clear();
2875        encode_constant_subframe(constant_output, channel[0], bits_per_sample, 0)?;
2876        return Ok(constant_output);
2877    }
2878
2879    // determine any wasted bits
2880    let (channel, bits_per_sample, wasted_bps) =
2881        match channel.iter().try_fold(WASTED_MAX, |acc, sample| {
2882            NonZero::new(sample.trailing_zeros()).map(|sample| sample.min(acc))
2883        }) {
2884            None => (channel, bits_per_sample, 0),
2885            Some(WASTED_MAX) => {
2886                constant_output.clear();
2887                encode_constant_subframe(constant_output, channel[0], bits_per_sample, 0)?;
2888                return Ok(constant_output);
2889            }
2890            Some(wasted_bps) => {
2891                let wasted_bps = wasted_bps.get();
2892                wasted.clear();
2893                wasted.extend(channel.iter().map(|sample| sample >> wasted_bps));
2894                (
2895                    wasted.as_slice(),
2896                    bits_per_sample.checked_sub(wasted_bps).unwrap(),
2897                    wasted_bps,
2898                )
2899            }
2900        };
2901
2902    fixed_output.clear();
2903
2904    let best = match options.max_lpc_order {
2905        Some(max_lpc_order) => {
2906            lpc_output.clear();
2907
2908            match join(
2909                || {
2910                    encode_fixed_subframe(
2911                        options,
2912                        fixed_cache,
2913                        fixed_output,
2914                        channel,
2915                        bits_per_sample,
2916                        wasted_bps,
2917                    )
2918                },
2919                || {
2920                    encode_lpc_subframe(
2921                        options,
2922                        max_lpc_order,
2923                        lpc_cache,
2924                        lpc_output,
2925                        channel,
2926                        bits_per_sample,
2927                        wasted_bps,
2928                    )
2929                },
2930            ) {
2931                (Ok(()), Ok(())) => [fixed_output, lpc_output]
2932                    .into_iter()
2933                    .min_by_key(|c| c.written())
2934                    .unwrap(),
2935                (Err(_), Ok(())) => lpc_output,
2936                (Ok(()), Err(_)) => fixed_output,
2937                (Err(_), Err(_)) => {
2938                    verbatim_output.clear();
2939                    encode_verbatim_subframe(
2940                        verbatim_output,
2941                        channel,
2942                        bits_per_sample,
2943                        wasted_bps,
2944                    )?;
2945                    return Ok(verbatim_output);
2946                }
2947            }
2948        }
2949        _ => {
2950            match encode_fixed_subframe(
2951                options,
2952                fixed_cache,
2953                fixed_output,
2954                channel,
2955                bits_per_sample,
2956                wasted_bps,
2957            ) {
2958                Ok(()) => fixed_output,
2959                Err(_) => {
2960                    verbatim_output.clear();
2961                    encode_verbatim_subframe(
2962                        verbatim_output,
2963                        channel,
2964                        bits_per_sample,
2965                        wasted_bps,
2966                    )?;
2967                    return Ok(verbatim_output);
2968                }
2969            }
2970        }
2971    };
2972
2973    let verbatim_len = channel.len() as u32 * u32::from(bits_per_sample);
2974
2975    if best.written() < verbatim_len {
2976        Ok(best)
2977    } else {
2978        verbatim_output.clear();
2979        encode_verbatim_subframe(verbatim_output, channel, bits_per_sample, wasted_bps)?;
2980        Ok(verbatim_output)
2981    }
2982}
2983
2984fn encode_constant_subframe<W: BitWrite>(
2985    writer: &mut W,
2986    sample: i32,
2987    bits_per_sample: SignedBitCount<32>,
2988    wasted_bps: u32,
2989) -> Result<(), Error> {
2990    use crate::stream::{SubframeHeader, SubframeHeaderType};
2991
2992    writer.build(&SubframeHeader {
2993        type_: SubframeHeaderType::Constant,
2994        wasted_bps,
2995    })?;
2996
2997    writer
2998        .write_signed_counted(bits_per_sample, sample)
2999        .map_err(Error::Io)
3000}
3001
3002fn encode_verbatim_subframe<W: BitWrite>(
3003    writer: &mut W,
3004    channel: &[i32],
3005    bits_per_sample: SignedBitCount<32>,
3006    wasted_bps: u32,
3007) -> Result<(), Error> {
3008    use crate::stream::{SubframeHeader, SubframeHeaderType};
3009
3010    writer.build(&SubframeHeader {
3011        type_: SubframeHeaderType::Verbatim,
3012        wasted_bps,
3013    })?;
3014
3015    channel
3016        .iter()
3017        .try_for_each(|i| writer.write_signed_counted(bits_per_sample, *i))?;
3018
3019    Ok(())
3020}
3021
3022fn encode_fixed_subframe<W: BitWrite>(
3023    options: &EncoderOptions,
3024    FixedCache {
3025        fixed_buffers: buffers,
3026    }: &mut FixedCache,
3027    writer: &mut W,
3028    channel: &[i32],
3029    bits_per_sample: SignedBitCount<32>,
3030    wasted_bps: u32,
3031) -> Result<(), Error> {
3032    use crate::stream::{SubframeHeader, SubframeHeaderType};
3033
3034    // calculate residuals for FIXED subframe orders 0-4
3035    // (or fewer, if we don't have enough samples)
3036    let (order, warm_up, residuals) = {
3037        let mut fixed_orders = ArrayVec::<&[i32], 5>::new();
3038        fixed_orders.push(channel);
3039
3040        // accumulate a set of FIXED diffs
3041        'outer: for buf in buffers.iter_mut() {
3042            let prev_order = fixed_orders.last().unwrap();
3043            match prev_order.split_at_checked(1) {
3044                Some((_, r)) => {
3045                    buf.clear();
3046                    for (n, p) in r.iter().zip(*prev_order) {
3047                        match n.checked_sub(*p) {
3048                            Some(v) => {
3049                                buf.push(v);
3050                            }
3051                            None => break 'outer,
3052                        }
3053                    }
3054                    if buf.is_empty() {
3055                        break;
3056                    } else {
3057                        fixed_orders.push(buf.as_slice());
3058                    }
3059                }
3060                None => break,
3061            }
3062        }
3063
3064        let min_fixed = fixed_orders.last().unwrap().len();
3065
3066        // choose diff with the smallest abs sum
3067        fixed_orders
3068            .into_iter()
3069            .enumerate()
3070            .min_by_key(|(_, residuals)| {
3071                residuals[(residuals.len() - min_fixed)..]
3072                    .iter()
3073                    .map(|r| u64::from(r.unsigned_abs()))
3074                    .sum::<u64>()
3075            })
3076            .map(|(order, residuals)| (order as u8, &channel[0..order], residuals))
3077            .unwrap()
3078    };
3079
3080    writer.build(&SubframeHeader {
3081        type_: SubframeHeaderType::Fixed { order },
3082        wasted_bps,
3083    })?;
3084
3085    warm_up
3086        .iter()
3087        .try_for_each(|sample: &i32| writer.write_signed_counted(bits_per_sample, *sample))?;
3088
3089    write_residuals(options, writer, order.into(), residuals)
3090}
3091
3092fn encode_lpc_subframe<W: BitWrite>(
3093    options: &EncoderOptions,
3094    max_lpc_order: NonZero<u8>,
3095    cache: &mut LpcCache,
3096    writer: &mut W,
3097    channel: &[i32],
3098    bits_per_sample: SignedBitCount<32>,
3099    wasted_bps: u32,
3100) -> Result<(), Error> {
3101    use crate::stream::{SubframeHeader, SubframeHeaderType};
3102
3103    let LpcSubframeParameters {
3104        warm_up,
3105        residuals,
3106        parameters:
3107            LpcParameters {
3108                order,
3109                precision,
3110                shift,
3111                coefficients,
3112            },
3113    } = LpcSubframeParameters::best(options, bits_per_sample, max_lpc_order, cache, channel)?;
3114
3115    writer.build(&SubframeHeader {
3116        type_: SubframeHeaderType::Lpc { order },
3117        wasted_bps,
3118    })?;
3119
3120    for sample in warm_up {
3121        writer.write_signed_counted(bits_per_sample, *sample)?;
3122    }
3123
3124    writer.write_count::<0b1111>(
3125        precision
3126            .count()
3127            .checked_sub(1)
3128            .ok_or(Error::InvalidQlpPrecision)?,
3129    )?;
3130
3131    writer.write::<5, i32>(shift as i32)?;
3132
3133    for coeff in coefficients {
3134        writer.write_signed_counted(precision, coeff)?;
3135    }
3136
3137    write_residuals(options, writer, order.get().into(), residuals)
3138}
3139
3140struct LpcSubframeParameters<'w, 'r> {
3141    parameters: LpcParameters,
3142    warm_up: &'w [i32],
3143    residuals: &'r [i32],
3144}
3145
3146impl<'w, 'r> LpcSubframeParameters<'w, 'r> {
3147    fn best(
3148        options: &EncoderOptions,
3149        bits_per_sample: SignedBitCount<32>,
3150        max_lpc_order: NonZero<u8>,
3151        LpcCache {
3152            residuals,
3153            window,
3154            windowed,
3155        }: &'r mut LpcCache,
3156        channel: &'w [i32],
3157    ) -> Result<Self, Error> {
3158        let parameters = LpcParameters::best(
3159            options,
3160            bits_per_sample,
3161            max_lpc_order,
3162            window,
3163            windowed,
3164            channel,
3165        )?;
3166
3167        Self::encode_residuals(&parameters, channel, residuals)
3168            .map(|(warm_up, residuals)| Self {
3169                warm_up,
3170                residuals,
3171                parameters,
3172            })
3173            .map_err(|ResidualOverflow| Error::ResidualOverflow)
3174    }
3175
3176    fn encode_residuals(
3177        parameters: &LpcParameters,
3178        channel: &'w [i32],
3179        residuals: &'r mut Vec<i32>,
3180    ) -> Result<(&'w [i32], &'r [i32]), ResidualOverflow> {
3181        residuals.clear();
3182
3183        for split in usize::from(parameters.order.get())..channel.len() {
3184            let (previous, current) = channel.split_at(split);
3185
3186            residuals.push(
3187                current[0]
3188                    .checked_sub(
3189                        (previous
3190                            .iter()
3191                            .rev()
3192                            .zip(&parameters.coefficients)
3193                            .map(|(x, y)| *x as i64 * *y as i64)
3194                            .sum::<i64>()
3195                            >> parameters.shift) as i32,
3196                    )
3197                    .ok_or(ResidualOverflow)?,
3198            );
3199        }
3200
3201        Ok((
3202            &channel[0..parameters.order.get().into()],
3203            residuals.as_slice(),
3204        ))
3205    }
3206}
3207
3208#[derive(Debug)]
3209struct ResidualOverflow;
3210
3211impl From<ResidualOverflow> for Error {
3212    #[inline]
3213    fn from(_: ResidualOverflow) -> Self {
3214        Error::ResidualOverflow
3215    }
3216}
3217
3218#[test]
3219fn test_residual_encoding_1() {
3220    let samples = [
3221        0, 16, 31, 44, 54, 61, 64, 63, 58, 49, 38, 24, 8, -8, -24, -38, -49, -58, -63, -64, -61,
3222        -54, -44, -31, -16,
3223    ];
3224
3225    let expected_residuals = [
3226        2, 2, 2, 3, 3, 3, 2, 2, 3, 0, 0, 0, -1, -1, -1, -3, -2, -2, -2, -1, -1, 0, 0,
3227    ];
3228
3229    let mut actual_residuals = Vec::with_capacity(expected_residuals.len());
3230
3231    let (warm_up, residuals) = LpcSubframeParameters::encode_residuals(
3232        &LpcParameters {
3233            order: NonZero::new(2).unwrap(),
3234            precision: SignedBitCount::new::<7>(),
3235            shift: 5,
3236            coefficients: arrayvec![59, -30],
3237        },
3238        &samples,
3239        &mut actual_residuals,
3240    )
3241    .unwrap();
3242
3243    assert_eq!(warm_up, &samples[0..2]);
3244    assert_eq!(residuals, &expected_residuals);
3245}
3246
3247#[test]
3248fn test_residual_encoding_2() {
3249    let samples = [
3250        64, 62, 56, 47, 34, 20, 4, -12, -27, -41, -52, -60, -63, -63, -60, -52, -41, -27, -12, 4,
3251        20, 34, 47, 56, 62,
3252    ];
3253
3254    let expected_residuals = [
3255        2, 2, 0, 1, -1, -1, -1, -2, -2, -2, -1, -3, -2, 0, -1, 1, 0, 2, 2, 2, 4, 2, 4,
3256    ];
3257
3258    let mut actual_residuals = Vec::with_capacity(expected_residuals.len());
3259
3260    let (warm_up, residuals) = LpcSubframeParameters::encode_residuals(
3261        &LpcParameters {
3262            order: NonZero::new(2).unwrap(),
3263            precision: SignedBitCount::new::<7>(),
3264            shift: 5,
3265            coefficients: arrayvec![58, -29],
3266        },
3267        &samples,
3268        &mut actual_residuals,
3269    )
3270    .unwrap();
3271
3272    assert_eq!(warm_up, &samples[0..2]);
3273    assert_eq!(residuals, &expected_residuals);
3274}
3275
3276#[derive(Debug)]
3277struct LpcParameters {
3278    order: NonZero<u8>,
3279    precision: SignedBitCount<15>,
3280    shift: u32,
3281    coefficients: ArrayVec<i32, MAX_LPC_COEFFS>,
3282}
3283
3284// There isn't any particular *best* way to determine
3285// the ideal LPC subframe parameters (though there are
3286// some worst ways, like choosing them at random).
3287// Even the reference implementation has changed its
3288// defaults over time.  So long as the subframe's residuals
3289// are calculated correctly, decoders don't care one way or another.
3290//
3291// I'll try to use an approach similar to the reference implementation's.
3292
3293impl LpcParameters {
3294    fn best(
3295        options: &EncoderOptions,
3296        bits_per_sample: SignedBitCount<32>,
3297        max_lpc_order: NonZero<u8>,
3298        window: &mut Vec<f64>,
3299        windowed: &mut Vec<f64>,
3300        channel: &[i32],
3301    ) -> Result<Self, Error> {
3302        if channel.len() <= usize::from(max_lpc_order.get()) {
3303            // not enough samples in channel to calculate LPC parameters
3304            return Err(Error::InsufficientLpcSamples);
3305        }
3306
3307        let precision = match channel.len() {
3308            // this shouldn't be possible
3309            0 => panic!("at least one sample required in channel"),
3310            1..=192 => SignedBitCount::new::<7>(),
3311            193..=384 => SignedBitCount::new::<8>(),
3312            385..=576 => SignedBitCount::new::<9>(),
3313            577..=1152 => SignedBitCount::new::<10>(),
3314            1153..=2304 => SignedBitCount::new::<11>(),
3315            2305..=4608 => SignedBitCount::new::<12>(),
3316            4609.. => SignedBitCount::new::<13>(),
3317        };
3318
3319        let (order, lp_coeffs) = compute_best_order(
3320            bits_per_sample,
3321            precision,
3322            channel
3323                .len()
3324                .try_into()
3325                // this shouldn't be possible
3326                .expect("excessive samples for subframe"),
3327            lp_coefficients(autocorrelate(
3328                options.window.apply(window, windowed, channel),
3329                max_lpc_order,
3330            )),
3331        )?;
3332
3333        Self::quantize(order, lp_coeffs, precision)
3334    }
3335
3336    fn quantize(
3337        order: NonZero<u8>,
3338        coeffs: ArrayVec<f64, MAX_LPC_COEFFS>,
3339        precision: SignedBitCount<15>,
3340    ) -> Result<Self, Error> {
3341        const MAX_SHIFT: i32 = (1 << 4) - 1;
3342        const MIN_SHIFT: i32 = -(1 << 4);
3343
3344        // verified output against reference implementation
3345        // See: FLAC__lpc_quantize_coefficients
3346
3347        debug_assert!(coeffs.len() == usize::from(order.get()));
3348
3349        let max_coeff = (1 << (u32::from(precision) - 1)) - 1;
3350        let min_coeff = -(1 << (u32::from(precision) - 1));
3351
3352        let l = coeffs
3353            .iter()
3354            .map(|c| c.abs())
3355            .max_by(|x, y| x.total_cmp(y))
3356            // f64.log2() gives unfortunate results when <= 0.0
3357            .filter(|l| *l > 0.0)
3358            .ok_or(Error::ZeroLpCoefficients)?;
3359
3360        let mut error = 0.0;
3361
3362        match ((u32::from(precision) - 1) as i32 - ((l.log2().floor()) as i32) - 1).min(MAX_SHIFT) {
3363            shift @ 0.. => {
3364                // normal, positive shift case
3365                let shift = shift as u32;
3366
3367                Ok(Self {
3368                    order,
3369                    precision,
3370                    shift,
3371                    coefficients: coeffs
3372                        .into_iter()
3373                        .map(|lp_coeff| {
3374                            let sum: f64 = lp_coeff.mul_add((1 << shift) as f64, error);
3375                            let qlp_coeff = (sum.round() as i32).clamp(min_coeff, max_coeff);
3376                            error = sum - (qlp_coeff as f64);
3377                            qlp_coeff
3378                        })
3379                        .collect(),
3380                })
3381            }
3382            shift @ MIN_SHIFT..0 => {
3383                // unusual negative shift case
3384                let shift = -shift as u32;
3385
3386                Ok(Self {
3387                    order,
3388                    precision,
3389                    shift: 0,
3390                    coefficients: coeffs
3391                        .into_iter()
3392                        .map(|lp_coeff| {
3393                            let sum: f64 = (lp_coeff / (1 << shift) as f64) + error;
3394                            let qlp_coeff = (sum.round() as i32).clamp(min_coeff, max_coeff);
3395                            error = sum - (qlp_coeff as f64);
3396                            qlp_coeff
3397                        })
3398                        .collect(),
3399                })
3400            }
3401            ..MIN_SHIFT => Err(Error::LpNegativeShiftError),
3402        }
3403    }
3404}
3405
3406#[test]
3407fn test_quantization() {
3408    // test against numbers generated from reference implementation
3409
3410    let order = NonZero::new(4).unwrap();
3411
3412    let quantized = LpcParameters::quantize(
3413        order,
3414        arrayvec![0.797774, -0.045362, -0.050136, -0.054254],
3415        SignedBitCount::new::<10>(),
3416    )
3417    .unwrap();
3418
3419    assert_eq!(quantized.order, order);
3420    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3421    assert_eq!(quantized.shift, 9);
3422    assert_eq!(quantized.coefficients, arrayvec![408, -23, -25, -28]);
3423
3424    // note the relationship between the un-quantized,
3425    // floating point parameters and the shift value (9)
3426    //
3427    // 409 / 2 ** 9 ≈ 0.796875
3428    // -23 / 2 ** 9 ≈ -0.044921
3429    // -25 / 2 ** 9 ≈ -0.048828
3430    // -28 / 2 ** 9 ≈ -0.054687
3431    //
3432    // we're converting floats to fractions
3433
3434    let quantized = LpcParameters::quantize(
3435        order,
3436        arrayvec![-0.054687, -0.953216, -0.027115, 0.033537],
3437        SignedBitCount::new::<10>(),
3438    )
3439    .unwrap();
3440
3441    assert_eq!(quantized.order, order);
3442    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3443    assert_eq!(quantized.shift, 9);
3444    assert_eq!(quantized.coefficients, arrayvec![-28, -488, -14, 17]);
3445
3446    // coefficients should never be all zero, which is bad
3447    assert!(matches!(
3448        LpcParameters::quantize(
3449            order,
3450            arrayvec![0.0, 0.0, 0.0, 0.0],
3451            SignedBitCount::new::<10>(),
3452        ),
3453        Err(Error::ZeroLpCoefficients)
3454    ));
3455
3456    // negative shifts should also be handled properly
3457    let quantized = LpcParameters::quantize(
3458        order,
3459        arrayvec![-0.1, 0.1, 10000000.0, -0.2],
3460        SignedBitCount::new::<10>(),
3461    )
3462    .unwrap();
3463
3464    assert_eq!(quantized.order, order);
3465    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3466    assert_eq!(quantized.shift, 0);
3467    assert_eq!(quantized.coefficients, arrayvec![0, 0, 305, 0]);
3468
3469    // and massive negative shifts must be an error
3470    assert!(matches!(
3471        LpcParameters::quantize(
3472            order,
3473            arrayvec![-0.1, 0.1, 100000000.0, -0.2],
3474            SignedBitCount::new::<10>(),
3475        ),
3476        Err(Error::LpNegativeShiftError)
3477    ));
3478}
3479
3480fn autocorrelate(
3481    windowed: &[f64],
3482    max_lpc_order: NonZero<u8>,
3483) -> ArrayVec<f64, { MAX_LPC_COEFFS + 1 }> {
3484    // verified output against reference implementation
3485    // See: FLAC__lpc_compute_autocorrelation
3486
3487    debug_assert!(usize::from(max_lpc_order.get()) < MAX_LPC_COEFFS);
3488
3489    let mut tail = windowed;
3490    // let mut autocorrelated = Vec::with_capacity(max_lpc_order.get().into());
3491    let mut autocorrelated = ArrayVec::default();
3492
3493    for _ in 0..=max_lpc_order.get() {
3494        if tail.is_empty() {
3495            return autocorrelated;
3496        } else {
3497            autocorrelated.push(windowed.iter().zip(tail).map(|(x, y)| x * y).sum());
3498            tail.split_off_first();
3499        }
3500    }
3501
3502    autocorrelated
3503}
3504
3505#[test]
3506fn test_autocorrelation() {
3507    // test against numbers generated from reference implementation
3508
3509    assert_eq!(
3510        autocorrelate(&[1.0], NonZero::new(1).unwrap()),
3511        arrayvec![1.0]
3512    );
3513
3514    assert_eq!(
3515        autocorrelate(&[1.0, 2.0, 3.0, 4.0, 5.0], NonZero::new(4).unwrap()),
3516        arrayvec![55.0, 40.0, 26.0, 14.0, 5.0],
3517    );
3518
3519    assert_eq!(
3520        autocorrelate(
3521            &[
3522                0.0, 16.0, 31.0, 44.0, 54.0, 61.0, 64.0, 63.0, 58.0, 49.0, 38.0, 24.0, 8.0, -8.0,
3523                -24.0, -38.0, -49.0, -58.0, -63.0, -64.0, -61.0, -54.0, -44.0, -31.0, -16.0,
3524            ],
3525            NonZero::new(4).unwrap()
3526        ),
3527        arrayvec![51408.0, 49792.0, 45304.0, 38466.0, 29914.0],
3528    )
3529}
3530
3531#[derive(Debug)]
3532struct LpCoeff {
3533    coeffs: ArrayVec<f64, MAX_LPC_COEFFS>,
3534    error: f64,
3535}
3536
3537// returns a Vec of (coefficients, error) pairs
3538fn lp_coefficients(
3539    autocorrelated: ArrayVec<f64, { MAX_LPC_COEFFS + 1 }>,
3540) -> ArrayVec<LpCoeff, MAX_LPC_COEFFS> {
3541    // verified output against reference implementation
3542    // See: FLAC__lpc_compute_lp_coefficients
3543
3544    match autocorrelated.len() {
3545        0 | 1 => panic!("must have at least 2 autocorrelation values"),
3546        _ => {
3547            let k = autocorrelated[1] / autocorrelated[0];
3548            let mut lp_coefficients = arrayvec![LpCoeff {
3549                coeffs: arrayvec![k],
3550                error: autocorrelated[0] * (1.0 - k.powi(2)),
3551            }];
3552
3553            for i in 1..(autocorrelated.len() - 1) {
3554                if let [prev @ .., next] = &autocorrelated[0..=i + 1] {
3555                    let LpCoeff { coeffs, error } = lp_coefficients.last().unwrap();
3556
3557                    let q = next
3558                        - prev
3559                            .iter()
3560                            .rev()
3561                            .zip(coeffs)
3562                            .map(|(x, y)| x * y)
3563                            .sum::<f64>();
3564
3565                    let k = q / error;
3566
3567                    lp_coefficients.push(LpCoeff {
3568                        coeffs: coeffs
3569                            .iter()
3570                            .zip(coeffs.iter().rev().map(|c| k * c))
3571                            .map(|(c1, c2)| c1 - c2)
3572                            .chain(std::iter::once(k))
3573                            .collect(),
3574                        error: error * (1.0 - k.powi(2)),
3575                    });
3576                }
3577            }
3578
3579            lp_coefficients
3580        }
3581    }
3582}
3583
3584#[allow(unused)]
3585macro_rules! assert_float_approx {
3586    ($a:expr, $b:expr) => {{
3587        let a = $a;
3588        let b = $b;
3589        assert!((a - b).abs() < 1.0e-6, "{a} != {b}");
3590    }};
3591}
3592
3593#[test]
3594fn test_lp_coefficients_1() {
3595    // test against numbers generated from reference implementation
3596
3597    let lp_coeffs = lp_coefficients(arrayvec![55.0, 40.0, 26.0, 14.0, 5.0]);
3598
3599    assert_eq!(lp_coeffs.len(), 4);
3600
3601    assert_float_approx!(lp_coeffs[0].error, 25.909091);
3602    assert_float_approx!(lp_coeffs[1].error, 25.540351);
3603    assert_float_approx!(lp_coeffs[2].error, 25.316142);
3604    assert_float_approx!(lp_coeffs[3].error, 25.241623);
3605
3606    assert_eq!(lp_coeffs[0].coeffs.len(), 1);
3607    assert_float_approx!(lp_coeffs[0].coeffs[0], 0.727273);
3608
3609    assert_eq!(lp_coeffs[1].coeffs.len(), 2);
3610    assert_float_approx!(lp_coeffs[1].coeffs[0], 0.814035);
3611    assert_float_approx!(lp_coeffs[1].coeffs[1], -0.119298);
3612
3613    assert_eq!(lp_coeffs[2].coeffs.len(), 3);
3614    assert_float_approx!(lp_coeffs[2].coeffs[0], 0.802858);
3615    assert_float_approx!(lp_coeffs[2].coeffs[1], -0.043028);
3616    assert_float_approx!(lp_coeffs[2].coeffs[2], -0.093694);
3617
3618    assert_eq!(lp_coeffs[3].coeffs.len(), 4);
3619    assert_float_approx!(lp_coeffs[3].coeffs[0], 0.797774);
3620    assert_float_approx!(lp_coeffs[3].coeffs[1], -0.045362);
3621    assert_float_approx!(lp_coeffs[3].coeffs[2], -0.050136);
3622    assert_float_approx!(lp_coeffs[3].coeffs[3], -0.054254);
3623}
3624
3625#[test]
3626fn test_lp_coefficients_2() {
3627    // test against numbers generated from reference implementation
3628
3629    let lp_coeffs = lp_coefficients(arrayvec![51408.0, 49792.0, 45304.0, 38466.0, 29914.0]);
3630
3631    assert_eq!(lp_coeffs.len(), 4);
3632
3633    assert_float_approx!(lp_coeffs[0].error, 3181.201369);
3634    assert_float_approx!(lp_coeffs[1].error, 495.815931);
3635    assert_float_approx!(lp_coeffs[2].error, 495.161449);
3636    assert_float_approx!(lp_coeffs[3].error, 494.604514);
3637
3638    assert_eq!(lp_coeffs[0].coeffs.len(), 1);
3639    assert_float_approx!(lp_coeffs[0].coeffs[0], 0.968565);
3640
3641    assert_eq!(lp_coeffs[1].coeffs.len(), 2);
3642    assert_float_approx!(lp_coeffs[1].coeffs[0], 1.858456);
3643    assert_float_approx!(lp_coeffs[1].coeffs[1], -0.918772);
3644
3645    assert_eq!(lp_coeffs[2].coeffs.len(), 3);
3646    assert_float_approx!(lp_coeffs[2].coeffs[0], 1.891837);
3647    assert_float_approx!(lp_coeffs[2].coeffs[1], -0.986293);
3648    assert_float_approx!(lp_coeffs[2].coeffs[2], 0.036332);
3649
3650    assert_eq!(lp_coeffs[3].coeffs.len(), 4);
3651    assert_float_approx!(lp_coeffs[3].coeffs[0], 1.890618);
3652    assert_float_approx!(lp_coeffs[3].coeffs[1], -0.953216);
3653    assert_float_approx!(lp_coeffs[3].coeffs[2], -0.027115);
3654    assert_float_approx!(lp_coeffs[3].coeffs[3], 0.033537);
3655}
3656
3657// Returns (bits, order, coeffients) tuples
3658fn subframe_bits_by_order(
3659    bits_per_sample: SignedBitCount<32>,
3660    precision: SignedBitCount<15>,
3661    sample_count: u16,
3662    coeffs: ArrayVec<LpCoeff, MAX_LPC_COEFFS>,
3663) -> impl Iterator<Item = (f64, u8, ArrayVec<f64, MAX_LPC_COEFFS>)> {
3664    debug_assert!(sample_count > 0);
3665
3666    let error_scale = 0.5 / f64::from(sample_count);
3667
3668    coeffs
3669        .into_iter()
3670        .take_while(|coeffs| coeffs.error > 0.0)
3671        .zip(1..)
3672        .map(move |(LpCoeff { coeffs, error }, order)| {
3673            let header_bits =
3674                u32::from(order) * (u32::from(bits_per_sample) + u32::from(precision));
3675
3676            let bits_per_residual =
3677                (error * error_scale).ln() / (2.0 * std::f64::consts::LN_2).max(0.0);
3678
3679            let subframe_bits = bits_per_residual.mul_add(
3680                f64::from(sample_count - u16::from(order)),
3681                f64::from(header_bits),
3682            );
3683
3684            (subframe_bits, order, coeffs)
3685        })
3686}
3687
3688// Uses the error in the LP coefficients to determine the best order
3689// and returns that order along with the stripped-out coefficients
3690fn compute_best_order(
3691    bits_per_sample: SignedBitCount<32>,
3692    precision: SignedBitCount<15>,
3693    sample_count: u16,
3694    coeffs: ArrayVec<LpCoeff, MAX_LPC_COEFFS>,
3695) -> Result<(NonZero<u8>, ArrayVec<f64, MAX_LPC_COEFFS>), Error> {
3696    // verified output against reference implementation
3697    // See: FLAC__lpc_compute_best_order  and
3698    // See: FLAC__lpc_compute_expected_bits_per_residual_sample_with_error_scale
3699
3700    subframe_bits_by_order(bits_per_sample, precision, sample_count, coeffs)
3701        .min_by(|(x, _, _), (y, _, _)| x.total_cmp(y))
3702        .and_then(|(_, order, coeffs)| Some((NonZero::new(order)?, coeffs)))
3703        .ok_or(Error::NoBestLpcOrder)
3704}
3705
3706#[test]
3707fn test_compute_best_order() {
3708    // test against numbers generated from reference implementation
3709
3710    let mut bits = subframe_bits_by_order(
3711        SignedBitCount::new::<16>(),
3712        SignedBitCount::new::<5>(),
3713        20,
3714        [3181.201369, 495.815931, 495.161449, 494.604514]
3715            .into_iter()
3716            .map(|error| LpCoeff {
3717                coeffs: ArrayVec::default(),
3718                error,
3719            })
3720            .collect(),
3721    )
3722    .map(|t| t.0);
3723
3724    assert_float_approx!(bits.next().unwrap(), 80.977565);
3725    assert_float_approx!(bits.next().unwrap(), 74.685594);
3726    assert_float_approx!(bits.next().unwrap(), 93.853530);
3727    assert_float_approx!(bits.next().unwrap(), 113.025628);
3728
3729    let mut bits = subframe_bits_by_order(
3730        SignedBitCount::new::<16>(),
3731        SignedBitCount::new::<10>(),
3732        4096,
3733        [15000.0, 25000.0, 20000.0, 30000.0]
3734            .into_iter()
3735            .map(|error| LpCoeff {
3736                coeffs: ArrayVec::default(),
3737                error,
3738            })
3739            .collect(),
3740    )
3741    .map(|t| t.0);
3742
3743    assert_float_approx!(bits.next().unwrap(), 1812.801817);
3744    assert_float_approx!(bits.next().unwrap(), 3346.934051);
3745    assert_float_approx!(bits.next().unwrap(), 2713.303385);
3746    assert_float_approx!(bits.next().unwrap(), 3935.492805);
3747}
3748
3749fn write_residuals<W: BitWrite>(
3750    options: &EncoderOptions,
3751    writer: &mut W,
3752    predictor_order: usize,
3753    residuals: &[i32],
3754) -> Result<(), Error> {
3755    use crate::stream::ResidualPartitionHeader;
3756    use bitstream_io::{BitCount, ToBitStream};
3757
3758    const MAX_PARTITIONS: usize = 64;
3759
3760    #[derive(Debug)]
3761    struct Partition<'r, const RICE_MAX: u32> {
3762        header: ResidualPartitionHeader<RICE_MAX>,
3763        residuals: &'r [i32],
3764    }
3765
3766    impl<'r, const RICE_MAX: u32> Partition<'r, RICE_MAX> {
3767        fn new(partition: &'r [i32], estimated_bits: &mut u32) -> Option<Self> {
3768            let partition_samples = partition.len() as u16;
3769            if partition_samples == 0 {
3770                return None;
3771            }
3772
3773            let partition_sum = partition
3774                .iter()
3775                .map(|i| u64::from(i.unsigned_abs()))
3776                .sum::<u64>();
3777
3778            if partition_sum > 0 {
3779                let rice = if partition_sum > partition_samples.into() {
3780                    let bits_needed = ((partition_sum as f64) / f64::from(partition_samples))
3781                        .log2()
3782                        .ceil() as u32;
3783
3784                    match BitCount::try_from(bits_needed).ok().filter(|rice| {
3785                        u32::from(*rice) < u32::from(BitCount::<RICE_MAX>::new::<RICE_MAX>())
3786                    }) {
3787                        Some(rice) => rice,
3788                        None => {
3789                            let escape_size = (partition
3790                                .iter()
3791                                .map(|i| u64::from(i.unsigned_abs()))
3792                                .sum::<u64>()
3793                                .ilog2()
3794                                + 2)
3795                            .try_into()
3796                            .ok()?;
3797
3798                            *estimated_bits +=
3799                                u32::from(escape_size) * u32::from(partition_samples);
3800
3801                            return Some(Self {
3802                                header: ResidualPartitionHeader::Escaped { escape_size },
3803                                residuals: partition,
3804                            });
3805                        }
3806                    }
3807                } else {
3808                    BitCount::new::<0>()
3809                };
3810
3811                let partition_size: u32 = 4u32
3812                    + ((1 + u32::from(rice)) * u32::from(partition_samples))
3813                    + if u32::from(rice) > 0 {
3814                        u32::try_from(partition_sum >> (u32::from(rice) - 1)).ok()?
3815                    } else {
3816                        u32::try_from(partition_sum << 1).ok()?
3817                    }
3818                    - (u32::from(partition_samples) / 2);
3819
3820                *estimated_bits += partition_size;
3821
3822                Some(Partition {
3823                    header: ResidualPartitionHeader::Standard { rice },
3824                    residuals: partition,
3825                })
3826            } else {
3827                // all partition residuals are 0, so use a constant
3828                Some(Partition {
3829                    header: ResidualPartitionHeader::Constant,
3830                    residuals: partition,
3831                })
3832            }
3833        }
3834    }
3835
3836    impl<const RICE_MAX: u32> ToBitStream for Partition<'_, RICE_MAX> {
3837        type Error = std::io::Error;
3838
3839        #[inline]
3840        fn to_writer<W: BitWrite + ?Sized>(&self, w: &mut W) -> Result<(), Self::Error> {
3841            w.build(&self.header)?;
3842            match self.header {
3843                ResidualPartitionHeader::Standard { rice } => {
3844                    let mask = rice.mask_lsb();
3845
3846                    self.residuals.iter().try_for_each(|s| {
3847                        let (msb, lsb) = mask(if s.is_negative() {
3848                            ((-*s as u32 - 1) << 1) + 1
3849                        } else {
3850                            (*s as u32) << 1
3851                        });
3852                        w.write_unary::<1>(msb)?;
3853                        w.write_checked(lsb)
3854                    })?;
3855                }
3856                ResidualPartitionHeader::Escaped { escape_size } => {
3857                    self.residuals
3858                        .iter()
3859                        .try_for_each(|s| w.write_signed_counted(escape_size, *s))?;
3860                }
3861                ResidualPartitionHeader::Constant => { /* nothing left to do */ }
3862            }
3863            Ok(())
3864        }
3865    }
3866
3867    fn best_partitions<'r, const RICE_MAX: u32>(
3868        options: &EncoderOptions,
3869        block_size: usize,
3870        residuals: &'r [i32],
3871    ) -> ArrayVec<Partition<'r, RICE_MAX>, MAX_PARTITIONS> {
3872        (0..=block_size.trailing_zeros().min(options.max_partition_order))
3873            .map(|partition_order| 1 << partition_order)
3874            .take_while(|partition_count: &usize| partition_count.is_power_of_two())
3875            .filter_map(|partition_count| {
3876                let mut estimated_bits = 0;
3877
3878                let partitions = residuals
3879                    .rchunks(block_size / partition_count)
3880                    .rev()
3881                    .map(|partition| Partition::new(partition, &mut estimated_bits))
3882                    .collect::<Option<ArrayVec<_, MAX_PARTITIONS>>>()
3883                    .filter(|p| !p.is_empty() && p.len().is_power_of_two())?;
3884
3885                Some((partitions, estimated_bits))
3886            })
3887            .min_by_key(|(_, estimated_bits)| *estimated_bits)
3888            .map(|(partitions, _)| partitions)
3889            .unwrap_or_else(|| {
3890                std::iter::once(Partition {
3891                    header: ResidualPartitionHeader::Escaped {
3892                        escape_size: SignedBitCount::new::<0b11111>(),
3893                    },
3894                    residuals,
3895                })
3896                .collect()
3897            })
3898    }
3899
3900    fn write_partitions<const RICE_MAX: u32, W: BitWrite>(
3901        writer: &mut W,
3902        partitions: ArrayVec<Partition<'_, RICE_MAX>, MAX_PARTITIONS>,
3903    ) -> Result<(), Error> {
3904        writer.write::<4, u32>(partitions.len().ilog2())?; // partition order
3905        for partition in partitions {
3906            writer.build(&partition)?;
3907        }
3908        Ok(())
3909    }
3910
3911    #[inline]
3912    fn try_shrink_header<const RICE_MAX: u32, const RICE_NEW_MAX: u32>(
3913        header: ResidualPartitionHeader<RICE_MAX>,
3914    ) -> Option<ResidualPartitionHeader<RICE_NEW_MAX>> {
3915        Some(match header {
3916            ResidualPartitionHeader::Standard { rice } => ResidualPartitionHeader::Standard {
3917                rice: rice.try_map(|r| (r < RICE_NEW_MAX).then_some(r))?,
3918            },
3919            ResidualPartitionHeader::Escaped { escape_size } => {
3920                ResidualPartitionHeader::Escaped { escape_size }
3921            }
3922            ResidualPartitionHeader::Constant => ResidualPartitionHeader::Constant,
3923        })
3924    }
3925
3926    enum CodingMethod<'p> {
3927        Rice(ArrayVec<Partition<'p, 0b1111>, MAX_PARTITIONS>),
3928        Rice2(ArrayVec<Partition<'p, 0b11111>, MAX_PARTITIONS>),
3929    }
3930
3931    fn try_reduce_rice(
3932        partitions: ArrayVec<Partition<'_, 0b11111>, MAX_PARTITIONS>,
3933    ) -> CodingMethod<'_> {
3934        match partitions
3935            .iter()
3936            .map(|Partition { header, residuals }| {
3937                try_shrink_header(*header).map(|header| Partition { header, residuals })
3938            })
3939            .collect()
3940        {
3941            Some(partitions) => CodingMethod::Rice(partitions),
3942            None => CodingMethod::Rice2(partitions),
3943        }
3944    }
3945
3946    let block_size = predictor_order + residuals.len();
3947
3948    if options.use_rice2 {
3949        match try_reduce_rice(best_partitions(options, block_size, residuals)) {
3950            CodingMethod::Rice(partitions) => {
3951                writer.write::<2, u8>(0)?; // coding method 0
3952                write_partitions(writer, partitions)
3953            }
3954            CodingMethod::Rice2(partitions) => {
3955                writer.write::<2, u8>(1)?; // coding method 1
3956                write_partitions(writer, partitions)
3957            }
3958        }
3959    } else {
3960        let partitions = best_partitions::<0b1111>(options, block_size, residuals);
3961        writer.write::<2, u8>(0)?; // coding method 0
3962        write_partitions(writer, partitions)
3963    }
3964}
3965
3966fn try_join<A, B, RA, RB, E>(oper_a: A, oper_b: B) -> Result<(RA, RB), E>
3967where
3968    A: FnOnce() -> Result<RA, E> + Send,
3969    B: FnOnce() -> Result<RB, E> + Send,
3970    RA: Send,
3971    RB: Send,
3972    E: Send,
3973{
3974    let (a, b) = join(oper_a, oper_b);
3975    Ok((a?, b?))
3976}
3977
3978#[cfg(feature = "rayon")]
3979use rayon::join;
3980
3981#[cfg(not(feature = "rayon"))]
3982fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
3983where
3984    A: FnOnce() -> RA + Send,
3985    B: FnOnce() -> RB + Send,
3986    RA: Send,
3987    RB: Send,
3988{
3989    (oper_a(), oper_b())
3990}
3991
3992#[cfg(feature = "rayon")]
3993fn vec_map<T, U, F>(src: Vec<T>, f: F) -> Vec<U>
3994where
3995    T: Send,
3996    U: Send,
3997    F: Fn(T) -> U + Send + Sync,
3998{
3999    use rayon::iter::{IntoParallelIterator, ParallelIterator};
4000
4001    src.into_par_iter().map(f).collect()
4002}
4003
4004#[cfg(not(feature = "rayon"))]
4005fn vec_map<T, U, F>(src: Vec<T>, f: F) -> Vec<U>
4006where
4007    T: Send,
4008    U: Send,
4009    F: Fn(T) -> U + Send + Sync,
4010{
4011    src.into_iter().map(f).collect()
4012}
4013
4014fn exact_div<N>(n: N, rhs: N) -> Option<N>
4015where
4016    N: std::ops::Div<Output = N> + std::ops::Rem<Output = N> + std::cmp::PartialEq + Copy + Default,
4017{
4018    (n % rhs == N::default()).then_some(n / rhs)
4019}