Skip to main content

flac_codec/
encode.rs

1// Copyright 2025 Brian Langenberger
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! For encoding PCM samples to FLAC files
10//!
11//! ## Multithreading
12//!
13//! Encoders will operate using multithreading if the optional `rayon` feature is enabled,
14//! typically boosting performance by processing channels in parallel.
15//! But because subframes must eventually be written serially, and their size cannot generally
16//! be known in advance, processing two channels across two threads will not
17//! encode twice as fast.
18
19use crate::audio::Frame;
20use crate::metadata::{
21    Application, BlockList, BlockSize, Cuesheet, Picture, PortableMetadataBlock, SeekPoint,
22    Streaminfo, VorbisComment, write_blocks,
23};
24use crate::stream::{ChannelAssignment, FrameNumber, Independent, SampleRate};
25use crate::{Counter, Error};
26use arrayvec::ArrayVec;
27use bitstream_io::{BigEndian, BitRecorder, BitWrite, BitWriter, SignedBitCount};
28use std::fs::File;
29use std::io::BufWriter;
30use std::num::NonZero;
31use std::path::Path;
32
33const MAX_CHANNELS: usize = 8;
34// maximum number of LPC coefficients
35const MAX_LPC_COEFFS: usize = 32;
36
37// Invent a vec!-like macro that the official crate lacks
38macro_rules! arrayvec {
39    ( $( $x:expr ),* ) => {
40        {
41            let mut v = ArrayVec::default();
42            $( v.push($x); )*
43            v
44        }
45    }
46}
47
48/// A FLAC writer which accepts samples as bytes
49///
50/// # Example
51///
52/// ```
53/// use flac_codec::{
54///     byteorder::LittleEndian,
55///     encode::{FlacByteWriter, Options},
56///     decode::{FlacByteReader, Metadata},
57/// };
58/// use std::io::{Cursor, Read, Seek, Write};
59///
60/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
61///
62/// let mut writer = FlacByteWriter::endian(
63///     &mut flac,           // our wrapped writer
64///     LittleEndian,        // .wav-style byte order
65///     Options::default(),  // default encoding options
66///     44100,               // sample rate
67///     16,                  // bits-per-sample
68///     1,                   // channel count
69///     Some(2000),          // total bytes
70/// ).unwrap();
71///
72/// // write 1000 samples as 16-bit, signed, little-endian bytes (2000 bytes total)
73/// let written_bytes = (0..1000).map(i16::to_le_bytes).flatten().collect::<Vec<u8>>();
74/// assert!(writer.write_all(&written_bytes).is_ok());
75///
76/// // finalize writing file
77/// assert!(writer.finalize().is_ok());
78///
79/// flac.rewind().unwrap();
80///
81/// // open reader around written FLAC file
82/// let mut reader = FlacByteReader::endian(flac, LittleEndian).unwrap();
83///
84/// // read 2000 bytes
85/// let mut read_bytes = vec![];
86/// assert!(reader.read_to_end(&mut read_bytes).is_ok());
87///
88/// // ensure MD5 sum of signed, little-endian samples matches hash in file
89/// let mut md5 = md5::Context::new();
90/// md5.consume(&read_bytes);
91/// assert_eq!(&md5.compute().0, reader.md5().unwrap());
92///
93/// // ensure input and output matches
94/// assert_eq!(read_bytes, written_bytes);
95/// ```
96pub struct FlacByteWriter<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> {
97    // the wrapped encoder
98    encoder: Encoder<W>,
99    // bytes that make up a partial FLAC frame
100    buf: Vec<u8>,
101    // a whole set of samples for a FLAC frame
102    frame: Frame,
103    // size of a single sample in bytes
104    bytes_per_sample: usize,
105    // size of single set of channel-independent samples in bytes
106    pcm_frame_size: usize,
107    // size of whole FLAC frame's samples in bytes
108    frame_byte_size: usize,
109    // whether the encoder has finalized the file
110    finalized: bool,
111    // the input bytes' endianness
112    endianness: std::marker::PhantomData<E>,
113}
114
115impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> FlacByteWriter<W, E> {
116    /// Creates new FLAC writer with the given parameters
117    ///
118    /// The writer should be positioned at the start of the file.
119    ///
120    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
121    ///
122    /// `bits_per_sample` must be between 1 and 32.
123    ///
124    /// `channels` must be between 1 and 8.
125    ///
126    /// Note that if `total_bytes` is indicated,
127    /// the number of channel-independent samples written *must*
128    /// be equal to that amount or an error will occur when writing
129    /// or finalizing the stream.
130    ///
131    /// # Errors
132    ///
133    /// Returns I/O error if unable to write initial
134    /// metadata blocks.
135    /// Returns error if any of the encoding parameters are invalid.
136    pub fn new(
137        writer: W,
138        options: Options,
139        sample_rate: u32,
140        bits_per_sample: u32,
141        channels: u8,
142        total_bytes: Option<u64>,
143    ) -> Result<Self, Error> {
144        let bits_per_sample: SignedBitCount<32> = bits_per_sample
145            .try_into()
146            .map_err(|_| Error::InvalidBitsPerSample)?;
147
148        let bytes_per_sample = u32::from(bits_per_sample).div_ceil(8) as usize;
149
150        let pcm_frame_size = bytes_per_sample * channels as usize;
151
152        Ok(Self {
153            buf: Vec::default(),
154            frame: Frame::empty(channels.into(), bits_per_sample.into()),
155            bytes_per_sample,
156            pcm_frame_size,
157            frame_byte_size: pcm_frame_size * options.block_size as usize,
158            encoder: Encoder::new(
159                writer,
160                options,
161                sample_rate,
162                bits_per_sample,
163                channels,
164                total_bytes
165                    .map(|bytes| {
166                        exact_div(bytes, channels.into())
167                            .and_then(|s| exact_div(s, bytes_per_sample as u64))
168                            .ok_or(Error::SamplesNotDivisibleByChannels)
169                            .and_then(|b| NonZero::new(b).ok_or(Error::InvalidTotalBytes))
170                    })
171                    .transpose()?,
172            )?,
173            finalized: false,
174            endianness: std::marker::PhantomData,
175        })
176    }
177
178    /// Creates new FLAC writer with CDDA parameters
179    ///
180    /// The writer should be positioned at the start of the file.
181    ///
182    /// Sample rate is 44100 Hz, bits-per-sample is 16,
183    /// channels is 2.
184    ///
185    /// Note that if `total_bytes` is indicated,
186    /// the number of channel-independent samples written *must*
187    /// be equal to that amount or an error will occur when writing
188    /// or finalizing the stream.
189    ///
190    /// # Errors
191    ///
192    /// Returns I/O error if unable to write initial
193    /// metadata blocks.
194    /// Returns error if any of the encoding parameters are invalid.
195    pub fn new_cdda(writer: W, options: Options, total_bytes: Option<u64>) -> Result<Self, Error> {
196        Self::new(writer, options, 44100, 16, 2, total_bytes)
197    }
198
199    /// Creates new FLAC writer in the given endianness with the given parameters
200    ///
201    /// The writer should be positioned at the start of the file.
202    ///
203    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
204    ///
205    /// `bits_per_sample` must be between 1 and 32.
206    ///
207    /// `channels` must be between 1 and 8.
208    ///
209    /// Note that if `total_bytes` is indicated,
210    /// the number of bytes written *must*
211    /// be equal to that amount or an error will occur when writing
212    /// or finalizing the stream.
213    ///
214    /// # Errors
215    ///
216    /// Returns I/O error if unable to write initial
217    /// metadata blocks.
218    #[inline]
219    pub fn endian(
220        writer: W,
221        _endianness: E,
222        options: Options,
223        sample_rate: u32,
224        bits_per_sample: u32,
225        channels: u8,
226        total_bytes: Option<u64>,
227    ) -> Result<Self, Error> {
228        Self::new(
229            writer,
230            options,
231            sample_rate,
232            bits_per_sample,
233            channels,
234            total_bytes,
235        )
236    }
237
238    fn finalize_inner(&mut self) -> Result<(), Error> {
239        if !self.finalized {
240            self.finalized = true;
241
242            // encode as many bytes as possible into final frame, if necessary
243            if !self.buf.is_empty() {
244                use crate::byteorder::LittleEndian;
245
246                // truncate buffer to whole PCM frames
247                let buf = self.buf.as_mut_slice();
248                let buf_len = buf.len();
249                let buf = &mut buf[..(buf_len - buf_len % self.pcm_frame_size)];
250
251                // convert buffer to little-endian bytes
252                E::bytes_to_le(buf, self.bytes_per_sample);
253
254                // update MD5 sum with little-endian bytes
255                self.encoder.md5.consume(&buf);
256
257                self.encoder
258                    .encode(self.frame.fill_from_buf::<LittleEndian>(buf))?;
259            }
260
261            self.encoder.finalize_inner()
262        } else {
263            Ok(())
264        }
265    }
266
267    /// Attempt to finalize stream
268    ///
269    /// It is necessary to finalize the FLAC encoder
270    /// so that it will write any partially unwritten samples
271    /// to the stream and update the [`crate::metadata::Streaminfo`] and [`crate::metadata::SeekTable`] blocks
272    /// with their final values.
273    ///
274    /// Dropping the encoder will attempt to finalize the stream
275    /// automatically, but will ignore any errors that may occur.
276    pub fn finalize(mut self) -> Result<(), Error> {
277        self.finalize_inner()?;
278        Ok(())
279    }
280}
281
282impl<E: crate::byteorder::Endianness> FlacByteWriter<BufWriter<File>, E> {
283    /// Creates new FLAC file at the given path
284    ///
285    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
286    ///
287    /// `bits_per_sample` must be between 1 and 32.
288    ///
289    /// `channels` must be between 1 and 8.
290    ///
291    /// Note that if `total_bytes` is indicated,
292    /// the number of bytes written *must*
293    /// be equal to that amount or an error will occur when writing
294    /// or finalizing the stream.
295    ///
296    /// # Errors
297    ///
298    /// Returns I/O error if unable to write initial
299    /// metadata blocks.
300    #[inline]
301    pub fn create<P: AsRef<Path>>(
302        path: P,
303        options: Options,
304        sample_rate: u32,
305        bits_per_sample: u32,
306        channels: u8,
307        total_bytes: Option<u64>,
308    ) -> Result<Self, Error> {
309        FlacByteWriter::new(
310            BufWriter::new(options.create(path)?),
311            options,
312            sample_rate,
313            bits_per_sample,
314            channels,
315            total_bytes,
316        )
317    }
318
319    /// Creates new FLAC file with CDDA parameters at the given path
320    ///
321    /// Sample rate is 44100 Hz, bits-per-sample is 16,
322    /// channels is 2.
323    ///
324    /// Note that if `total_bytes` is indicated,
325    /// the number of bytes written *must*
326    /// be equal to that amount or an error will occur when writing
327    /// or finalizing the stream.
328    ///
329    /// # Errors
330    ///
331    /// Returns I/O error if unable to write initial
332    /// metadata blocks.
333    pub fn create_cdda<P: AsRef<Path>>(
334        path: P,
335        options: Options,
336        total_bytes: Option<u64>,
337    ) -> Result<Self, Error> {
338        Self::create(path, options, 44100, 16, 2, total_bytes)
339    }
340}
341
342impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> std::io::Write
343    for FlacByteWriter<W, E>
344{
345    /// Writes a set of sample bytes to the FLAC file
346    ///
347    /// Samples are signed and encoded in the stream's given byte order.
348    ///
349    /// Samples are then interleaved by channel, like:
350    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
351    ///
352    /// This is the same format used by common PCM container
353    /// formats like WAVE and AIFF.
354    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
355        use crate::byteorder::LittleEndian;
356
357        // dump whole set of bytes into our internal buffer
358        self.buf.extend(buf);
359
360        // encode as many FLAC frames as possible (which may be 0)
361        let mut encoded_frames = 0;
362        for buf in self
363            .buf
364            .as_mut_slice()
365            .chunks_exact_mut(self.frame_byte_size)
366        {
367            // convert buffer to little-endian bytes
368            E::bytes_to_le(buf, self.bytes_per_sample);
369
370            // update MD5 sum with little-endian bytes
371            self.encoder.md5.consume(&buf);
372
373            // encode fresh FLAC frame
374            self.encoder
375                .encode(self.frame.fill_from_buf::<LittleEndian>(buf))?;
376
377            encoded_frames += 1;
378        }
379
380        self.buf.drain(0..self.frame_byte_size * encoded_frames);
381
382        // indicate whole buffer's been consumed
383        Ok(buf.len())
384    }
385
386    #[inline]
387    fn flush(&mut self) -> std::io::Result<()> {
388        // we don't want to flush a partial frame to disk,
389        // but we can at least flush our internal writer
390        self.encoder.writer.flush()
391    }
392}
393
394impl<W: std::io::Write + std::io::Seek, E: crate::byteorder::Endianness> Drop
395    for FlacByteWriter<W, E>
396{
397    fn drop(&mut self) {
398        let _ = self.finalize_inner();
399    }
400}
401
402/// A FLAC writer which accepts samples as signed integers
403///
404/// # Example
405///
406/// ```
407/// use flac_codec::{
408///     encode::{FlacSampleWriter, Options},
409///     decode::FlacSampleReader,
410/// };
411/// use std::io::{Cursor, Seek};
412///
413/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
414///
415/// let mut writer = FlacSampleWriter::new(
416///     &mut flac,           // our wrapped writer
417///     Options::default(),  // default encoding options
418///     44100,               // sample rate
419///     16,                  // bits-per-sample
420///     1,                   // channel count
421///     Some(1000),          // total samples
422/// ).unwrap();
423///
424/// // write 1000 samples
425/// let written_samples = (0..1000).collect::<Vec<i32>>();
426/// assert!(writer.write(&written_samples).is_ok());
427///
428/// // finalize writing file
429/// assert!(writer.finalize().is_ok());
430///
431/// flac.rewind().unwrap();
432///
433/// // open reader around written FLAC file
434/// let mut reader = FlacSampleReader::new(flac).unwrap();
435///
436/// // read 1000 samples
437/// let mut read_samples = vec![0; 1000];
438/// assert!(matches!(reader.read(&mut read_samples), Ok(1000)));
439///
440/// // ensure they match
441/// assert_eq!(read_samples, written_samples);
442/// ```
443pub struct FlacSampleWriter<W: std::io::Write + std::io::Seek> {
444    // the wrapped encoder
445    encoder: Encoder<W>,
446    // samples that make up a partial FLAC frame
447    // in channel-interleaved order
448    // (must de-interleave later in case someone writes
449    // only partial set of channels in a single write call)
450    sample_buf: Vec<i32>,
451    // a whole set of samples for a FLAC frame
452    frame: Frame,
453    // size of a single frame in samples
454    frame_sample_size: usize,
455    // size of a single PCM frame in samples
456    pcm_frame_size: usize,
457    // size of a single sample in bytes
458    bytes_per_sample: usize,
459    // whether the encoder has finalized the file
460    finalized: bool,
461}
462
463impl<W: std::io::Write + std::io::Seek> FlacSampleWriter<W> {
464    /// Creates new FLAC writer with the given parameters
465    ///
466    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
467    ///
468    /// `bits_per_sample` must be between 1 and 32.
469    ///
470    /// `channels` must be between 1 and 8.
471    ///
472    /// Note that if `total_samples` is indicated,
473    /// the number of samples written *must*
474    /// be equal to that amount or an error will occur when writing
475    /// or finalizing the stream.
476    ///
477    /// # Errors
478    ///
479    /// Returns I/O error if unable to write initial
480    /// metadata blocks.
481    /// Returns error if any of the encoding parameters are invalid.
482    pub fn new(
483        writer: W,
484        options: Options,
485        sample_rate: u32,
486        bits_per_sample: u32,
487        channels: u8,
488        total_samples: Option<u64>,
489    ) -> Result<Self, Error> {
490        let bits_per_sample: SignedBitCount<32> = bits_per_sample
491            .try_into()
492            .map_err(|_| Error::InvalidBitsPerSample)?;
493
494        let bytes_per_sample = u32::from(bits_per_sample).div_ceil(8) as usize;
495
496        let pcm_frame_size = usize::from(channels);
497
498        Ok(Self {
499            sample_buf: Vec::default(),
500            frame: Frame::empty(channels.into(), bits_per_sample.into()),
501            bytes_per_sample,
502            pcm_frame_size,
503            frame_sample_size: pcm_frame_size * options.block_size as usize,
504            encoder: Encoder::new(
505                writer,
506                options,
507                sample_rate,
508                bits_per_sample,
509                channels,
510                total_samples
511                    .map(|samples| {
512                        exact_div(samples, channels.into())
513                            .ok_or(Error::SamplesNotDivisibleByChannels)
514                            .and_then(|s| NonZero::new(s).ok_or(Error::InvalidTotalSamples))
515                    })
516                    .transpose()?,
517            )?,
518            finalized: false,
519        })
520    }
521
522    /// Creates new FLAC writer with CDDA parameters
523    ///
524    /// Sample rate is 44100 Hz, bits-per-sample is 16,
525    /// channels is 2.
526    ///
527    /// Note that if `total_samples` is indicated,
528    /// the number of samples written *must*
529    /// be equal to that amount or an error will occur when writing
530    /// or finalizing the stream.
531    ///
532    /// # Errors
533    ///
534    /// Returns I/O error if unable to write initial
535    /// metadata blocks.
536    /// Returns error if any of the encoding parameters are invalid.
537    pub fn new_cdda(
538        writer: W,
539        options: Options,
540        total_samples: Option<u64>,
541    ) -> Result<Self, Error> {
542        Self::new(writer, options, 44100, 16, 2, total_samples)
543    }
544
545    /// Given a set of samples, writes them to the FLAC file
546    ///
547    /// Samples are interleaved by channel, like:
548    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
549    ///
550    /// This may output 0 or more actual FLAC frames,
551    /// depending on the quantity of samples and the amount
552    /// previously written.
553    pub fn write(&mut self, samples: &[i32]) -> Result<(), Error> {
554        // dump whole set of samples into our internal buffer
555        self.sample_buf.extend(samples);
556
557        // encode as many FLAC frames as possible (which may be 0)
558        let mut encoded_frames = 0;
559        for buf in self
560            .sample_buf
561            .as_mut_slice()
562            .chunks_exact_mut(self.frame_sample_size)
563        {
564            // update running MD5 sum calculation
565            // since samples are already interleaved in channel order
566            update_md5(
567                &mut self.encoder.md5,
568                buf.iter().copied(),
569                self.bytes_per_sample,
570            );
571
572            // encode fresh FLAC frame
573            self.encoder.encode(self.frame.fill_from_samples(buf))?;
574
575            encoded_frames += 1;
576        }
577        self.sample_buf
578            .drain(0..self.frame_sample_size * encoded_frames);
579
580        Ok(())
581    }
582
583    fn finalize_inner(&mut self) -> Result<(), Error> {
584        if !self.finalized {
585            self.finalized = true;
586
587            // encode as many samples possible into final frame, if necessary
588            if !self.sample_buf.is_empty() {
589                // truncate buffer to whole PCM frames
590                let buf = self.sample_buf.as_mut_slice();
591                let buf_len = buf.len();
592                let buf = &mut buf[..(buf_len - buf_len % self.pcm_frame_size)];
593
594                // update running MD5 sum calculation
595                // since samples are already interleaved in channel order
596                update_md5(
597                    &mut self.encoder.md5,
598                    buf.iter().copied(),
599                    self.bytes_per_sample,
600                );
601
602                // encode final FLAC frame
603                self.encoder.encode(self.frame.fill_from_samples(buf))?;
604            }
605
606            self.encoder.finalize_inner()
607        } else {
608            Ok(())
609        }
610    }
611
612    /// Attempt to finalize stream
613    ///
614    /// It is necessary to finalize the FLAC encoder
615    /// so that it will write any partially unwritten samples
616    /// to the stream and update the [`crate::metadata::Streaminfo`] and [`crate::metadata::SeekTable`] blocks
617    /// with their final values.
618    ///
619    /// Dropping the encoder will attempt to finalize the stream
620    /// automatically, but will ignore any errors that may occur.
621    pub fn finalize(mut self) -> Result<(), Error> {
622        self.finalize_inner()?;
623        Ok(())
624    }
625}
626
627impl FlacSampleWriter<BufWriter<File>> {
628    /// Creates new FLAC file at the given path
629    ///
630    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
631    ///
632    /// `bits_per_sample` must be between 1 and 32.
633    ///
634    /// `channels` must be between 1 and 8.
635    ///
636    /// Note that if `total_bytes` is indicated,
637    /// the number of bytes written *must*
638    /// be equal to that amount or an error will occur when writing
639    /// or finalizing the stream.
640    ///
641    /// # Errors
642    ///
643    /// Returns I/O error if unable to write initial
644    /// metadata blocks.
645    #[inline]
646    pub fn create<P: AsRef<Path>>(
647        path: P,
648        options: Options,
649        sample_rate: u32,
650        bits_per_sample: u32,
651        channels: u8,
652        total_samples: Option<u64>,
653    ) -> Result<Self, Error> {
654        FlacSampleWriter::new(
655            BufWriter::new(options.create(path)?),
656            options,
657            sample_rate,
658            bits_per_sample,
659            channels,
660            total_samples,
661        )
662    }
663
664    /// Creates new FLAC file at the given path with CDDA parameters
665    ///
666    /// Sample rate is 44100 Hz, bits-per-sample is 16,
667    /// channels is 2.
668    ///
669    /// Note that if `total_bytes` is indicated,
670    /// the number of bytes written *must*
671    /// be equal to that amount or an error will occur when writing
672    /// or finalizing the stream.
673    ///
674    /// # Errors
675    ///
676    /// Returns I/O error if unable to write initial
677    /// metadata blocks.
678    #[inline]
679    pub fn create_cdda<P: AsRef<Path>>(
680        path: P,
681        options: Options,
682        total_samples: Option<u64>,
683    ) -> Result<Self, Error> {
684        Self::create(path, options, 44100, 16, 2, total_samples)
685    }
686}
687
688/// A FLAC writer which accepts samples as channels of signed integers
689///
690/// # Example
691/// ```
692/// use flac_codec::{
693///     encode::{FlacChannelWriter, Options},
694///     decode::FlacChannelReader,
695/// };
696/// use std::io::{Cursor, Seek};
697///
698/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
699///
700/// let mut writer = FlacChannelWriter::new(
701///     &mut flac,           // our wrapped writer
702///     Options::default(),  // default encoding options
703///     44100,               // sample rate
704///     16,                  // bits-per-sample
705///     2,                   // channel count
706///     Some(5),             // total channel-independent samples
707/// ).unwrap();
708///
709/// // write our samples, divided by channel
710/// let written_samples = vec![
711///     vec![1, 2, 3, 4, 5],
712///     vec![-1, -2, -3, -4, -5],
713/// ];
714/// assert!(writer.write(&written_samples).is_ok());
715///
716/// // finalize writing file
717/// assert!(writer.finalize().is_ok());
718///
719/// flac.rewind().unwrap();
720///
721/// // open reader around written FLAC file
722/// let mut reader = FlacChannelReader::new(flac).unwrap();
723///
724/// // read a buffer's worth of samples
725/// let read_samples = reader.fill_buf().unwrap();
726///
727/// // ensure the channels match
728/// assert_eq!(read_samples.len(), written_samples.len());
729/// assert_eq!(read_samples[0], written_samples[0]);
730/// assert_eq!(read_samples[1], written_samples[1]);
731/// ```
732pub struct FlacChannelWriter<W: std::io::Write + std::io::Seek> {
733    // the wrapped encoder
734    encoder: Encoder<W>,
735    // channels that make up a partial FLAC frame
736    channel_bufs: Vec<Vec<i32>>,
737    // a whole set of samples for a FLAC frame
738    frame: Frame,
739    // size of a single frame in samples
740    frame_sample_size: usize,
741    // size of a single sample in bytes
742    bytes_per_sample: usize,
743    // whether the encoder has finalized the file
744    finalized: bool,
745}
746
747impl<W: std::io::Write + std::io::Seek> FlacChannelWriter<W> {
748    /// Creates new FLAC writer with the given parameters
749    ///
750    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
751    ///
752    /// `bits_per_sample` must be between 1 and 32.
753    ///
754    /// `channels` must be between 1 and 8.
755    ///
756    /// Note that if `total_samples` is indicated,
757    /// the number of samples written *must*
758    /// be equal to that amount or an error will occur when writing
759    /// or finalizing the stream.
760    ///
761    /// # Errors
762    ///
763    /// Returns I/O error if unable to write initial
764    /// metadata blocks.
765    /// Returns error if any of the encoding parameters are invalid.
766    pub fn new(
767        writer: W,
768        options: Options,
769        sample_rate: u32,
770        bits_per_sample: u32,
771        channels: u8,
772        total_samples: Option<u64>,
773    ) -> Result<Self, Error> {
774        let bits_per_sample: SignedBitCount<32> = bits_per_sample
775            .try_into()
776            .map_err(|_| Error::InvalidBitsPerSample)?;
777
778        let bytes_per_sample = u32::from(bits_per_sample).div_ceil(8) as usize;
779
780        Ok(Self {
781            channel_bufs: vec![Vec::default(); channels.into()],
782            frame: Frame::empty(channels.into(), bits_per_sample.into()),
783            bytes_per_sample,
784            frame_sample_size: options.block_size as usize,
785            encoder: Encoder::new(
786                writer,
787                options,
788                sample_rate,
789                bits_per_sample,
790                channels,
791                total_samples.and_then(NonZero::new),
792            )?,
793            finalized: false,
794        })
795    }
796
797    /// Creates new FLAC writer with CDDA parameters
798    ///
799    /// Sample rate is 44100 Hz, bits-per-sample is 16,
800    /// channels is 2.
801    ///
802    /// Note that if `total_samples` is indicated,
803    /// the number of samples written *must*
804    /// be equal to that amount or an error will occur when writing
805    /// or finalizing the stream.
806    ///
807    /// # Errors
808    ///
809    /// Returns I/O error if unable to write initial
810    /// metadata blocks.
811    /// Returns error if any of the encoding parameters are invalid.
812    pub fn new_cdda(
813        writer: W,
814        options: Options,
815        total_samples: Option<u64>,
816    ) -> Result<Self, Error> {
817        Self::new(writer, options, 44100, 16, 2, total_samples)
818    }
819
820    /// Given a set of channels containing samples, writes them to the FLAC file
821    ///
822    /// Channels should be a slice-able set of sample slices, like:
823    /// [[left₀ , left₁ , left₂ , …] , [right₀ , right₁ , right₂ , …]]
824    ///
825    /// The number of channels must be identical to the
826    /// channel count indicated when intializing the encoder.
827    ///
828    /// The number of samples in each channel must also be identical.
829    pub fn write<C, S>(&mut self, channels: C) -> Result<(), Error>
830    where
831        C: AsRef<[S]>,
832        S: AsRef<[i32]>,
833    {
834        use crate::audio::MultiZip;
835
836        // sanity-check our channel inputs
837        let channels = channels.as_ref();
838
839        match channels {
840            whole @ [first, rest @ ..]
841                if whole.len() == usize::from(self.encoder.channel_count().get()) =>
842            {
843                if rest
844                    .iter()
845                    .any(|c| c.as_ref().len() != first.as_ref().len())
846                {
847                    return Err(Error::ChannelLengthMismatch);
848                }
849            }
850            _ => {
851                return Err(Error::ChannelCountMismatch);
852            }
853        }
854
855        // dump whole set of samples into our internal channel buffers
856        for (buf, channel) in self.channel_bufs.iter_mut().zip(channels) {
857            buf.extend(channel.as_ref());
858        }
859
860        // encode as many FLAC frames as possible (which may be 0)
861        let mut encoded_frames = 0;
862        for bufs in self
863            .channel_bufs
864            .iter_mut()
865            .map(|v| v.as_mut_slice().chunks_exact_mut(self.frame_sample_size))
866            .collect::<MultiZip<_>>()
867        {
868            // update running MD5 sum calculation
869            update_md5(
870                &mut self.encoder.md5,
871                bufs.iter()
872                    .map(|c| c.iter().copied())
873                    .collect::<MultiZip<_>>()
874                    .flatten(),
875                self.bytes_per_sample,
876            );
877
878            // encode fresh FLAC frame
879            self.encoder.encode(self.frame.fill_from_channels(&bufs))?;
880
881            encoded_frames += 1;
882        }
883
884        // drain encoded samples from buffers
885        for channel in self.channel_bufs.iter_mut() {
886            channel.drain(0..self.frame_sample_size * encoded_frames);
887        }
888
889        Ok(())
890    }
891
892    fn finalize_inner(&mut self) -> Result<(), Error> {
893        use crate::audio::MultiZip;
894
895        if !self.finalized {
896            self.finalized = true;
897
898            // encode as many samples possible into final frame, if necessary
899            if !self.channel_bufs[0].is_empty() {
900                // update running MD5 sum calculation
901                update_md5(
902                    &mut self.encoder.md5,
903                    self.channel_bufs
904                        .iter()
905                        .map(|c| c.iter().copied())
906                        .collect::<MultiZip<_>>()
907                        .flatten(),
908                    self.bytes_per_sample,
909                );
910
911                // encode final FLAC frame
912                self.encoder.encode(
913                    self.frame.fill_from_channels(
914                        self.channel_bufs
915                            .iter_mut()
916                            .map(|v| v.as_mut_slice())
917                            .collect::<ArrayVec<_, MAX_CHANNELS>>()
918                            .as_slice(),
919                    ),
920                )?;
921            }
922
923            self.encoder.finalize_inner()
924        } else {
925            Ok(())
926        }
927    }
928
929    /// Attempt to finalize stream
930    ///
931    /// It is necessary to finalize the FLAC encoder
932    /// so that it will write any partially unwritten samples
933    /// to the stream and update the [`crate::metadata::Streaminfo`] and [`crate::metadata::SeekTable`] blocks
934    /// with their final values.
935    ///
936    /// Dropping the encoder will attempt to finalize the stream
937    /// automatically, but will ignore any errors that may occur.
938    pub fn finalize(mut self) -> Result<(), Error> {
939        self.finalize_inner()?;
940        Ok(())
941    }
942}
943
944impl FlacChannelWriter<BufWriter<File>> {
945    /// Creates new FLAC file at the given path
946    ///
947    /// `sample_rate` must be between 0 (for non-audio streams) and 2²⁰.
948    ///
949    /// `bits_per_sample` must be between 1 and 32.
950    ///
951    /// `channels` must be between 1 and 8.
952    ///
953    /// Note that if `total_bytes` is indicated,
954    /// the number of bytes written *must*
955    /// be equal to that amount or an error will occur when writing
956    /// or finalizing the stream.
957    ///
958    /// # Errors
959    ///
960    /// Returns I/O error if unable to write initial
961    /// metadata blocks.
962    #[inline]
963    pub fn create<P: AsRef<Path>>(
964        path: P,
965        options: Options,
966        sample_rate: u32,
967        bits_per_sample: u32,
968        channels: u8,
969        total_samples: Option<u64>,
970    ) -> Result<Self, Error> {
971        FlacChannelWriter::new(
972            BufWriter::new(options.create(path)?),
973            options,
974            sample_rate,
975            bits_per_sample,
976            channels,
977            total_samples,
978        )
979    }
980
981    /// Creates new FLAC file at the given path with CDDA parameters
982    ///
983    /// Sample rate is 44100 Hz, bits-per-sample is 16,
984    /// channels is 2.
985    ///
986    /// Note that if `total_bytes` is indicated,
987    /// the number of bytes written *must*
988    /// be equal to that amount or an error will occur when writing
989    /// or finalizing the stream.
990    ///
991    /// # Errors
992    ///
993    /// Returns I/O error if unable to write initial
994    /// metadata blocks.
995    #[inline]
996    pub fn create_cdda<P: AsRef<Path>>(
997        path: P,
998        options: Options,
999        total_samples: Option<u64>,
1000    ) -> Result<Self, Error> {
1001        Self::create(path, options, 44100, 16, 2, total_samples)
1002    }
1003}
1004
1005/// A FLAC writer which operates on streamed output
1006///
1007/// Because this encodes FLAC frames without any metadata
1008/// blocks or finalizing, it does not need to be seekable.
1009///
1010/// # Example
1011///
1012/// ```
1013/// use flac_codec::{
1014///     decode::{FlacStreamReader, FrameBuf},
1015///     encode::{FlacStreamWriter, Options},
1016/// };
1017/// use std::io::{Cursor, Seek};
1018/// use bitstream_io::SignedBitCount;
1019///
1020/// let mut flac = Cursor::new(vec![]);
1021///
1022/// let samples = (0..100).collect::<Vec<i32>>();
1023///
1024/// let mut w = FlacStreamWriter::new(&mut flac, Options::default());
1025///
1026/// // write a single FLAC frame with some samples
1027/// w.write(
1028///     44100,  // sample rate
1029///     1,      // channels
1030///     16,     // bits-per-sample
1031///     &samples,
1032/// ).unwrap();
1033///
1034/// flac.rewind().unwrap();
1035///
1036/// let mut r = FlacStreamReader::new(&mut flac);
1037///
1038/// // read a single FLAC frame with some samples
1039/// assert_eq!(
1040///     r.read().unwrap(),
1041///     FrameBuf {
1042///         samples: &samples,
1043///         sample_rate: 44100,
1044///         channels: 1,
1045///         bits_per_sample: 16,
1046///     },
1047/// );
1048/// ```
1049pub struct FlacStreamWriter<W> {
1050    // the writer we're outputting to
1051    writer: W,
1052    // various encoding optins
1053    options: EncoderOptions,
1054    // various encoder caches
1055    caches: EncodingCaches,
1056    // a whole set of samples for a FLAC frame
1057    frame: Frame,
1058    // the current frame number
1059    frame_number: FrameNumber,
1060}
1061
1062impl<W: std::io::Write> FlacStreamWriter<W> {
1063    /// Creates new stream writer
1064    pub fn new(writer: W, options: Options) -> Self {
1065        Self {
1066            writer,
1067            options: EncoderOptions {
1068                max_partition_order: options.max_partition_order,
1069                mid_side: options.mid_side,
1070                seektable_interval: options.seektable_interval,
1071                max_lpc_order: options.max_lpc_order,
1072                window: options.window,
1073                exhaustive_channel_correlation: options.exhaustive_channel_correlation,
1074                use_rice2: false,
1075            },
1076            caches: EncodingCaches::default(),
1077            frame: Frame::default(),
1078            frame_number: FrameNumber::default(),
1079        }
1080    }
1081
1082    /// Writes a whole FLAC frame to our output stream
1083    ///
1084    /// Samples are interleaved by channel, like:
1085    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
1086    ///
1087    /// This writes a whole FLAC frame to the output stream on each call.
1088    ///
1089    /// # Errors
1090    ///
1091    /// Returns an error of any of the parameters are invalid
1092    /// or if an I/O error occurs when writing to the stream.
1093    pub fn write(
1094        &mut self,
1095        sample_rate: u32,
1096        channels: u8,
1097        bits_per_sample: u32,
1098        samples: &[i32],
1099    ) -> Result<(), Error> {
1100        use crate::crc::{Crc16, CrcWriter};
1101        use crate::stream::{BitsPerSample, FrameHeader, SampleRate};
1102
1103        let bits_per_sample: SignedBitCount<32> = bits_per_sample
1104            .try_into()
1105            .map_err(|_| Error::NonSubsetBitsPerSample)?;
1106
1107        // samples must divide evenly into channels
1108        if !samples.len().is_multiple_of(usize::from(channels)) {
1109            return Err(Error::SamplesNotDivisibleByChannels);
1110        } else if !(1..=8).contains(&channels) {
1111            return Err(Error::ExcessiveChannels);
1112        }
1113
1114        self.options.use_rice2 = u32::from(bits_per_sample) > 16;
1115
1116        self.frame
1117            .resize(bits_per_sample.into(), channels.into(), 0);
1118        self.frame.fill_from_samples(samples);
1119
1120        // block size must be valid
1121        let block_size: crate::stream::BlockSize<u16> = crate::stream::BlockSize::try_from(
1122            u16::try_from(self.frame.pcm_frames()).map_err(|_| Error::InvalidBlockSize)?,
1123        )
1124        .map_err(|_| Error::InvalidBlockSize)?;
1125
1126        // sample rate must be valid for subset streams
1127        let sample_rate: SampleRate<u32> = sample_rate.try_into().and_then(|rate| match rate {
1128            SampleRate::Streaminfo(_) => Err(Error::NonSubsetSampleRate),
1129            rate => Ok(rate),
1130        })?;
1131
1132        // bits-per-sample must be valid for subset streams
1133        let header_bits_per_sample = match BitsPerSample::from(bits_per_sample) {
1134            BitsPerSample::Streaminfo(_) => return Err(Error::NonSubsetBitsPerSample),
1135            bps => bps,
1136        };
1137
1138        let mut w: CrcWriter<_, Crc16> = CrcWriter::new(&mut self.writer);
1139        let mut bw: BitWriter<CrcWriter<&mut W, Crc16>, BigEndian>;
1140
1141        match self
1142            .frame
1143            .channels()
1144            .collect::<ArrayVec<&[i32], MAX_CHANNELS>>()
1145            .as_slice()
1146        {
1147            [channel] => {
1148                FrameHeader {
1149                    blocking_strategy: false,
1150                    frame_number: self.frame_number,
1151                    block_size: (channel.len() as u16)
1152                        .try_into()
1153                        .expect("frame cannot be empty"),
1154                    sample_rate,
1155                    bits_per_sample: header_bits_per_sample,
1156                    channel_assignment: ChannelAssignment::Independent(Independent::Mono),
1157                }
1158                .write_subset(&mut w)?;
1159
1160                bw = BitWriter::new(w);
1161
1162                self.caches.channels.resize_with(1, ChannelCache::default);
1163
1164                encode_subframe(
1165                    &self.options,
1166                    &mut self.caches.channels[0],
1167                    CorrelatedChannel::independent(bits_per_sample, channel),
1168                )?
1169                .playback(&mut bw)?;
1170            }
1171            [left, right] if self.options.exhaustive_channel_correlation => {
1172                let Correlated {
1173                    channel_assignment,
1174                    channels: [channel_0, channel_1],
1175                } = correlate_channels_exhaustive(
1176                    &self.options,
1177                    &mut self.caches.correlated,
1178                    [left, right],
1179                    bits_per_sample,
1180                )?;
1181
1182                FrameHeader {
1183                    blocking_strategy: false,
1184                    frame_number: self.frame_number,
1185                    block_size,
1186                    sample_rate,
1187                    bits_per_sample: header_bits_per_sample,
1188                    channel_assignment,
1189                }
1190                .write_subset(&mut w)?;
1191
1192                bw = BitWriter::new(w);
1193
1194                channel_0.playback(&mut bw)?;
1195                channel_1.playback(&mut bw)?;
1196            }
1197            [left, right] => {
1198                let Correlated {
1199                    channel_assignment,
1200                    channels: [channel_0, channel_1],
1201                } = correlate_channels(
1202                    &self.options,
1203                    &mut self.caches.correlated,
1204                    [left, right],
1205                    bits_per_sample,
1206                );
1207
1208                FrameHeader {
1209                    blocking_strategy: false,
1210                    frame_number: self.frame_number,
1211                    block_size,
1212                    sample_rate,
1213                    bits_per_sample: header_bits_per_sample,
1214                    channel_assignment,
1215                }
1216                .write_subset(&mut w)?;
1217
1218                self.caches.channels.resize_with(2, ChannelCache::default);
1219                let [cache_0, cache_1] = self.caches.channels.get_disjoint_mut([0, 1]).unwrap();
1220                let (channel_0, channel_1) = join(
1221                    || encode_subframe(&self.options, cache_0, channel_0),
1222                    || encode_subframe(&self.options, cache_1, channel_1),
1223                );
1224
1225                bw = BitWriter::new(w);
1226
1227                channel_0?.playback(&mut bw)?;
1228                channel_1?.playback(&mut bw)?;
1229            }
1230            channels => {
1231                // non-stereo frames are always encoded independently
1232                FrameHeader {
1233                    blocking_strategy: false,
1234                    frame_number: self.frame_number,
1235                    block_size,
1236                    sample_rate,
1237                    bits_per_sample: header_bits_per_sample,
1238                    channel_assignment: ChannelAssignment::Independent(
1239                        channels.len().try_into().expect("invalid channel count"),
1240                    ),
1241                }
1242                .write_subset(&mut w)?;
1243
1244                bw = BitWriter::new(w);
1245
1246                self.caches
1247                    .channels
1248                    .resize_with(channels.len(), ChannelCache::default);
1249
1250                vec_map(
1251                    self.caches.channels.iter_mut().zip(channels).collect(),
1252                    |(cache, channel)| {
1253                        encode_subframe(
1254                            &self.options,
1255                            cache,
1256                            CorrelatedChannel::independent(bits_per_sample, channel),
1257                        )
1258                    },
1259                )
1260                .into_iter()
1261                .try_for_each(|r| r.and_then(|r| r.playback(bw.by_ref()).map_err(Error::Io)))?;
1262            }
1263        }
1264
1265        let crc16: u16 = bw.aligned_writer()?.checksum().into();
1266        bw.write_from(crc16)?;
1267
1268        if self.frame_number.try_increment().is_err() {
1269            self.frame_number = FrameNumber::default();
1270        }
1271
1272        Ok(())
1273    }
1274
1275    /// Writes a whole FLAC frame to our output stream with CDDA parameters
1276    ///
1277    /// Samples are interleaved by channel, like:
1278    /// [left₀ , right₀ , left₁ , right₁ , left₂ , right₂ , …]
1279    ///
1280    /// This writes a whole FLAC frame to the output stream on each call.
1281    ///
1282    /// # Errors
1283    ///
1284    /// Returns an error of any of the parameters are invalid
1285    /// or if an I/O error occurs when writing to the stream.
1286    pub fn write_cdda(&mut self, samples: &[i32]) -> Result<(), Error> {
1287        self.write(44100, 2, 16, samples)
1288    }
1289}
1290
1291fn update_md5(md5: &mut md5::Context, samples: impl Iterator<Item = i32>, bytes_per_sample: usize) {
1292    use crate::byteorder::{Endianness, LittleEndian};
1293
1294    match bytes_per_sample {
1295        1 => {
1296            for s in samples {
1297                md5.consume(LittleEndian::i8_to_bytes(s as i8));
1298            }
1299        }
1300        2 => {
1301            for s in samples {
1302                md5.consume(LittleEndian::i16_to_bytes(s as i16));
1303            }
1304        }
1305        3 => {
1306            for s in samples {
1307                md5.consume(LittleEndian::i24_to_bytes(s));
1308            }
1309        }
1310        4 => {
1311            for s in samples {
1312                md5.consume(LittleEndian::i32_to_bytes(s));
1313            }
1314        }
1315        _ => panic!("unsupported number of bytes per sample"),
1316    }
1317}
1318
1319/// The interval of seek points to generate
1320#[derive(Copy, Clone, Debug)]
1321pub enum SeekTableInterval {
1322    ///Generate seekpoint every nth seconds
1323    Seconds(NonZero<u8>),
1324    /// Generate seekpoint every nth frames
1325    Frames(NonZero<usize>),
1326}
1327
1328impl Default for SeekTableInterval {
1329    fn default() -> Self {
1330        Self::Seconds(NonZero::new(10).unwrap())
1331    }
1332}
1333
1334impl SeekTableInterval {
1335    // decimates full set of seekpoints based on the requested
1336    // seektable style, or returns None if no seektable is requested
1337    fn filter<'s>(
1338        self,
1339        sample_rate: u32,
1340        seekpoints: impl IntoIterator<Item = EncoderSeekPoint> + 's,
1341    ) -> Box<dyn Iterator<Item = EncoderSeekPoint> + 's> {
1342        match self {
1343            Self::Seconds(seconds) => {
1344                let nth_sample = u64::from(u32::from(seconds.get()) * sample_rate);
1345                let mut offset = 0;
1346                Box::new(seekpoints.into_iter().filter(move |point| {
1347                    if point.range().contains(&offset) {
1348                        offset += nth_sample;
1349                        true
1350                    } else {
1351                        false
1352                    }
1353                }))
1354            }
1355            Self::Frames(frames) => Box::new(seekpoints.into_iter().step_by(frames.get())),
1356        }
1357    }
1358}
1359
1360/// FLAC encoding options
1361#[derive(Clone, Debug)]
1362pub struct Options {
1363    // whether to clobber existing file
1364    clobber: bool,
1365    block_size: u16,
1366    max_partition_order: u32,
1367    mid_side: bool,
1368    metadata: BlockList,
1369    seektable_interval: Option<SeekTableInterval>,
1370    max_lpc_order: Option<NonZero<u8>>,
1371    window: Window,
1372    exhaustive_channel_correlation: bool,
1373}
1374
1375impl Default for Options {
1376    fn default() -> Self {
1377        // a dummy placeholder value
1378        // since we can't know the stream parameters yet
1379        let mut metadata = BlockList::new(Streaminfo {
1380            minimum_block_size: 0,
1381            maximum_block_size: 0,
1382            minimum_frame_size: None,
1383            maximum_frame_size: None,
1384            sample_rate: 0,
1385            channels: NonZero::new(1).unwrap(),
1386            bits_per_sample: SignedBitCount::new::<4>(),
1387            total_samples: None,
1388            md5: None,
1389        });
1390
1391        metadata.insert(crate::metadata::Padding {
1392            size: 4096u16.into(),
1393        });
1394
1395        Self {
1396            clobber: false,
1397            block_size: 4096,
1398            mid_side: true,
1399            max_partition_order: 5,
1400            metadata,
1401            seektable_interval: Some(SeekTableInterval::default()),
1402            max_lpc_order: NonZero::new(8),
1403            window: Window::default(),
1404            exhaustive_channel_correlation: true,
1405        }
1406    }
1407}
1408
1409impl Options {
1410    /// Sets new block size
1411    ///
1412    /// Block size must be ≥ 16
1413    ///
1414    /// For subset streams, this must be ≤ 4608
1415    /// if the sample rate is ≤ 48 kHz -
1416    /// or ≤ 16384 for higher sample rates.
1417    pub fn block_size(self, block_size: u16) -> Result<Self, OptionsError> {
1418        match block_size {
1419            0..16 => Err(OptionsError::InvalidBlockSize),
1420            16.. => Ok(Self { block_size, ..self }),
1421        }
1422    }
1423
1424    /// Sets new maximum LPC order
1425    ///
1426    /// The valid range is: 0 < `max_lpc_order` ≤ 32
1427    ///
1428    /// A value of `None` means that no LPC subframes will be encoded.
1429    pub fn max_lpc_order(self, max_lpc_order: Option<u8>) -> Result<Self, OptionsError> {
1430        Ok(Self {
1431            max_lpc_order: max_lpc_order
1432                .map(|o| {
1433                    o.try_into()
1434                        .ok()
1435                        .filter(|o| *o <= NonZero::new(32).unwrap())
1436                        .ok_or(OptionsError::InvalidLpcOrder)
1437                })
1438                .transpose()?,
1439            ..self
1440        })
1441    }
1442
1443    /// Sets maximum residual partion order.
1444    ///
1445    /// The valid range is: 0 ≤ `max_partition_order` ≤ 15
1446    pub fn max_partition_order(self, max_partition_order: u32) -> Result<Self, OptionsError> {
1447        match max_partition_order {
1448            0..=15 => Ok(Self {
1449                max_partition_order,
1450                ..self
1451            }),
1452            16.. => Err(OptionsError::InvalidMaxPartitions),
1453        }
1454    }
1455
1456    /// Whether to use mid-side encoding
1457    ///
1458    /// The default is `true`.
1459    pub fn mid_side(self, mid_side: bool) -> Self {
1460        Self { mid_side, ..self }
1461    }
1462
1463    /// The windowing function to use for input samples
1464    pub fn window(self, window: Window) -> Self {
1465        Self { window, ..self }
1466    }
1467
1468    /// Whether to calculate the best channel correlation quickly
1469    ///
1470    /// The default is `false`
1471    pub fn fast_channel_correlation(self, fast: bool) -> Self {
1472        Self {
1473            exhaustive_channel_correlation: !fast,
1474            ..self
1475        }
1476    }
1477
1478    /// Updates size of padding block
1479    ///
1480    /// `size` must be < 2²⁴
1481    ///
1482    /// If `size` is set to 0, removes the block entirely.
1483    ///
1484    /// The default is to add a 4096 byte padding block.
1485    pub fn padding(mut self, size: u32) -> Result<Self, OptionsError> {
1486        use crate::metadata::Padding;
1487
1488        match size
1489            .try_into()
1490            .map_err(|_| OptionsError::ExcessivePadding)?
1491        {
1492            BlockSize::ZERO => self.metadata.remove::<Padding>(),
1493            size => self.metadata.update::<Padding>(|p| {
1494                p.size = size;
1495            }),
1496        }
1497        Ok(self)
1498    }
1499
1500    /// Remove any padding blocks from metadata
1501    ///
1502    /// This makes the file smaller, but will likely require
1503    /// rewriting it if any metadata needs to be modified later.
1504    pub fn no_padding(mut self) -> Self {
1505        self.metadata.remove::<crate::metadata::Padding>();
1506        self
1507    }
1508
1509    /// Adds new tag to comment metadata block
1510    ///
1511    /// Creates new [`crate::metadata::VorbisComment`] block if not already present.
1512    pub fn tag<S>(mut self, field: &str, value: S) -> Self
1513    where
1514        S: std::fmt::Display,
1515    {
1516        self.metadata
1517            .update::<VorbisComment>(|vc| vc.insert(field, value));
1518        self
1519    }
1520
1521    /// Replaces entire [`crate::metadata::VorbisComment`] metadata block
1522    ///
1523    /// This may be more convenient when adding many fields at once.
1524    pub fn comment(mut self, comment: VorbisComment) -> Self {
1525        self.metadata.insert(comment);
1526        self
1527    }
1528
1529    /// Add new [`crate::metadata::Picture`] block to metadata
1530    ///
1531    /// Files may contain multiple [`crate::metadata::Picture`] blocks,
1532    /// and this adds a new block each time it is used.
1533    pub fn picture(mut self, picture: Picture) -> Self {
1534        self.metadata.insert(picture);
1535        self
1536    }
1537
1538    /// Add new [`crate::metadata::Cuesheet`] block to metadata
1539    ///
1540    /// Files may (theoretically) contain multiple [`crate::metadata::Cuesheet`] blocks,
1541    /// and this adds a new block each time it is used.
1542    ///
1543    /// In practice, CD images almost always use only a single
1544    /// cue sheet.
1545    pub fn cuesheet(mut self, cuesheet: Cuesheet) -> Self {
1546        self.metadata.insert(cuesheet);
1547        self
1548    }
1549
1550    /// Add new [`crate::metadata::Application`] block to metadata
1551    ///
1552    /// Files may contain multiple [`crate::metadata::Application`] blocks,
1553    /// and this adds a new block each time it is used.
1554    pub fn application(mut self, application: Application) -> Self {
1555        self.metadata.insert(application);
1556        self
1557    }
1558
1559    /// Generate [`crate::metadata::SeekTable`] with the given number of seconds between seek points
1560    ///
1561    /// The default is to generate a SEEKTABLE with 10 seconds between seek points.
1562    ///
1563    /// If `seconds` is 0, removes the SEEKTABLE block.
1564    ///
1565    /// The interval between seek points may be larger than requested
1566    /// if the encoder's block size is larger than the seekpoint interval.
1567    pub fn seektable_seconds(mut self, seconds: u8) -> Self {
1568        // note that we can't drop a placeholder seektable
1569        // into the metadata blocks until we know
1570        // the sample rate and total samples of our stream
1571        self.seektable_interval = NonZero::new(seconds).map(SeekTableInterval::Seconds);
1572        self
1573    }
1574
1575    /// Generate [`crate::metadata::SeekTable`] with the given number of FLAC frames between seek points
1576    ///
1577    /// If `frames` is 0, removes the SEEKTABLE block
1578    pub fn seektable_frames(mut self, frames: usize) -> Self {
1579        self.seektable_interval = NonZero::new(frames).map(SeekTableInterval::Frames);
1580        self
1581    }
1582
1583    /// Do not generate a seektable in our encoded file
1584    pub fn no_seektable(self) -> Self {
1585        Self {
1586            seektable_interval: None,
1587            ..self
1588        }
1589    }
1590
1591    /// Add new block to metadata
1592    ///
1593    /// If the block may only occur once in a file,
1594    /// any previous block of that same type is removed.
1595    pub fn add_block<B>(&mut self, block: B) -> &mut Self
1596    where
1597        B: PortableMetadataBlock,
1598    {
1599        self.metadata.insert(block);
1600        self
1601    }
1602
1603    /// Add new blocks to metadata
1604    ///
1605    /// If the block may only occur once in a file,
1606    /// any current block of that type is replaced by
1607    /// the final block in the iterator - if any.
1608    /// Otherwise, all blocks in the iterator are used.
1609    pub fn add_blocks<B>(&mut self, iter: impl IntoIterator<Item = B>) -> &mut Self
1610    where
1611        B: PortableMetadataBlock,
1612    {
1613        for block in iter {
1614            self.metadata.insert(block);
1615        }
1616        self
1617    }
1618
1619    /// Overwrites existing file if it already exists
1620    ///
1621    /// This only applies to encoding functions which
1622    /// create files from paths.
1623    ///
1624    /// The default is to not overwrite files
1625    /// if they already exist.
1626    pub fn overwrite(mut self) -> Self {
1627        self.clobber = true;
1628        self
1629    }
1630
1631    /// Returns the fastest encoding options
1632    ///
1633    /// These are tuned to encode as quickly as possible.
1634    pub fn fast() -> Self {
1635        Self {
1636            block_size: 1152,
1637            mid_side: false,
1638            max_partition_order: 3,
1639            max_lpc_order: None,
1640            exhaustive_channel_correlation: false,
1641            ..Self::default()
1642        }
1643    }
1644
1645    /// Returns the fastest encoding options
1646    ///
1647    /// These are tuned to encode files as small as possible.
1648    pub fn best() -> Self {
1649        Self {
1650            block_size: 4096,
1651            mid_side: true,
1652            max_partition_order: 6,
1653            max_lpc_order: NonZero::new(12),
1654            ..Self::default()
1655        }
1656    }
1657
1658    /// Creates files according to whether clobber is set or not
1659    fn create<P: AsRef<Path>>(&self, path: P) -> std::io::Result<File> {
1660        if self.clobber {
1661            File::create(path)
1662        } else {
1663            use std::fs::OpenOptions;
1664
1665            OpenOptions::new()
1666                .write(true)
1667                .create_new(true)
1668                .open(path.as_ref())
1669        }
1670    }
1671}
1672
1673/// An error when specifying encoding options
1674#[derive(Debug)]
1675pub enum OptionsError {
1676    /// Selected block size is too small
1677    InvalidBlockSize,
1678    /// Maximum LPC order is too large
1679    InvalidLpcOrder,
1680    /// Maximum residual partitions is too large
1681    InvalidMaxPartitions,
1682    /// Selected padding size is too large
1683    ExcessivePadding,
1684}
1685
1686impl std::error::Error for OptionsError {}
1687
1688impl std::fmt::Display for OptionsError {
1689    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1690        match self {
1691            Self::InvalidBlockSize => "block size must be >= 16".fmt(f),
1692            Self::InvalidLpcOrder => "maximum LPC order must be <= 32".fmt(f),
1693            Self::InvalidMaxPartitions => "max partition order must be <= 15".fmt(f),
1694            Self::ExcessivePadding => "padding size is too large for block".fmt(f),
1695        }
1696    }
1697}
1698
1699/// A cut-down version of Options without the metadata blocks
1700struct EncoderOptions {
1701    max_partition_order: u32,
1702    mid_side: bool,
1703    seektable_interval: Option<SeekTableInterval>,
1704    max_lpc_order: Option<NonZero<u8>>,
1705    window: Window,
1706    exhaustive_channel_correlation: bool,
1707    use_rice2: bool,
1708}
1709
1710/// The method to use for windowing the input signal
1711#[derive(Copy, Clone, Debug)]
1712pub enum Window {
1713    /// Basic rectangular window
1714    Rectangle,
1715    /// Hann window
1716    Hann,
1717    /// Tukey window
1718    Tukey(f32),
1719}
1720
1721// TODO - add more windowing options
1722
1723impl Window {
1724    fn generate(&self, window: &mut [f64]) {
1725        use std::f64::consts::PI;
1726
1727        match self {
1728            Self::Rectangle => window.fill(1.0),
1729            Self::Hann => {
1730                // verified output against reference implementation
1731                // See: FLAC__window_hann()
1732
1733                let np =
1734                    f64::from(u16::try_from(window.len()).expect("window size too large")) - 1.0;
1735
1736                window.iter_mut().zip(0u16..).for_each(|(w, n)| {
1737                    *w = 0.5 - 0.5 * (2.0 * PI * f64::from(n) / np).cos();
1738                });
1739            }
1740            Self::Tukey(p) => match p {
1741                // verified output against reference implementation
1742                // See: FLAC__window_tukey()
1743                ..=0.0 => {
1744                    window.fill(1.0);
1745                }
1746                1.0.. => {
1747                    Self::Hann.generate(window);
1748                }
1749                0.0..1.0 => {
1750                    match ((f64::from(*p) / 2.0 * window.len() as f64) as usize).checked_sub(1) {
1751                        Some(np) => match window.get_disjoint_mut([
1752                            0..np,
1753                            np..window.len() - np,
1754                            window.len() - np..window.len(),
1755                        ]) {
1756                            Ok([first, mid, last]) => {
1757                                // u16 is maximum block size
1758                                let np = u16::try_from(np).expect("window size too large");
1759
1760                                for ((x, y), n) in
1761                                    first.iter_mut().zip(last.iter_mut().rev()).zip(0u16..)
1762                                {
1763                                    *x = 0.5 - 0.5 * (PI * f64::from(n) / f64::from(np)).cos();
1764                                    *y = *x;
1765                                }
1766                                mid.fill(1.0);
1767                            }
1768                            Err(_) => {
1769                                window.fill(1.0);
1770                            }
1771                        },
1772                        None => {
1773                            window.fill(1.0);
1774                        }
1775                    }
1776                }
1777                _ => {
1778                    Self::Tukey(0.5).generate(window);
1779                }
1780            },
1781        }
1782    }
1783
1784    fn apply<'w>(
1785        &self,
1786        window: &mut Vec<f64>,
1787        cache: &'w mut Vec<f64>,
1788        samples: &[i32],
1789    ) -> &'w [f64] {
1790        if window.len() != samples.len() {
1791            // need to re-generate window to fit samples
1792            window.resize(samples.len(), 0.0);
1793            self.generate(window);
1794        }
1795
1796        // window signal into cache and return cached slice
1797        cache.clear();
1798        cache.extend(samples.iter().zip(window).map(|(s, w)| f64::from(*s) * *w));
1799        cache.as_slice()
1800    }
1801}
1802
1803impl Default for Window {
1804    fn default() -> Self {
1805        Self::Tukey(0.5)
1806    }
1807}
1808
1809#[derive(Default)]
1810struct EncodingCaches {
1811    channels: Vec<ChannelCache>,
1812    correlated: CorrelationCache,
1813}
1814
1815#[derive(Default)]
1816struct CorrelationCache {
1817    // the average channel samples
1818    average_samples: Vec<i32>,
1819    // the difference channel samples
1820    difference_samples: Vec<i32>,
1821
1822    left_cache: ChannelCache,
1823    right_cache: ChannelCache,
1824    average_cache: ChannelCache,
1825    difference_cache: ChannelCache,
1826}
1827
1828#[derive(Default)]
1829struct ChannelCache {
1830    fixed: FixedCache,
1831    fixed_output: BitRecorder<u32, BigEndian>,
1832    lpc: LpcCache,
1833    lpc_output: BitRecorder<u32, BigEndian>,
1834    constant_output: BitRecorder<u32, BigEndian>,
1835    verbatim_output: BitRecorder<u32, BigEndian>,
1836    wasted: Vec<i32>,
1837}
1838
1839#[derive(Default)]
1840struct FixedCache {
1841    // FIXED subframe buffers, one per order 1-4
1842    fixed_buffers: [Vec<i32>; 4],
1843}
1844
1845#[derive(Default)]
1846struct LpcCache {
1847    window: Vec<f64>,
1848    windowed: Vec<f64>,
1849    residuals: Vec<i32>,
1850}
1851
1852/// A FLAC encoder
1853struct Encoder<W: std::io::Write + std::io::Seek> {
1854    // the writer we're outputting to
1855    writer: Counter<W>,
1856    // the stream's starting offset in the writer, in bytes
1857    start: u64,
1858    // various encoding options
1859    options: EncoderOptions,
1860    // various encoder caches
1861    caches: EncodingCaches,
1862    // our metadata blocks
1863    blocks: BlockList,
1864    // our stream's sample rate
1865    sample_rate: SampleRate<u32>,
1866    // the current frame number
1867    frame_number: FrameNumber,
1868    // the number of channel-independent samples written
1869    samples_written: u64,
1870    // all seekpoints
1871    seekpoints: Vec<EncoderSeekPoint>,
1872    // our running MD5 calculation
1873    md5: md5::Context,
1874    // whether the encoder has finalized the file
1875    finalized: bool,
1876}
1877
1878impl<W: std::io::Write + std::io::Seek> Encoder<W> {
1879    const MAX_SAMPLES: u64 = 68_719_476_736;
1880
1881    fn new(
1882        mut writer: W,
1883        options: Options,
1884        sample_rate: u32,
1885        bits_per_sample: SignedBitCount<32>,
1886        channels: u8,
1887        total_samples: Option<NonZero<u64>>,
1888    ) -> Result<Self, Error> {
1889        use crate::metadata::OptionalBlockType;
1890
1891        let mut blocks = options.metadata;
1892
1893        *blocks.streaminfo_mut() = Streaminfo {
1894            minimum_block_size: options.block_size,
1895            maximum_block_size: options.block_size,
1896            minimum_frame_size: None,
1897            maximum_frame_size: None,
1898            sample_rate: (0..1048576)
1899                .contains(&sample_rate)
1900                .then_some(sample_rate)
1901                .ok_or(Error::InvalidSampleRate)?,
1902            bits_per_sample,
1903            channels: (1..=8)
1904                .contains(&channels)
1905                .then_some(channels)
1906                .and_then(NonZero::new)
1907                .ok_or(Error::ExcessiveChannels)?,
1908            total_samples: match total_samples {
1909                None => None,
1910                total_samples @ Some(samples) => match samples.get() {
1911                    0..Self::MAX_SAMPLES => total_samples,
1912                    _ => return Err(Error::ExcessiveTotalSamples),
1913                },
1914            },
1915            md5: None,
1916        };
1917
1918        // insert a dummy SeekTable to be populated later
1919        if let Some(total_samples) = total_samples
1920            && let Some(placeholders) = options.seektable_interval.map(|s| {
1921                s.filter(
1922                    sample_rate,
1923                    EncoderSeekPoint::placeholders(total_samples.get(), options.block_size),
1924                )
1925            })
1926        {
1927            use crate::metadata::SeekTable;
1928
1929            blocks.insert(SeekTable {
1930                // placeholder points should always be contiguous
1931                points: placeholders
1932                    .take(SeekTable::MAX_POINTS)
1933                    .map(|p| p.into())
1934                    .collect::<Vec<_>>()
1935                    .try_into()
1936                    .unwrap(),
1937            });
1938        }
1939
1940        let start = writer.stream_position()?;
1941
1942        // sort blocks to put more relevant items at the front
1943        blocks.sort_by(|block| match block {
1944            OptionalBlockType::VorbisComment => 0,
1945            OptionalBlockType::SeekTable => 1,
1946            OptionalBlockType::Picture => 2,
1947            OptionalBlockType::Application => 3,
1948            OptionalBlockType::Cuesheet => 4,
1949            OptionalBlockType::Padding => 5,
1950        });
1951
1952        write_blocks(writer.by_ref(), blocks.blocks())?;
1953
1954        Ok(Self {
1955            start,
1956            writer: Counter::new(writer),
1957            options: EncoderOptions {
1958                max_partition_order: options.max_partition_order,
1959                mid_side: options.mid_side,
1960                seektable_interval: options.seektable_interval,
1961                max_lpc_order: options.max_lpc_order,
1962                window: options.window,
1963                exhaustive_channel_correlation: options.exhaustive_channel_correlation,
1964                use_rice2: u32::from(bits_per_sample) > 16,
1965            },
1966            caches: EncodingCaches::default(),
1967            sample_rate: blocks
1968                .streaminfo()
1969                .sample_rate
1970                .try_into()
1971                .expect("invalid sample rate"),
1972            blocks,
1973            frame_number: FrameNumber::default(),
1974            samples_written: 0,
1975            seekpoints: Vec::new(),
1976            md5: md5::Context::new(),
1977            finalized: false,
1978        })
1979    }
1980
1981    /// The encoder's channel count
1982    fn channel_count(&self) -> NonZero<u8> {
1983        self.blocks.streaminfo().channels
1984    }
1985
1986    /// Encodes an audio frame of PCM samples
1987    ///
1988    /// Depending on the encoder's chosen block size,
1989    /// this may encode zero or more FLAC frames to disk.
1990    ///
1991    /// # Errors
1992    ///
1993    /// Returns an I/O error from the underlying stream,
1994    /// or if the frame's parameters are not a match
1995    /// for the encoder's.
1996    fn encode(&mut self, frame: &Frame) -> Result<(), Error> {
1997        // drop in a new seekpoint
1998        self.seekpoints.push(EncoderSeekPoint {
1999            sample_offset: self.samples_written,
2000            byte_offset: Some(self.writer.count),
2001            frame_samples: frame.pcm_frames() as u16,
2002        });
2003
2004        // update running total of samples written
2005        self.samples_written += frame.pcm_frames() as u64;
2006        if let Some(total_samples) = self.blocks.streaminfo().total_samples
2007            && self.samples_written > total_samples.get()
2008        {
2009            return Err(Error::ExcessiveTotalSamples);
2010        }
2011
2012        encode_frame(
2013            &self.options,
2014            &mut self.caches,
2015            &mut self.writer,
2016            self.blocks.streaminfo_mut(),
2017            &mut self.frame_number,
2018            self.sample_rate,
2019            frame.channels().collect(),
2020        )
2021    }
2022
2023    fn finalize_inner(&mut self) -> Result<(), Error> {
2024        if !self.finalized {
2025            use crate::metadata::SeekTable;
2026
2027            self.finalized = true;
2028
2029            // update SEEKTABLE metadata block with final values
2030            if let Some(encoded_points) = self
2031                .options
2032                .seektable_interval
2033                .map(|s| s.filter(self.sample_rate.into(), self.seekpoints.iter().cloned()))
2034            {
2035                match self.blocks.get_pair_mut() {
2036                    (Some(SeekTable { points }), _) => {
2037                        // placeholder SEEKTABLE already in place,
2038                        // so no need to adjust PADDING to fit
2039
2040                        // ensure points count is the same
2041
2042                        let points_len = points.len();
2043                        points.clear();
2044                        points
2045                            .try_extend(
2046                                encoded_points
2047                                    .into_iter()
2048                                    .map(|p| p.into())
2049                                    .chain(std::iter::repeat(SeekPoint::Placeholder))
2050                                    .take(points_len),
2051                            )
2052                            .unwrap();
2053                    }
2054                    (None, Some(crate::metadata::Padding { size: padding_size })) => {
2055                        // no SEEKTABLE, but there is a PADDING block,
2056                        // so try to shrink PADDING to fit SEEKTABLE
2057
2058                        use crate::metadata::MetadataBlock;
2059
2060                        let seektable = SeekTable {
2061                            points: encoded_points
2062                                .map(|p| p.into())
2063                                .collect::<Vec<_>>()
2064                                .try_into()
2065                                .unwrap(),
2066                        };
2067                        if let Some(new_padding_size) = seektable
2068                            .total_size()
2069                            .and_then(|seektable_size| padding_size.checked_sub(seektable_size))
2070                        {
2071                            *padding_size = new_padding_size;
2072                            self.blocks.insert(seektable);
2073                        }
2074                    }
2075                    (None, None) => { /* no seektable or padding, so nothing to do */ }
2076                }
2077            }
2078
2079            // verify or update final sample count
2080            match &mut self.blocks.streaminfo_mut().total_samples {
2081                Some(expected) => {
2082                    // ensure final sample count matches
2083                    if expected.get() != self.samples_written {
2084                        return Err(Error::SampleCountMismatch);
2085                    }
2086                }
2087                expected @ None => {
2088                    // update final sample count if possible
2089                    if self.samples_written < Self::MAX_SAMPLES {
2090                        *expected =
2091                            Some(NonZero::new(self.samples_written).ok_or(Error::NoSamples)?);
2092                    } else {
2093                        // TODO - should I just leave this blank
2094                        // if too many samples are written?
2095                        return Err(Error::ExcessiveTotalSamples);
2096                    }
2097                }
2098            }
2099
2100            // update STREAMINFO MD5 sum
2101            self.blocks.streaminfo_mut().md5 = Some(self.md5.clone().finalize().0);
2102
2103            // rewrite metadata blocks, relative to the beginning
2104            // of the stream
2105            let writer = self.writer.stream();
2106            writer.seek(std::io::SeekFrom::Start(self.start))?;
2107            write_blocks(writer.by_ref(), self.blocks.blocks())
2108        } else {
2109            Ok(())
2110        }
2111    }
2112}
2113
2114impl<W: std::io::Write + std::io::Seek> Drop for Encoder<W> {
2115    fn drop(&mut self) {
2116        let _ = self.finalize_inner();
2117    }
2118}
2119
2120// Unlike regular SeekPoints, which can have placeholder variants,
2121// these are always defined to be something.  A byte offset
2122// of None indicates a dummy encoder point
2123#[derive(Debug, Clone)]
2124struct EncoderSeekPoint {
2125    sample_offset: u64,
2126    byte_offset: Option<u64>,
2127    frame_samples: u16,
2128}
2129
2130impl EncoderSeekPoint {
2131    // generates set of placeholder points
2132    fn placeholders(total_samples: u64, block_size: u16) -> impl Iterator<Item = EncoderSeekPoint> {
2133        (0..total_samples)
2134            .step_by(usize::from(block_size))
2135            .map(move |sample_offset| EncoderSeekPoint {
2136                sample_offset,
2137                byte_offset: None,
2138                frame_samples: u16::try_from(total_samples - sample_offset)
2139                    .map(|s| s.min(block_size))
2140                    .unwrap_or(block_size),
2141            })
2142    }
2143
2144    // returns sample range of point
2145    fn range(&self) -> std::ops::Range<u64> {
2146        self.sample_offset..(self.sample_offset + u64::from(self.frame_samples))
2147    }
2148}
2149
2150impl From<EncoderSeekPoint> for SeekPoint {
2151    fn from(p: EncoderSeekPoint) -> Self {
2152        match p.byte_offset {
2153            Some(byte_offset) => Self::Defined {
2154                sample_offset: p.sample_offset,
2155                byte_offset,
2156                frame_samples: p.frame_samples,
2157            },
2158            None => Self::Placeholder,
2159        }
2160    }
2161}
2162
2163/// Given a FLAC stream, generates new seek table
2164///
2165/// Though encoders should add seek tables by default,
2166/// sometimes one isn't present.  This function takes
2167/// an existing FLAC file stream and generates a new
2168/// seek table suitable for adding to the file's metadata
2169/// via the [`crate::metadata::update`] function.
2170///
2171/// The stream should be rewound to the beginning of the file.
2172///
2173/// # Errors
2174///
2175/// Returns any error from the underlying stream.
2176///
2177/// # Example
2178/// ```
2179/// use flac_codec::{
2180///     encode::{FlacSampleWriter, Options, SeekTableInterval, generate_seektable},
2181///     metadata::{SeekTable, read_block},
2182/// };
2183/// use std::io::{Cursor, Seek};
2184///
2185/// let mut flac = Cursor::new(vec![]);  // a FLAC file in memory
2186///
2187/// // add a seekpoint every second
2188/// let options = Options::default().seektable_seconds(1);
2189///
2190/// let mut writer = FlacSampleWriter::new(
2191///     &mut flac,         // our wrapped writer
2192///     options,           // our seektable options
2193///     44100,             // sample rate
2194///     16,                // bits-per-sample
2195///     1,                 // channel count
2196///     Some(60 * 44100),  // one minute's worth of samples
2197/// ).unwrap();
2198///
2199/// // write one minute's worth of samples
2200/// writer.write(vec![0; 60 * 44100].as_slice()).unwrap();
2201///
2202/// // finalize writing file
2203/// assert!(writer.finalize().is_ok());
2204///
2205/// flac.rewind().unwrap();
2206///
2207/// // get existing seektable
2208/// let original_seektable = match read_block::<_, SeekTable>(&mut flac) {
2209///     Ok(Some(seektable)) => seektable,
2210///     _ => panic!("seektable not found"),
2211/// };
2212///
2213/// flac.rewind().unwrap();
2214///
2215/// // generate new seektable, also with seekpoints every second
2216/// let new_seektable = generate_seektable(
2217///     flac,
2218///     SeekTableInterval::Seconds(1.try_into().unwrap())
2219/// ).unwrap();
2220///
2221/// // ensure both seektables are identical
2222/// assert_eq!(original_seektable, new_seektable);
2223/// ```
2224pub fn generate_seektable<R: std::io::Read>(
2225    r: R,
2226    interval: SeekTableInterval,
2227) -> Result<crate::metadata::SeekTable, Error> {
2228    use crate::{
2229        metadata::{Metadata, SeekTable},
2230        stream::FrameIterator,
2231    };
2232
2233    let iter = FrameIterator::new(r)?;
2234    let metadata_len = iter.metadata_len();
2235    let sample_rate = iter.sample_rate();
2236    let mut sample_offset = 0;
2237
2238    iter.map(|r| {
2239        r.map(|(frame, offset)| EncoderSeekPoint {
2240            sample_offset,
2241            byte_offset: Some(offset - metadata_len),
2242            frame_samples: frame.header.block_size.into(),
2243        })
2244        .inspect(|p| {
2245            sample_offset += u64::from(p.frame_samples);
2246        })
2247    })
2248    .collect::<Result<Vec<_>, _>>()
2249    .map(|seekpoints| SeekTable {
2250        points: interval
2251            .filter(sample_rate, seekpoints)
2252            .take(SeekTable::MAX_POINTS)
2253            .map(|p| p.into())
2254            .collect::<Vec<_>>()
2255            .try_into()
2256            .unwrap(),
2257    })
2258}
2259
2260fn encode_frame<W>(
2261    options: &EncoderOptions,
2262    cache: &mut EncodingCaches,
2263    mut writer: W,
2264    streaminfo: &mut Streaminfo,
2265    frame_number: &mut FrameNumber,
2266    sample_rate: SampleRate<u32>,
2267    frame: ArrayVec<&[i32], MAX_CHANNELS>,
2268) -> Result<(), Error>
2269where
2270    W: std::io::Write,
2271{
2272    use crate::Counter;
2273    use crate::crc::{Crc16, CrcWriter};
2274    use crate::stream::FrameHeader;
2275    use bitstream_io::BigEndian;
2276
2277    debug_assert!(!frame.is_empty());
2278
2279    let size = Counter::new(writer.by_ref());
2280    let mut w: CrcWriter<_, Crc16> = CrcWriter::new(size);
2281    let mut bw: BitWriter<CrcWriter<Counter<&mut W>, Crc16>, BigEndian>;
2282
2283    match frame.as_slice() {
2284        [channel] => {
2285            FrameHeader {
2286                blocking_strategy: false,
2287                frame_number: *frame_number,
2288                block_size: (channel.len() as u16)
2289                    .try_into()
2290                    .expect("frame cannot be empty"),
2291                sample_rate,
2292                bits_per_sample: streaminfo.bits_per_sample.into(),
2293                channel_assignment: ChannelAssignment::Independent(Independent::Mono),
2294            }
2295            .write(&mut w, streaminfo)?;
2296
2297            bw = BitWriter::new(w);
2298
2299            cache.channels.resize_with(1, ChannelCache::default);
2300
2301            encode_subframe(
2302                options,
2303                &mut cache.channels[0],
2304                CorrelatedChannel::independent(streaminfo.bits_per_sample, channel),
2305            )?
2306            .playback(&mut bw)?;
2307        }
2308        [left, right] if options.exhaustive_channel_correlation => {
2309            let Correlated {
2310                channel_assignment,
2311                channels: [channel_0, channel_1],
2312            } = correlate_channels_exhaustive(
2313                options,
2314                &mut cache.correlated,
2315                [left, right],
2316                streaminfo.bits_per_sample,
2317            )?;
2318
2319            FrameHeader {
2320                blocking_strategy: false,
2321                frame_number: *frame_number,
2322                block_size: (frame[0].len() as u16)
2323                    .try_into()
2324                    .expect("frame cannot be empty"),
2325                sample_rate,
2326                bits_per_sample: streaminfo.bits_per_sample.into(),
2327                channel_assignment,
2328            }
2329            .write(&mut w, streaminfo)?;
2330
2331            bw = BitWriter::new(w);
2332
2333            channel_0.playback(&mut bw)?;
2334            channel_1.playback(&mut bw)?;
2335        }
2336        [left, right] => {
2337            let Correlated {
2338                channel_assignment,
2339                channels: [channel_0, channel_1],
2340            } = correlate_channels(
2341                options,
2342                &mut cache.correlated,
2343                [left, right],
2344                streaminfo.bits_per_sample,
2345            );
2346
2347            FrameHeader {
2348                blocking_strategy: false,
2349                frame_number: *frame_number,
2350                block_size: (frame[0].len() as u16)
2351                    .try_into()
2352                    .expect("frame cannot be empty"),
2353                sample_rate,
2354                bits_per_sample: streaminfo.bits_per_sample.into(),
2355                channel_assignment,
2356            }
2357            .write(&mut w, streaminfo)?;
2358
2359            cache.channels.resize_with(2, ChannelCache::default);
2360            let [cache_0, cache_1] = cache.channels.get_disjoint_mut([0, 1]).unwrap();
2361            let (channel_0, channel_1) = join(
2362                || encode_subframe(options, cache_0, channel_0),
2363                || encode_subframe(options, cache_1, channel_1),
2364            );
2365
2366            bw = BitWriter::new(w);
2367
2368            channel_0?.playback(&mut bw)?;
2369            channel_1?.playback(&mut bw)?;
2370        }
2371        channels => {
2372            // non-stereo frames are always encoded independently
2373
2374            FrameHeader {
2375                blocking_strategy: false,
2376                frame_number: *frame_number,
2377                block_size: (channels[0].len() as u16)
2378                    .try_into()
2379                    .expect("frame cannot be empty"),
2380                sample_rate,
2381                bits_per_sample: streaminfo.bits_per_sample.into(),
2382                channel_assignment: ChannelAssignment::Independent(
2383                    frame.len().try_into().expect("invalid channel count"),
2384                ),
2385            }
2386            .write(&mut w, streaminfo)?;
2387
2388            bw = BitWriter::new(w);
2389
2390            cache
2391                .channels
2392                .resize_with(channels.len(), ChannelCache::default);
2393
2394            vec_map(
2395                cache.channels.iter_mut().zip(channels).collect(),
2396                |(cache, channel)| {
2397                    encode_subframe(
2398                        options,
2399                        cache,
2400                        CorrelatedChannel::independent(streaminfo.bits_per_sample, channel),
2401                    )
2402                },
2403            )
2404            .into_iter()
2405            .try_for_each(|r| r.and_then(|r| r.playback(bw.by_ref()).map_err(Error::Io)))?;
2406        }
2407    }
2408
2409    let crc16: u16 = bw.aligned_writer()?.checksum().into();
2410    bw.write_from(crc16)?;
2411
2412    frame_number.try_increment()?;
2413
2414    // update minimum and maximum frame size values
2415    if let s @ Some(size) = u32::try_from(bw.into_writer().into_writer().count)
2416        .ok()
2417        .filter(|size| *size < Streaminfo::MAX_FRAME_SIZE)
2418        .and_then(NonZero::new)
2419    {
2420        match &mut streaminfo.minimum_frame_size {
2421            Some(min_size) => {
2422                *min_size = size.min(*min_size);
2423            }
2424            min_size @ None => {
2425                *min_size = s;
2426            }
2427        }
2428
2429        match &mut streaminfo.maximum_frame_size {
2430            Some(max_size) => {
2431                *max_size = size.max(*max_size);
2432            }
2433            max_size @ None => {
2434                *max_size = s;
2435            }
2436        }
2437    }
2438
2439    Ok(())
2440}
2441
2442struct Correlated<C> {
2443    channel_assignment: ChannelAssignment,
2444    channels: [C; 2],
2445}
2446
2447struct CorrelatedChannel<'c> {
2448    samples: &'c [i32],
2449    bits_per_sample: SignedBitCount<32>,
2450    // whether all samples are known to be 0
2451    all_0: bool,
2452}
2453
2454impl<'c> CorrelatedChannel<'c> {
2455    fn independent(bits_per_sample: SignedBitCount<32>, samples: &'c [i32]) -> Self {
2456        Self {
2457            all_0: samples.iter().all(|s| *s == 0),
2458            bits_per_sample,
2459            samples,
2460        }
2461    }
2462}
2463
2464fn correlate_channels<'c>(
2465    options: &EncoderOptions,
2466    CorrelationCache {
2467        average_samples,
2468        difference_samples,
2469        ..
2470    }: &'c mut CorrelationCache,
2471    [left, right]: [&'c [i32]; 2],
2472    bits_per_sample: SignedBitCount<32>,
2473) -> Correlated<CorrelatedChannel<'c>> {
2474    match bits_per_sample.checked_add::<32>(1) {
2475        Some(difference_bits_per_sample) if options.mid_side => {
2476            let mut left_abs_sum = 0;
2477            let mut right_abs_sum = 0;
2478            let mut mid_abs_sum = 0;
2479            let mut side_abs_sum = 0;
2480
2481            join(
2482                || {
2483                    average_samples.clear();
2484                    average_samples.extend(
2485                        left.iter()
2486                            .inspect(|s| left_abs_sum += u64::from(s.unsigned_abs()))
2487                            .zip(
2488                                right
2489                                    .iter()
2490                                    .inspect(|s| right_abs_sum += u64::from(s.unsigned_abs())),
2491                            )
2492                            .map(|(l, r)| (l + r) >> 1)
2493                            .inspect(|s| mid_abs_sum += u64::from(s.unsigned_abs())),
2494                    );
2495                },
2496                || {
2497                    difference_samples.clear();
2498                    difference_samples.extend(
2499                        left.iter()
2500                            .zip(right)
2501                            .map(|(l, r)| l - r)
2502                            .inspect(|s| side_abs_sum += u64::from(s.unsigned_abs())),
2503                    );
2504                },
2505            );
2506
2507            match [
2508                (
2509                    ChannelAssignment::Independent(Independent::Stereo),
2510                    left_abs_sum + right_abs_sum,
2511                ),
2512                (ChannelAssignment::LeftSide, left_abs_sum + side_abs_sum),
2513                (ChannelAssignment::SideRight, side_abs_sum + right_abs_sum),
2514                (ChannelAssignment::MidSide, mid_abs_sum + side_abs_sum),
2515            ]
2516            .into_iter()
2517            .min_by_key(|(_, total)| *total)
2518            .unwrap()
2519            .0
2520            {
2521                channel_assignment @ ChannelAssignment::LeftSide => Correlated {
2522                    channel_assignment,
2523                    channels: [
2524                        CorrelatedChannel {
2525                            samples: left,
2526                            bits_per_sample,
2527                            all_0: left_abs_sum == 0,
2528                        },
2529                        CorrelatedChannel {
2530                            samples: difference_samples,
2531                            bits_per_sample: difference_bits_per_sample,
2532                            all_0: side_abs_sum == 0,
2533                        },
2534                    ],
2535                },
2536                channel_assignment @ ChannelAssignment::SideRight => Correlated {
2537                    channel_assignment,
2538                    channels: [
2539                        CorrelatedChannel {
2540                            samples: difference_samples,
2541                            bits_per_sample: difference_bits_per_sample,
2542                            all_0: side_abs_sum == 0,
2543                        },
2544                        CorrelatedChannel {
2545                            samples: right,
2546                            bits_per_sample,
2547                            all_0: right_abs_sum == 0,
2548                        },
2549                    ],
2550                },
2551                channel_assignment @ ChannelAssignment::MidSide => Correlated {
2552                    channel_assignment,
2553                    channels: [
2554                        CorrelatedChannel {
2555                            samples: average_samples,
2556                            bits_per_sample,
2557                            all_0: mid_abs_sum == 0,
2558                        },
2559                        CorrelatedChannel {
2560                            samples: difference_samples,
2561                            bits_per_sample: difference_bits_per_sample,
2562                            all_0: side_abs_sum == 0,
2563                        },
2564                    ],
2565                },
2566                channel_assignment @ ChannelAssignment::Independent(_) => Correlated {
2567                    channel_assignment,
2568                    channels: [
2569                        CorrelatedChannel {
2570                            samples: left,
2571                            bits_per_sample,
2572                            all_0: left_abs_sum == 0,
2573                        },
2574                        CorrelatedChannel {
2575                            samples: right,
2576                            bits_per_sample,
2577                            all_0: right_abs_sum == 0,
2578                        },
2579                    ],
2580                },
2581            }
2582        }
2583        Some(difference_bits_per_sample) => {
2584            let mut left_abs_sum = 0;
2585            let mut right_abs_sum = 0;
2586            let mut side_abs_sum = 0;
2587
2588            difference_samples.clear();
2589            difference_samples.extend(
2590                left.iter()
2591                    .inspect(|s| left_abs_sum += u64::from(s.unsigned_abs()))
2592                    .zip(
2593                        right
2594                            .iter()
2595                            .inspect(|s| right_abs_sum += u64::from(s.unsigned_abs())),
2596                    )
2597                    .map(|(l, r)| l - r)
2598                    .inspect(|s| side_abs_sum += u64::from(s.unsigned_abs())),
2599            );
2600
2601            match [
2602                (ChannelAssignment::LeftSide, left_abs_sum + side_abs_sum),
2603                (ChannelAssignment::SideRight, side_abs_sum + right_abs_sum),
2604                (
2605                    ChannelAssignment::Independent(Independent::Stereo),
2606                    left_abs_sum + right_abs_sum,
2607                ),
2608            ]
2609            .into_iter()
2610            .min_by_key(|(_, total)| *total)
2611            .unwrap()
2612            .0
2613            {
2614                channel_assignment @ ChannelAssignment::LeftSide => Correlated {
2615                    channel_assignment,
2616                    channels: [
2617                        CorrelatedChannel {
2618                            samples: left,
2619                            bits_per_sample,
2620                            all_0: left_abs_sum == 0,
2621                        },
2622                        CorrelatedChannel {
2623                            samples: difference_samples,
2624                            bits_per_sample: difference_bits_per_sample,
2625                            all_0: side_abs_sum == 0,
2626                        },
2627                    ],
2628                },
2629                channel_assignment @ ChannelAssignment::SideRight => Correlated {
2630                    channel_assignment,
2631                    channels: [
2632                        CorrelatedChannel {
2633                            samples: difference_samples,
2634                            bits_per_sample: difference_bits_per_sample,
2635                            all_0: side_abs_sum == 0,
2636                        },
2637                        CorrelatedChannel {
2638                            samples: right,
2639                            bits_per_sample,
2640                            all_0: right_abs_sum == 0,
2641                        },
2642                    ],
2643                },
2644                ChannelAssignment::MidSide => unreachable!(),
2645                channel_assignment @ ChannelAssignment::Independent(_) => Correlated {
2646                    channel_assignment,
2647                    channels: [
2648                        CorrelatedChannel {
2649                            samples: left,
2650                            bits_per_sample,
2651                            all_0: left_abs_sum == 0,
2652                        },
2653                        CorrelatedChannel {
2654                            samples: right,
2655                            bits_per_sample,
2656                            all_0: right_abs_sum == 0,
2657                        },
2658                    ],
2659                },
2660            }
2661        }
2662        None => {
2663            // 32 bps stream, so forego difference channel
2664            // and encode them both indepedently
2665
2666            Correlated {
2667                channel_assignment: ChannelAssignment::Independent(Independent::Stereo),
2668                channels: [
2669                    CorrelatedChannel::independent(bits_per_sample, left),
2670                    CorrelatedChannel::independent(bits_per_sample, right),
2671                ],
2672            }
2673        }
2674    }
2675}
2676
2677fn correlate_channels_exhaustive<'c>(
2678    options: &EncoderOptions,
2679    CorrelationCache {
2680        average_samples,
2681        difference_samples,
2682        left_cache,
2683        right_cache,
2684        average_cache,
2685        difference_cache,
2686        ..
2687    }: &'c mut CorrelationCache,
2688    [left, right]: [&'c [i32]; 2],
2689    bits_per_sample: SignedBitCount<32>,
2690) -> Result<Correlated<&'c BitRecorder<u32, BigEndian>>, Error> {
2691    let (left_recorder, right_recorder) = try_join(
2692        || {
2693            encode_subframe(
2694                options,
2695                left_cache,
2696                CorrelatedChannel {
2697                    samples: left,
2698                    bits_per_sample,
2699                    all_0: false,
2700                },
2701            )
2702        },
2703        || {
2704            encode_subframe(
2705                options,
2706                right_cache,
2707                CorrelatedChannel {
2708                    samples: right,
2709                    bits_per_sample,
2710                    all_0: false,
2711                },
2712            )
2713        },
2714    )?;
2715
2716    match bits_per_sample.checked_add::<32>(1) {
2717        Some(difference_bits_per_sample) if options.mid_side => {
2718            let (average_recorder, difference_recorder) = try_join(
2719                || {
2720                    average_samples.clear();
2721                    average_samples
2722                        .extend(left.iter().zip(right.iter()).map(|(l, r)| (l + r) >> 1));
2723                    encode_subframe(
2724                        options,
2725                        average_cache,
2726                        CorrelatedChannel {
2727                            samples: average_samples,
2728                            bits_per_sample,
2729                            all_0: false,
2730                        },
2731                    )
2732                },
2733                || {
2734                    difference_samples.clear();
2735                    difference_samples.extend(left.iter().zip(right).map(|(l, r)| l - r));
2736                    encode_subframe(
2737                        options,
2738                        difference_cache,
2739                        CorrelatedChannel {
2740                            samples: difference_samples,
2741                            bits_per_sample: difference_bits_per_sample,
2742                            all_0: false,
2743                        },
2744                    )
2745                },
2746            )?;
2747
2748            match [
2749                (
2750                    ChannelAssignment::Independent(Independent::Stereo),
2751                    left_recorder.written() + right_recorder.written(),
2752                ),
2753                (
2754                    ChannelAssignment::LeftSide,
2755                    left_recorder.written() + difference_recorder.written(),
2756                ),
2757                (
2758                    ChannelAssignment::SideRight,
2759                    difference_recorder.written() + right_recorder.written(),
2760                ),
2761                (
2762                    ChannelAssignment::MidSide,
2763                    average_recorder.written() + difference_recorder.written(),
2764                ),
2765            ]
2766            .into_iter()
2767            .min_by_key(|(_, total)| *total)
2768            .unwrap()
2769            .0
2770            {
2771                channel_assignment @ ChannelAssignment::LeftSide => Ok(Correlated {
2772                    channel_assignment,
2773                    channels: [left_recorder, difference_recorder],
2774                }),
2775                channel_assignment @ ChannelAssignment::SideRight => Ok(Correlated {
2776                    channel_assignment,
2777                    channels: [difference_recorder, right_recorder],
2778                }),
2779                channel_assignment @ ChannelAssignment::MidSide => Ok(Correlated {
2780                    channel_assignment,
2781                    channels: [average_recorder, difference_recorder],
2782                }),
2783                channel_assignment @ ChannelAssignment::Independent(_) => Ok(Correlated {
2784                    channel_assignment,
2785                    channels: [left_recorder, right_recorder],
2786                }),
2787            }
2788        }
2789        Some(difference_bits_per_sample) => {
2790            let difference_recorder = {
2791                difference_samples.clear();
2792                difference_samples.extend(left.iter().zip(right).map(|(l, r)| l - r));
2793                encode_subframe(
2794                    options,
2795                    difference_cache,
2796                    CorrelatedChannel {
2797                        samples: difference_samples,
2798                        bits_per_sample: difference_bits_per_sample,
2799                        all_0: false,
2800                    },
2801                )?
2802            };
2803
2804            match [
2805                (
2806                    ChannelAssignment::Independent(Independent::Stereo),
2807                    left_recorder.written() + right_recorder.written(),
2808                ),
2809                (
2810                    ChannelAssignment::LeftSide,
2811                    left_recorder.written() + difference_recorder.written(),
2812                ),
2813                (
2814                    ChannelAssignment::SideRight,
2815                    difference_recorder.written() + right_recorder.written(),
2816                ),
2817            ]
2818            .into_iter()
2819            .min_by_key(|(_, total)| *total)
2820            .unwrap()
2821            .0
2822            {
2823                channel_assignment @ ChannelAssignment::LeftSide => Ok(Correlated {
2824                    channel_assignment,
2825                    channels: [left_recorder, difference_recorder],
2826                }),
2827                channel_assignment @ ChannelAssignment::SideRight => Ok(Correlated {
2828                    channel_assignment,
2829                    channels: [difference_recorder, right_recorder],
2830                }),
2831                ChannelAssignment::MidSide => unreachable!(),
2832                channel_assignment @ ChannelAssignment::Independent(_) => Ok(Correlated {
2833                    channel_assignment,
2834                    channels: [left_recorder, right_recorder],
2835                }),
2836            }
2837        }
2838        None => {
2839            // 32 bps stream, so forego difference channel
2840            // and encode them both indepedently
2841
2842            Ok(Correlated {
2843                channel_assignment: ChannelAssignment::Independent(Independent::Stereo),
2844                channels: [left_recorder, right_recorder],
2845            })
2846        }
2847    }
2848}
2849
2850fn encode_subframe<'c>(
2851    options: &EncoderOptions,
2852    ChannelCache {
2853        fixed: fixed_cache,
2854        fixed_output,
2855        lpc: lpc_cache,
2856        lpc_output,
2857        constant_output,
2858        verbatim_output,
2859        wasted,
2860    }: &'c mut ChannelCache,
2861    CorrelatedChannel {
2862        samples: channel,
2863        bits_per_sample,
2864        all_0,
2865    }: CorrelatedChannel,
2866) -> Result<&'c BitRecorder<u32, BigEndian>, Error> {
2867    const WASTED_MAX: NonZero<u32> = NonZero::new(32).unwrap();
2868
2869    debug_assert!(!channel.is_empty());
2870
2871    if all_0 {
2872        // all samples are 0
2873        constant_output.clear();
2874        encode_constant_subframe(constant_output, channel[0], bits_per_sample, 0)?;
2875        return Ok(constant_output);
2876    }
2877
2878    // determine any wasted bits
2879    let (channel, bits_per_sample, wasted_bps) =
2880        match channel.iter().try_fold(WASTED_MAX, |acc, sample| {
2881            NonZero::new(sample.trailing_zeros()).map(|sample| sample.min(acc))
2882        }) {
2883            None => (channel, bits_per_sample, 0),
2884            Some(WASTED_MAX) => {
2885                constant_output.clear();
2886                encode_constant_subframe(constant_output, channel[0], bits_per_sample, 0)?;
2887                return Ok(constant_output);
2888            }
2889            Some(wasted_bps) => {
2890                let wasted_bps = wasted_bps.get();
2891                wasted.clear();
2892                wasted.extend(channel.iter().map(|sample| sample >> wasted_bps));
2893                (
2894                    wasted.as_slice(),
2895                    bits_per_sample.checked_sub(wasted_bps).unwrap(),
2896                    wasted_bps,
2897                )
2898            }
2899        };
2900
2901    fixed_output.clear();
2902
2903    let best = match options.max_lpc_order {
2904        Some(max_lpc_order) => {
2905            lpc_output.clear();
2906
2907            match join(
2908                || {
2909                    encode_fixed_subframe(
2910                        options,
2911                        fixed_cache,
2912                        fixed_output,
2913                        channel,
2914                        bits_per_sample,
2915                        wasted_bps,
2916                    )
2917                },
2918                || {
2919                    encode_lpc_subframe(
2920                        options,
2921                        max_lpc_order,
2922                        lpc_cache,
2923                        lpc_output,
2924                        channel,
2925                        bits_per_sample,
2926                        wasted_bps,
2927                    )
2928                },
2929            ) {
2930                (Ok(()), Ok(())) => [fixed_output, lpc_output]
2931                    .into_iter()
2932                    .min_by_key(|c| c.written())
2933                    .unwrap(),
2934                (Err(_), Ok(())) => lpc_output,
2935                (Ok(()), Err(_)) => fixed_output,
2936                (Err(_), Err(_)) => {
2937                    verbatim_output.clear();
2938                    encode_verbatim_subframe(
2939                        verbatim_output,
2940                        channel,
2941                        bits_per_sample,
2942                        wasted_bps,
2943                    )?;
2944                    return Ok(verbatim_output);
2945                }
2946            }
2947        }
2948        _ => {
2949            match encode_fixed_subframe(
2950                options,
2951                fixed_cache,
2952                fixed_output,
2953                channel,
2954                bits_per_sample,
2955                wasted_bps,
2956            ) {
2957                Ok(()) => fixed_output,
2958                Err(_) => {
2959                    verbatim_output.clear();
2960                    encode_verbatim_subframe(
2961                        verbatim_output,
2962                        channel,
2963                        bits_per_sample,
2964                        wasted_bps,
2965                    )?;
2966                    return Ok(verbatim_output);
2967                }
2968            }
2969        }
2970    };
2971
2972    let verbatim_len = channel.len() as u32 * u32::from(bits_per_sample);
2973
2974    if best.written() < verbatim_len {
2975        Ok(best)
2976    } else {
2977        verbatim_output.clear();
2978        encode_verbatim_subframe(verbatim_output, channel, bits_per_sample, wasted_bps)?;
2979        Ok(verbatim_output)
2980    }
2981}
2982
2983fn encode_constant_subframe<W: BitWrite>(
2984    writer: &mut W,
2985    sample: i32,
2986    bits_per_sample: SignedBitCount<32>,
2987    wasted_bps: u32,
2988) -> Result<(), Error> {
2989    use crate::stream::{SubframeHeader, SubframeHeaderType};
2990
2991    writer.build(&SubframeHeader {
2992        type_: SubframeHeaderType::Constant,
2993        wasted_bps,
2994    })?;
2995
2996    writer
2997        .write_signed_counted(bits_per_sample, sample)
2998        .map_err(Error::Io)
2999}
3000
3001fn encode_verbatim_subframe<W: BitWrite>(
3002    writer: &mut W,
3003    channel: &[i32],
3004    bits_per_sample: SignedBitCount<32>,
3005    wasted_bps: u32,
3006) -> Result<(), Error> {
3007    use crate::stream::{SubframeHeader, SubframeHeaderType};
3008
3009    writer.build(&SubframeHeader {
3010        type_: SubframeHeaderType::Verbatim,
3011        wasted_bps,
3012    })?;
3013
3014    channel
3015        .iter()
3016        .try_for_each(|i| writer.write_signed_counted(bits_per_sample, *i))?;
3017
3018    Ok(())
3019}
3020
3021fn encode_fixed_subframe<W: BitWrite>(
3022    options: &EncoderOptions,
3023    FixedCache {
3024        fixed_buffers: buffers,
3025    }: &mut FixedCache,
3026    writer: &mut W,
3027    channel: &[i32],
3028    bits_per_sample: SignedBitCount<32>,
3029    wasted_bps: u32,
3030) -> Result<(), Error> {
3031    use crate::stream::{SubframeHeader, SubframeHeaderType};
3032
3033    // calculate residuals for FIXED subframe orders 0-4
3034    // (or fewer, if we don't have enough samples)
3035    let (order, warm_up, residuals) = {
3036        let mut fixed_orders = ArrayVec::<&[i32], 5>::new();
3037        fixed_orders.push(channel);
3038
3039        // accumulate a set of FIXED diffs
3040        'outer: for buf in buffers.iter_mut() {
3041            let prev_order = fixed_orders.last().unwrap();
3042            match prev_order.split_at_checked(1) {
3043                Some((_, r)) => {
3044                    buf.clear();
3045                    for (n, p) in r.iter().zip(*prev_order) {
3046                        match n.checked_sub(*p) {
3047                            Some(v) => {
3048                                buf.push(v);
3049                            }
3050                            None => break 'outer,
3051                        }
3052                    }
3053                    if buf.is_empty() {
3054                        break;
3055                    } else {
3056                        fixed_orders.push(buf.as_slice());
3057                    }
3058                }
3059                None => break,
3060            }
3061        }
3062
3063        let min_fixed = fixed_orders.last().unwrap().len();
3064
3065        // choose diff with the smallest abs sum
3066        fixed_orders
3067            .into_iter()
3068            .enumerate()
3069            .min_by_key(|(_, residuals)| {
3070                residuals[(residuals.len() - min_fixed)..]
3071                    .iter()
3072                    .map(|r| u64::from(r.unsigned_abs()))
3073                    .sum::<u64>()
3074            })
3075            .map(|(order, residuals)| (order as u8, &channel[0..order], residuals))
3076            .unwrap()
3077    };
3078
3079    writer.build(&SubframeHeader {
3080        type_: SubframeHeaderType::Fixed { order },
3081        wasted_bps,
3082    })?;
3083
3084    warm_up
3085        .iter()
3086        .try_for_each(|sample: &i32| writer.write_signed_counted(bits_per_sample, *sample))?;
3087
3088    write_residuals(options, writer, order.into(), residuals)
3089}
3090
3091fn encode_lpc_subframe<W: BitWrite>(
3092    options: &EncoderOptions,
3093    max_lpc_order: NonZero<u8>,
3094    cache: &mut LpcCache,
3095    writer: &mut W,
3096    channel: &[i32],
3097    bits_per_sample: SignedBitCount<32>,
3098    wasted_bps: u32,
3099) -> Result<(), Error> {
3100    use crate::stream::{SubframeHeader, SubframeHeaderType};
3101
3102    let LpcSubframeParameters {
3103        warm_up,
3104        residuals,
3105        parameters:
3106            LpcParameters {
3107                order,
3108                precision,
3109                shift,
3110                coefficients,
3111            },
3112    } = LpcSubframeParameters::best(options, bits_per_sample, max_lpc_order, cache, channel)?;
3113
3114    writer.build(&SubframeHeader {
3115        type_: SubframeHeaderType::Lpc { order },
3116        wasted_bps,
3117    })?;
3118
3119    for sample in warm_up {
3120        writer.write_signed_counted(bits_per_sample, *sample)?;
3121    }
3122
3123    writer.write_count::<0b1111>(
3124        precision
3125            .count()
3126            .checked_sub(1)
3127            .ok_or(Error::InvalidQlpPrecision)?,
3128    )?;
3129
3130    writer.write::<5, i32>(shift as i32)?;
3131
3132    for coeff in coefficients {
3133        writer.write_signed_counted(precision, coeff)?;
3134    }
3135
3136    write_residuals(options, writer, order.get().into(), residuals)
3137}
3138
3139struct LpcSubframeParameters<'w, 'r> {
3140    parameters: LpcParameters,
3141    warm_up: &'w [i32],
3142    residuals: &'r [i32],
3143}
3144
3145impl<'w, 'r> LpcSubframeParameters<'w, 'r> {
3146    fn best(
3147        options: &EncoderOptions,
3148        bits_per_sample: SignedBitCount<32>,
3149        max_lpc_order: NonZero<u8>,
3150        LpcCache {
3151            residuals,
3152            window,
3153            windowed,
3154        }: &'r mut LpcCache,
3155        channel: &'w [i32],
3156    ) -> Result<Self, Error> {
3157        let parameters = LpcParameters::best(
3158            options,
3159            bits_per_sample,
3160            max_lpc_order,
3161            window,
3162            windowed,
3163            channel,
3164        )?;
3165
3166        Self::encode_residuals(&parameters, channel, residuals)
3167            .map(|(warm_up, residuals)| Self {
3168                warm_up,
3169                residuals,
3170                parameters,
3171            })
3172            .map_err(|ResidualOverflow| Error::ResidualOverflow)
3173    }
3174
3175    fn encode_residuals(
3176        parameters: &LpcParameters,
3177        channel: &'w [i32],
3178        residuals: &'r mut Vec<i32>,
3179    ) -> Result<(&'w [i32], &'r [i32]), ResidualOverflow> {
3180        residuals.clear();
3181
3182        for split in usize::from(parameters.order.get())..channel.len() {
3183            let (previous, current) = channel.split_at(split);
3184
3185            residuals.push(
3186                current[0]
3187                    .checked_sub(
3188                        (previous
3189                            .iter()
3190                            .rev()
3191                            .zip(&parameters.coefficients)
3192                            .map(|(x, y)| *x as i64 * *y as i64)
3193                            .sum::<i64>()
3194                            >> parameters.shift) as i32,
3195                    )
3196                    .ok_or(ResidualOverflow)?,
3197            );
3198        }
3199
3200        Ok((
3201            &channel[0..parameters.order.get().into()],
3202            residuals.as_slice(),
3203        ))
3204    }
3205}
3206
3207#[derive(Debug)]
3208struct ResidualOverflow;
3209
3210impl From<ResidualOverflow> for Error {
3211    #[inline]
3212    fn from(_: ResidualOverflow) -> Self {
3213        Error::ResidualOverflow
3214    }
3215}
3216
3217#[test]
3218fn test_residual_encoding_1() {
3219    let samples = [
3220        0, 16, 31, 44, 54, 61, 64, 63, 58, 49, 38, 24, 8, -8, -24, -38, -49, -58, -63, -64, -61,
3221        -54, -44, -31, -16,
3222    ];
3223
3224    let expected_residuals = [
3225        2, 2, 2, 3, 3, 3, 2, 2, 3, 0, 0, 0, -1, -1, -1, -3, -2, -2, -2, -1, -1, 0, 0,
3226    ];
3227
3228    let mut actual_residuals = Vec::with_capacity(expected_residuals.len());
3229
3230    let (warm_up, residuals) = LpcSubframeParameters::encode_residuals(
3231        &LpcParameters {
3232            order: NonZero::new(2).unwrap(),
3233            precision: SignedBitCount::new::<7>(),
3234            shift: 5,
3235            coefficients: arrayvec![59, -30],
3236        },
3237        &samples,
3238        &mut actual_residuals,
3239    )
3240    .unwrap();
3241
3242    assert_eq!(warm_up, &samples[0..2]);
3243    assert_eq!(residuals, &expected_residuals);
3244}
3245
3246#[test]
3247fn test_residual_encoding_2() {
3248    let samples = [
3249        64, 62, 56, 47, 34, 20, 4, -12, -27, -41, -52, -60, -63, -63, -60, -52, -41, -27, -12, 4,
3250        20, 34, 47, 56, 62,
3251    ];
3252
3253    let expected_residuals = [
3254        2, 2, 0, 1, -1, -1, -1, -2, -2, -2, -1, -3, -2, 0, -1, 1, 0, 2, 2, 2, 4, 2, 4,
3255    ];
3256
3257    let mut actual_residuals = Vec::with_capacity(expected_residuals.len());
3258
3259    let (warm_up, residuals) = LpcSubframeParameters::encode_residuals(
3260        &LpcParameters {
3261            order: NonZero::new(2).unwrap(),
3262            precision: SignedBitCount::new::<7>(),
3263            shift: 5,
3264            coefficients: arrayvec![58, -29],
3265        },
3266        &samples,
3267        &mut actual_residuals,
3268    )
3269    .unwrap();
3270
3271    assert_eq!(warm_up, &samples[0..2]);
3272    assert_eq!(residuals, &expected_residuals);
3273}
3274
3275#[derive(Debug)]
3276struct LpcParameters {
3277    order: NonZero<u8>,
3278    precision: SignedBitCount<15>,
3279    shift: u32,
3280    coefficients: ArrayVec<i32, MAX_LPC_COEFFS>,
3281}
3282
3283// There isn't any particular *best* way to determine
3284// the ideal LPC subframe parameters (though there are
3285// some worst ways, like choosing them at random).
3286// Even the reference implementation has changed its
3287// defaults over time.  So long as the subframe's residuals
3288// are calculated correctly, decoders don't care one way or another.
3289//
3290// I'll try to use an approach similar to the reference implementation's.
3291
3292impl LpcParameters {
3293    fn best(
3294        options: &EncoderOptions,
3295        bits_per_sample: SignedBitCount<32>,
3296        max_lpc_order: NonZero<u8>,
3297        window: &mut Vec<f64>,
3298        windowed: &mut Vec<f64>,
3299        channel: &[i32],
3300    ) -> Result<Self, Error> {
3301        if channel.len() <= usize::from(max_lpc_order.get()) {
3302            // not enough samples in channel to calculate LPC parameters
3303            return Err(Error::InsufficientLpcSamples);
3304        }
3305
3306        let precision = match channel.len() {
3307            // this shouldn't be possible
3308            0 => panic!("at least one sample required in channel"),
3309            1..=192 => SignedBitCount::new::<7>(),
3310            193..=384 => SignedBitCount::new::<8>(),
3311            385..=576 => SignedBitCount::new::<9>(),
3312            577..=1152 => SignedBitCount::new::<10>(),
3313            1153..=2304 => SignedBitCount::new::<11>(),
3314            2305..=4608 => SignedBitCount::new::<12>(),
3315            4609.. => SignedBitCount::new::<13>(),
3316        };
3317
3318        let (order, lp_coeffs) = compute_best_order(
3319            bits_per_sample,
3320            precision,
3321            channel
3322                .len()
3323                .try_into()
3324                // this shouldn't be possible
3325                .expect("excessive samples for subframe"),
3326            lp_coefficients(autocorrelate(
3327                options.window.apply(window, windowed, channel),
3328                max_lpc_order,
3329            )),
3330        )?;
3331
3332        Self::quantize(order, lp_coeffs, precision)
3333    }
3334
3335    fn quantize(
3336        order: NonZero<u8>,
3337        coeffs: ArrayVec<f64, MAX_LPC_COEFFS>,
3338        precision: SignedBitCount<15>,
3339    ) -> Result<Self, Error> {
3340        const MAX_SHIFT: i32 = (1 << 4) - 1;
3341        const MIN_SHIFT: i32 = -(1 << 4);
3342
3343        // verified output against reference implementation
3344        // See: FLAC__lpc_quantize_coefficients
3345
3346        debug_assert!(coeffs.len() == usize::from(order.get()));
3347
3348        let max_coeff = (1 << (u32::from(precision) - 1)) - 1;
3349        let min_coeff = -(1 << (u32::from(precision) - 1));
3350
3351        let l = coeffs
3352            .iter()
3353            .map(|c| c.abs())
3354            .max_by(|x, y| x.total_cmp(y))
3355            // f64.log2() gives unfortunate results when <= 0.0
3356            .filter(|l| *l > 0.0)
3357            .ok_or(Error::ZeroLpCoefficients)?;
3358
3359        let mut error = 0.0;
3360
3361        match ((u32::from(precision) - 1) as i32 - ((l.log2().floor()) as i32) - 1).min(MAX_SHIFT) {
3362            shift @ 0.. => {
3363                // normal, positive shift case
3364                let shift = shift as u32;
3365
3366                Ok(Self {
3367                    order,
3368                    precision,
3369                    shift,
3370                    coefficients: coeffs
3371                        .into_iter()
3372                        .map(|lp_coeff| {
3373                            let sum: f64 = lp_coeff.mul_add((1 << shift) as f64, error);
3374                            let qlp_coeff = (sum.round() as i32).clamp(min_coeff, max_coeff);
3375                            error = sum - (qlp_coeff as f64);
3376                            qlp_coeff
3377                        })
3378                        .collect(),
3379                })
3380            }
3381            shift @ MIN_SHIFT..0 => {
3382                // unusual negative shift case
3383                let shift = -shift as u32;
3384
3385                Ok(Self {
3386                    order,
3387                    precision,
3388                    shift: 0,
3389                    coefficients: coeffs
3390                        .into_iter()
3391                        .map(|lp_coeff| {
3392                            let sum: f64 = (lp_coeff / (1 << shift) as f64) + error;
3393                            let qlp_coeff = (sum.round() as i32).clamp(min_coeff, max_coeff);
3394                            error = sum - (qlp_coeff as f64);
3395                            qlp_coeff
3396                        })
3397                        .collect(),
3398                })
3399            }
3400            ..MIN_SHIFT => Err(Error::LpNegativeShiftError),
3401        }
3402    }
3403}
3404
3405#[test]
3406fn test_quantization() {
3407    // test against numbers generated from reference implementation
3408
3409    let order = NonZero::new(4).unwrap();
3410
3411    let quantized = LpcParameters::quantize(
3412        order,
3413        arrayvec![0.797774, -0.045362, -0.050136, -0.054254],
3414        SignedBitCount::new::<10>(),
3415    )
3416    .unwrap();
3417
3418    assert_eq!(quantized.order, order);
3419    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3420    assert_eq!(quantized.shift, 9);
3421    assert_eq!(quantized.coefficients, arrayvec![408, -23, -25, -28]);
3422
3423    // note the relationship between the un-quantized,
3424    // floating point parameters and the shift value (9)
3425    //
3426    // 409 / 2 ** 9 ≈ 0.796875
3427    // -23 / 2 ** 9 ≈ -0.044921
3428    // -25 / 2 ** 9 ≈ -0.048828
3429    // -28 / 2 ** 9 ≈ -0.054687
3430    //
3431    // we're converting floats to fractions
3432
3433    let quantized = LpcParameters::quantize(
3434        order,
3435        arrayvec![-0.054687, -0.953216, -0.027115, 0.033537],
3436        SignedBitCount::new::<10>(),
3437    )
3438    .unwrap();
3439
3440    assert_eq!(quantized.order, order);
3441    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3442    assert_eq!(quantized.shift, 9);
3443    assert_eq!(quantized.coefficients, arrayvec![-28, -488, -14, 17]);
3444
3445    // coefficients should never be all zero, which is bad
3446    assert!(matches!(
3447        LpcParameters::quantize(
3448            order,
3449            arrayvec![0.0, 0.0, 0.0, 0.0],
3450            SignedBitCount::new::<10>(),
3451        ),
3452        Err(Error::ZeroLpCoefficients)
3453    ));
3454
3455    // negative shifts should also be handled properly
3456    let quantized = LpcParameters::quantize(
3457        order,
3458        arrayvec![-0.1, 0.1, 10000000.0, -0.2],
3459        SignedBitCount::new::<10>(),
3460    )
3461    .unwrap();
3462
3463    assert_eq!(quantized.order, order);
3464    assert_eq!(quantized.precision, SignedBitCount::new::<10>());
3465    assert_eq!(quantized.shift, 0);
3466    assert_eq!(quantized.coefficients, arrayvec![0, 0, 305, 0]);
3467
3468    // and massive negative shifts must be an error
3469    assert!(matches!(
3470        LpcParameters::quantize(
3471            order,
3472            arrayvec![-0.1, 0.1, 100000000.0, -0.2],
3473            SignedBitCount::new::<10>(),
3474        ),
3475        Err(Error::LpNegativeShiftError)
3476    ));
3477}
3478
3479fn autocorrelate(
3480    windowed: &[f64],
3481    max_lpc_order: NonZero<u8>,
3482) -> ArrayVec<f64, { MAX_LPC_COEFFS + 1 }> {
3483    // verified output against reference implementation
3484    // See: FLAC__lpc_compute_autocorrelation
3485
3486    debug_assert!(usize::from(max_lpc_order.get()) < MAX_LPC_COEFFS);
3487
3488    let mut tail = windowed;
3489    // let mut autocorrelated = Vec::with_capacity(max_lpc_order.get().into());
3490    let mut autocorrelated = ArrayVec::default();
3491
3492    for _ in 0..=max_lpc_order.get() {
3493        if tail.is_empty() {
3494            return autocorrelated;
3495        } else {
3496            autocorrelated.push(windowed.iter().zip(tail).map(|(x, y)| x * y).sum());
3497            tail.split_off_first();
3498        }
3499    }
3500
3501    autocorrelated
3502}
3503
3504#[test]
3505fn test_autocorrelation() {
3506    // test against numbers generated from reference implementation
3507
3508    assert_eq!(
3509        autocorrelate(&[1.0], NonZero::new(1).unwrap()),
3510        arrayvec![1.0]
3511    );
3512
3513    assert_eq!(
3514        autocorrelate(&[1.0, 2.0, 3.0, 4.0, 5.0], NonZero::new(4).unwrap()),
3515        arrayvec![55.0, 40.0, 26.0, 14.0, 5.0],
3516    );
3517
3518    assert_eq!(
3519        autocorrelate(
3520            &[
3521                0.0, 16.0, 31.0, 44.0, 54.0, 61.0, 64.0, 63.0, 58.0, 49.0, 38.0, 24.0, 8.0, -8.0,
3522                -24.0, -38.0, -49.0, -58.0, -63.0, -64.0, -61.0, -54.0, -44.0, -31.0, -16.0,
3523            ],
3524            NonZero::new(4).unwrap()
3525        ),
3526        arrayvec![51408.0, 49792.0, 45304.0, 38466.0, 29914.0],
3527    )
3528}
3529
3530#[derive(Debug)]
3531struct LpCoeff {
3532    coeffs: ArrayVec<f64, MAX_LPC_COEFFS>,
3533    error: f64,
3534}
3535
3536// returns a Vec of (coefficients, error) pairs
3537fn lp_coefficients(
3538    autocorrelated: ArrayVec<f64, { MAX_LPC_COEFFS + 1 }>,
3539) -> ArrayVec<LpCoeff, MAX_LPC_COEFFS> {
3540    // verified output against reference implementation
3541    // See: FLAC__lpc_compute_lp_coefficients
3542
3543    match autocorrelated.len() {
3544        0 | 1 => panic!("must have at least 2 autocorrelation values"),
3545        _ => {
3546            let k = autocorrelated[1] / autocorrelated[0];
3547            let mut lp_coefficients = arrayvec![LpCoeff {
3548                coeffs: arrayvec![k],
3549                error: autocorrelated[0] * (1.0 - k.powi(2)),
3550            }];
3551
3552            for i in 1..(autocorrelated.len() - 1) {
3553                if let [prev @ .., next] = &autocorrelated[0..=i + 1] {
3554                    let LpCoeff { coeffs, error } = lp_coefficients.last().unwrap();
3555
3556                    let q = next
3557                        - prev
3558                            .iter()
3559                            .rev()
3560                            .zip(coeffs)
3561                            .map(|(x, y)| x * y)
3562                            .sum::<f64>();
3563
3564                    let k = q / error;
3565
3566                    lp_coefficients.push(LpCoeff {
3567                        coeffs: coeffs
3568                            .iter()
3569                            .zip(coeffs.iter().rev().map(|c| k * c))
3570                            .map(|(c1, c2)| c1 - c2)
3571                            .chain(std::iter::once(k))
3572                            .collect(),
3573                        error: error * (1.0 - k.powi(2)),
3574                    });
3575                }
3576            }
3577
3578            lp_coefficients
3579        }
3580    }
3581}
3582
3583#[allow(unused)]
3584macro_rules! assert_float_approx {
3585    ($a:expr, $b:expr) => {{
3586        let a = $a;
3587        let b = $b;
3588        assert!((a - b).abs() < 1.0e-6, "{a} != {b}");
3589    }};
3590}
3591
3592#[test]
3593fn test_lp_coefficients_1() {
3594    // test against numbers generated from reference implementation
3595
3596    let lp_coeffs = lp_coefficients(arrayvec![55.0, 40.0, 26.0, 14.0, 5.0]);
3597
3598    assert_eq!(lp_coeffs.len(), 4);
3599
3600    assert_float_approx!(lp_coeffs[0].error, 25.909091);
3601    assert_float_approx!(lp_coeffs[1].error, 25.540351);
3602    assert_float_approx!(lp_coeffs[2].error, 25.316142);
3603    assert_float_approx!(lp_coeffs[3].error, 25.241623);
3604
3605    assert_eq!(lp_coeffs[0].coeffs.len(), 1);
3606    assert_float_approx!(lp_coeffs[0].coeffs[0], 0.727273);
3607
3608    assert_eq!(lp_coeffs[1].coeffs.len(), 2);
3609    assert_float_approx!(lp_coeffs[1].coeffs[0], 0.814035);
3610    assert_float_approx!(lp_coeffs[1].coeffs[1], -0.119298);
3611
3612    assert_eq!(lp_coeffs[2].coeffs.len(), 3);
3613    assert_float_approx!(lp_coeffs[2].coeffs[0], 0.802858);
3614    assert_float_approx!(lp_coeffs[2].coeffs[1], -0.043028);
3615    assert_float_approx!(lp_coeffs[2].coeffs[2], -0.093694);
3616
3617    assert_eq!(lp_coeffs[3].coeffs.len(), 4);
3618    assert_float_approx!(lp_coeffs[3].coeffs[0], 0.797774);
3619    assert_float_approx!(lp_coeffs[3].coeffs[1], -0.045362);
3620    assert_float_approx!(lp_coeffs[3].coeffs[2], -0.050136);
3621    assert_float_approx!(lp_coeffs[3].coeffs[3], -0.054254);
3622}
3623
3624#[test]
3625fn test_lp_coefficients_2() {
3626    // test against numbers generated from reference implementation
3627
3628    let lp_coeffs = lp_coefficients(arrayvec![51408.0, 49792.0, 45304.0, 38466.0, 29914.0]);
3629
3630    assert_eq!(lp_coeffs.len(), 4);
3631
3632    assert_float_approx!(lp_coeffs[0].error, 3181.201369);
3633    assert_float_approx!(lp_coeffs[1].error, 495.815931);
3634    assert_float_approx!(lp_coeffs[2].error, 495.161449);
3635    assert_float_approx!(lp_coeffs[3].error, 494.604514);
3636
3637    assert_eq!(lp_coeffs[0].coeffs.len(), 1);
3638    assert_float_approx!(lp_coeffs[0].coeffs[0], 0.968565);
3639
3640    assert_eq!(lp_coeffs[1].coeffs.len(), 2);
3641    assert_float_approx!(lp_coeffs[1].coeffs[0], 1.858456);
3642    assert_float_approx!(lp_coeffs[1].coeffs[1], -0.918772);
3643
3644    assert_eq!(lp_coeffs[2].coeffs.len(), 3);
3645    assert_float_approx!(lp_coeffs[2].coeffs[0], 1.891837);
3646    assert_float_approx!(lp_coeffs[2].coeffs[1], -0.986293);
3647    assert_float_approx!(lp_coeffs[2].coeffs[2], 0.036332);
3648
3649    assert_eq!(lp_coeffs[3].coeffs.len(), 4);
3650    assert_float_approx!(lp_coeffs[3].coeffs[0], 1.890618);
3651    assert_float_approx!(lp_coeffs[3].coeffs[1], -0.953216);
3652    assert_float_approx!(lp_coeffs[3].coeffs[2], -0.027115);
3653    assert_float_approx!(lp_coeffs[3].coeffs[3], 0.033537);
3654}
3655
3656// Returns (bits, order, coeffients) tuples
3657fn subframe_bits_by_order(
3658    bits_per_sample: SignedBitCount<32>,
3659    precision: SignedBitCount<15>,
3660    sample_count: u16,
3661    coeffs: ArrayVec<LpCoeff, MAX_LPC_COEFFS>,
3662) -> impl Iterator<Item = (f64, u8, ArrayVec<f64, MAX_LPC_COEFFS>)> {
3663    debug_assert!(sample_count > 0);
3664
3665    let error_scale = 0.5 / f64::from(sample_count);
3666
3667    coeffs
3668        .into_iter()
3669        .take_while(|coeffs| coeffs.error > 0.0)
3670        .zip(1..)
3671        .map(move |(LpCoeff { coeffs, error }, order)| {
3672            let header_bits =
3673                u32::from(order) * (u32::from(bits_per_sample) + u32::from(precision));
3674
3675            let bits_per_residual =
3676                (error * error_scale).ln() / (2.0 * std::f64::consts::LN_2).max(0.0);
3677
3678            let subframe_bits = bits_per_residual.mul_add(
3679                f64::from(sample_count - u16::from(order)),
3680                f64::from(header_bits),
3681            );
3682
3683            (subframe_bits, order, coeffs)
3684        })
3685}
3686
3687// Uses the error in the LP coefficients to determine the best order
3688// and returns that order along with the stripped-out coefficients
3689fn compute_best_order(
3690    bits_per_sample: SignedBitCount<32>,
3691    precision: SignedBitCount<15>,
3692    sample_count: u16,
3693    coeffs: ArrayVec<LpCoeff, MAX_LPC_COEFFS>,
3694) -> Result<(NonZero<u8>, ArrayVec<f64, MAX_LPC_COEFFS>), Error> {
3695    // verified output against reference implementation
3696    // See: FLAC__lpc_compute_best_order  and
3697    // See: FLAC__lpc_compute_expected_bits_per_residual_sample_with_error_scale
3698
3699    subframe_bits_by_order(bits_per_sample, precision, sample_count, coeffs)
3700        .min_by(|(x, _, _), (y, _, _)| x.total_cmp(y))
3701        .and_then(|(_, order, coeffs)| Some((NonZero::new(order)?, coeffs)))
3702        .ok_or(Error::NoBestLpcOrder)
3703}
3704
3705#[test]
3706fn test_compute_best_order() {
3707    // test against numbers generated from reference implementation
3708
3709    let mut bits = subframe_bits_by_order(
3710        SignedBitCount::new::<16>(),
3711        SignedBitCount::new::<5>(),
3712        20,
3713        [3181.201369, 495.815931, 495.161449, 494.604514]
3714            .into_iter()
3715            .map(|error| LpCoeff {
3716                coeffs: ArrayVec::default(),
3717                error,
3718            })
3719            .collect(),
3720    )
3721    .map(|t| t.0);
3722
3723    assert_float_approx!(bits.next().unwrap(), 80.977565);
3724    assert_float_approx!(bits.next().unwrap(), 74.685594);
3725    assert_float_approx!(bits.next().unwrap(), 93.853530);
3726    assert_float_approx!(bits.next().unwrap(), 113.025628);
3727
3728    let mut bits = subframe_bits_by_order(
3729        SignedBitCount::new::<16>(),
3730        SignedBitCount::new::<10>(),
3731        4096,
3732        [15000.0, 25000.0, 20000.0, 30000.0]
3733            .into_iter()
3734            .map(|error| LpCoeff {
3735                coeffs: ArrayVec::default(),
3736                error,
3737            })
3738            .collect(),
3739    )
3740    .map(|t| t.0);
3741
3742    assert_float_approx!(bits.next().unwrap(), 1812.801817);
3743    assert_float_approx!(bits.next().unwrap(), 3346.934051);
3744    assert_float_approx!(bits.next().unwrap(), 2713.303385);
3745    assert_float_approx!(bits.next().unwrap(), 3935.492805);
3746}
3747
3748fn write_residuals<W: BitWrite>(
3749    options: &EncoderOptions,
3750    writer: &mut W,
3751    predictor_order: usize,
3752    residuals: &[i32],
3753) -> Result<(), Error> {
3754    use crate::stream::ResidualPartitionHeader;
3755    use bitstream_io::{BitCount, ToBitStream};
3756
3757    const MAX_PARTITIONS: usize = 64;
3758
3759    #[derive(Debug)]
3760    struct Partition<'r, const RICE_MAX: u32> {
3761        header: ResidualPartitionHeader<RICE_MAX>,
3762        residuals: &'r [i32],
3763    }
3764
3765    impl<'r, const RICE_MAX: u32> Partition<'r, RICE_MAX> {
3766        fn new(partition: &'r [i32], estimated_bits: &mut u32) -> Option<Self> {
3767            let partition_samples = partition.len() as u16;
3768            if partition_samples == 0 {
3769                return None;
3770            }
3771
3772            let partition_sum = partition
3773                .iter()
3774                .map(|i| u64::from(i.unsigned_abs()))
3775                .sum::<u64>();
3776
3777            if partition_sum > 0 {
3778                let rice = if partition_sum > partition_samples.into() {
3779                    let bits_needed = ((partition_sum as f64) / f64::from(partition_samples))
3780                        .log2()
3781                        .ceil() as u32;
3782
3783                    match BitCount::try_from(bits_needed).ok().filter(|rice| {
3784                        u32::from(*rice) < u32::from(BitCount::<RICE_MAX>::new::<RICE_MAX>())
3785                    }) {
3786                        Some(rice) => rice,
3787                        None => {
3788                            let escape_size = (partition
3789                                .iter()
3790                                .map(|i| u64::from(i.unsigned_abs()))
3791                                .sum::<u64>()
3792                                .ilog2()
3793                                + 2)
3794                            .try_into()
3795                            .ok()?;
3796
3797                            *estimated_bits +=
3798                                u32::from(escape_size) * u32::from(partition_samples);
3799
3800                            return Some(Self {
3801                                header: ResidualPartitionHeader::Escaped { escape_size },
3802                                residuals: partition,
3803                            });
3804                        }
3805                    }
3806                } else {
3807                    BitCount::new::<0>()
3808                };
3809
3810                let partition_size: u32 = 4u32
3811                    + ((1 + u32::from(rice)) * u32::from(partition_samples))
3812                    + if u32::from(rice) > 0 {
3813                        u32::try_from(partition_sum >> (u32::from(rice) - 1)).ok()?
3814                    } else {
3815                        u32::try_from(partition_sum << 1).ok()?
3816                    }
3817                    - (u32::from(partition_samples) / 2);
3818
3819                *estimated_bits += partition_size;
3820
3821                Some(Partition {
3822                    header: ResidualPartitionHeader::Standard { rice },
3823                    residuals: partition,
3824                })
3825            } else {
3826                // all partition residuals are 0, so use a constant
3827                Some(Partition {
3828                    header: ResidualPartitionHeader::Constant,
3829                    residuals: partition,
3830                })
3831            }
3832        }
3833    }
3834
3835    impl<const RICE_MAX: u32> ToBitStream for Partition<'_, RICE_MAX> {
3836        type Error = std::io::Error;
3837
3838        #[inline]
3839        fn to_writer<W: BitWrite + ?Sized>(&self, w: &mut W) -> Result<(), Self::Error> {
3840            w.build(&self.header)?;
3841            match self.header {
3842                ResidualPartitionHeader::Standard { rice } => {
3843                    let mask = rice.mask_lsb();
3844
3845                    self.residuals.iter().try_for_each(|s| {
3846                        let (msb, lsb) = mask(if s.is_negative() {
3847                            ((-*s as u32 - 1) << 1) + 1
3848                        } else {
3849                            (*s as u32) << 1
3850                        });
3851                        w.write_unary::<1>(msb)?;
3852                        w.write_checked(lsb)
3853                    })?;
3854                }
3855                ResidualPartitionHeader::Escaped { escape_size } => {
3856                    self.residuals
3857                        .iter()
3858                        .try_for_each(|s| w.write_signed_counted(escape_size, *s))?;
3859                }
3860                ResidualPartitionHeader::Constant => { /* nothing left to do */ }
3861            }
3862            Ok(())
3863        }
3864    }
3865
3866    fn best_partitions<'r, const RICE_MAX: u32>(
3867        options: &EncoderOptions,
3868        block_size: usize,
3869        residuals: &'r [i32],
3870    ) -> ArrayVec<Partition<'r, RICE_MAX>, MAX_PARTITIONS> {
3871        (0..=block_size.trailing_zeros().min(options.max_partition_order))
3872            .map(|partition_order| 1 << partition_order)
3873            .take_while(|partition_count: &usize| partition_count.is_power_of_two())
3874            .filter_map(|partition_count| {
3875                let mut estimated_bits = 0;
3876
3877                let partitions = residuals
3878                    .rchunks(block_size / partition_count)
3879                    .rev()
3880                    .map(|partition| Partition::new(partition, &mut estimated_bits))
3881                    .collect::<Option<ArrayVec<_, MAX_PARTITIONS>>>()
3882                    .filter(|p| !p.is_empty() && p.len().is_power_of_two())?;
3883
3884                Some((partitions, estimated_bits))
3885            })
3886            .min_by_key(|(_, estimated_bits)| *estimated_bits)
3887            .map(|(partitions, _)| partitions)
3888            .unwrap_or_else(|| {
3889                std::iter::once(Partition {
3890                    header: ResidualPartitionHeader::Escaped {
3891                        escape_size: SignedBitCount::new::<0b11111>(),
3892                    },
3893                    residuals,
3894                })
3895                .collect()
3896            })
3897    }
3898
3899    fn write_partitions<const RICE_MAX: u32, W: BitWrite>(
3900        writer: &mut W,
3901        partitions: ArrayVec<Partition<'_, RICE_MAX>, MAX_PARTITIONS>,
3902    ) -> Result<(), Error> {
3903        writer.write::<4, u32>(partitions.len().ilog2())?; // partition order
3904        for partition in partitions {
3905            writer.build(&partition)?;
3906        }
3907        Ok(())
3908    }
3909
3910    #[inline]
3911    fn try_shrink_header<const RICE_MAX: u32, const RICE_NEW_MAX: u32>(
3912        header: ResidualPartitionHeader<RICE_MAX>,
3913    ) -> Option<ResidualPartitionHeader<RICE_NEW_MAX>> {
3914        Some(match header {
3915            ResidualPartitionHeader::Standard { rice } => ResidualPartitionHeader::Standard {
3916                rice: rice.try_map(|r| (r < RICE_NEW_MAX).then_some(r))?,
3917            },
3918            ResidualPartitionHeader::Escaped { escape_size } => {
3919                ResidualPartitionHeader::Escaped { escape_size }
3920            }
3921            ResidualPartitionHeader::Constant => ResidualPartitionHeader::Constant,
3922        })
3923    }
3924
3925    enum CodingMethod<'p> {
3926        Rice(ArrayVec<Partition<'p, 0b1111>, MAX_PARTITIONS>),
3927        Rice2(ArrayVec<Partition<'p, 0b11111>, MAX_PARTITIONS>),
3928    }
3929
3930    fn try_reduce_rice(
3931        partitions: ArrayVec<Partition<'_, 0b11111>, MAX_PARTITIONS>,
3932    ) -> CodingMethod<'_> {
3933        match partitions
3934            .iter()
3935            .map(|Partition { header, residuals }| {
3936                try_shrink_header(*header).map(|header| Partition { header, residuals })
3937            })
3938            .collect()
3939        {
3940            Some(partitions) => CodingMethod::Rice(partitions),
3941            None => CodingMethod::Rice2(partitions),
3942        }
3943    }
3944
3945    let block_size = predictor_order + residuals.len();
3946
3947    if options.use_rice2 {
3948        match try_reduce_rice(best_partitions(options, block_size, residuals)) {
3949            CodingMethod::Rice(partitions) => {
3950                writer.write::<2, u8>(0)?; // coding method 0
3951                write_partitions(writer, partitions)
3952            }
3953            CodingMethod::Rice2(partitions) => {
3954                writer.write::<2, u8>(1)?; // coding method 1
3955                write_partitions(writer, partitions)
3956            }
3957        }
3958    } else {
3959        let partitions = best_partitions::<0b1111>(options, block_size, residuals);
3960        writer.write::<2, u8>(0)?; // coding method 0
3961        write_partitions(writer, partitions)
3962    }
3963}
3964
3965fn try_join<A, B, RA, RB, E>(oper_a: A, oper_b: B) -> Result<(RA, RB), E>
3966where
3967    A: FnOnce() -> Result<RA, E> + Send,
3968    B: FnOnce() -> Result<RB, E> + Send,
3969    RA: Send,
3970    RB: Send,
3971    E: Send,
3972{
3973    let (a, b) = join(oper_a, oper_b);
3974    Ok((a?, b?))
3975}
3976
3977#[cfg(feature = "rayon")]
3978use rayon::join;
3979
3980#[cfg(not(feature = "rayon"))]
3981fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
3982where
3983    A: FnOnce() -> RA + Send,
3984    B: FnOnce() -> RB + Send,
3985    RA: Send,
3986    RB: Send,
3987{
3988    (oper_a(), oper_b())
3989}
3990
3991#[cfg(feature = "rayon")]
3992fn vec_map<T, U, F>(src: Vec<T>, f: F) -> Vec<U>
3993where
3994    T: Send,
3995    U: Send,
3996    F: Fn(T) -> U + Send + Sync,
3997{
3998    use rayon::iter::{IntoParallelIterator, ParallelIterator};
3999
4000    src.into_par_iter().map(f).collect()
4001}
4002
4003#[cfg(not(feature = "rayon"))]
4004fn vec_map<T, U, F>(src: Vec<T>, f: F) -> Vec<U>
4005where
4006    T: Send,
4007    U: Send,
4008    F: Fn(T) -> U + Send + Sync,
4009{
4010    src.into_iter().map(f).collect()
4011}
4012
4013fn exact_div<N>(n: N, rhs: N) -> Option<N>
4014where
4015    N: std::ops::Div<Output = N> + std::ops::Rem<Output = N> + std::cmp::PartialEq + Copy + Default,
4016{
4017    (n % rhs == N::default()).then_some(n / rhs)
4018}