use std::io::{SeekFrom, Write};
use audio_samples::{AudioSamples, SampleType, traits::StandardSample};
use super::{
CompressionLevel, constants::FLAC_MARKER, flac_file::audio_to_planar_i32, frame::encode_frame_from_channels,
metadata::StreamInfo,
};
use crate::{
WriteSeek,
error::{AudioIOError, AudioIOResult},
traits::{AudioStreamWrite, AudioStreamWriter},
types::ValidatedSampleType,
};
#[derive(Debug, Clone, Copy)]
struct FlacEncodeParams {
max_lpc_order: usize,
qlp_precision: u8,
min_partition_order: u8,
max_partition_order: u8,
try_mid_side: bool,
exhaustive_rice: bool,
}
impl FlacEncodeParams {
const fn from_level(level: CompressionLevel) -> Self {
let (min_partition_order, max_partition_order) = level.rice_partition_order_range();
Self {
max_lpc_order: level.max_lpc_order() as usize,
qlp_precision: level.qlp_precision(),
min_partition_order,
max_partition_order,
try_mid_side: level.try_mid_side(),
exhaustive_rice: level.exhaustive_rice_search(),
}
}
}
#[derive(Debug)]
pub struct StreamedFlacWriter<W>
where
W: WriteSeek,
{
writer: W,
channels: u16,
sample_rate: u32,
bits_per_sample: u8,
block_size: u32,
params: FlacEncodeParams,
accum: Vec<Vec<i32>>,
frame_number: u64,
frames_written: usize,
streaminfo_offset: u64,
finalized: bool,
}
impl<W> StreamedFlacWriter<W>
where
W: WriteSeek,
{
pub fn new_i16(writer: W, channels: u16, sample_rate: u32) -> AudioIOResult<Self> {
Self::new(
writer,
channels,
sample_rate,
ValidatedSampleType::I16,
CompressionLevel::default(),
)
}
pub fn new_i24(writer: W, channels: u16, sample_rate: u32) -> AudioIOResult<Self> {
Self::new(
writer,
channels,
sample_rate,
ValidatedSampleType::I24,
CompressionLevel::default(),
)
}
pub fn new_i32(writer: W, channels: u16, sample_rate: u32) -> AudioIOResult<Self> {
Self::new(
writer,
channels,
sample_rate,
ValidatedSampleType::I32,
CompressionLevel::default(),
)
}
pub fn new_f32(writer: W, channels: u16, sample_rate: u32) -> AudioIOResult<Self> {
Self::new(
writer,
channels,
sample_rate,
ValidatedSampleType::F32,
CompressionLevel::default(),
)
}
pub fn new_f64(writer: W, channels: u16, sample_rate: u32) -> AudioIOResult<Self> {
Self::new(
writer,
channels,
sample_rate,
ValidatedSampleType::F64,
CompressionLevel::default(),
)
}
pub fn new(
mut writer: W,
channels: u16,
sample_rate: u32,
sample_type: ValidatedSampleType,
level: CompressionLevel,
) -> AudioIOResult<Self> {
if channels == 0 || channels > 8 {
return Err(AudioIOError::corrupted_data_simple(
"Invalid channel count for FLAC",
format!("FLAC supports 1-8 channels, got {channels}"),
));
}
let bits_per_sample = flac_bits_for(sample_type);
let block_size = level.block_size();
writer.write_all(&FLAC_MARKER)?;
let stream_info = StreamInfo {
min_block_size: block_size as u16,
max_block_size: block_size as u16,
min_frame_size: 0,
max_frame_size: 0,
sample_rate,
channels: channels as u8,
bits_per_sample,
total_samples: 0,
md5_signature: [0; 16],
};
let streaminfo_bytes = stream_info.to_bytes();
writer.write_all(&[0x80])?; writer.write_all(&[(streaminfo_bytes.len() >> 16) as u8])?;
writer.write_all(&[(streaminfo_bytes.len() >> 8) as u8])?;
writer.write_all(&[streaminfo_bytes.len() as u8])?;
let streaminfo_offset = writer.stream_position()?;
writer.write_all(&streaminfo_bytes)?;
let accum = (0..channels as usize).map(|_| Vec::new()).collect();
Ok(Self {
writer,
channels,
sample_rate,
bits_per_sample,
block_size,
params: FlacEncodeParams::from_level(level),
accum,
frame_number: 0,
frames_written: 0,
streaminfo_offset,
finalized: false,
})
}
}
const fn flac_bits_for(sample_type: ValidatedSampleType) -> u8 {
match sample_type {
ValidatedSampleType::I16 => 16,
_ => 24,
}
}
fn encode_and_write_block<W: Write>(
writer: &mut W,
accum: &[Vec<i32>],
start: usize,
len: usize,
bits_per_sample: u8,
sample_rate: u32,
frame_number: u64,
params: &FlacEncodeParams,
) -> AudioIOResult<()> {
let ch_slices: Vec<&[i32]> = accum.iter().map(|ch| &ch[start..start + len]).collect();
let frame_bytes = encode_frame_from_channels(
&ch_slices,
bits_per_sample,
sample_rate,
frame_number,
params.max_lpc_order,
params.qlp_precision,
params.min_partition_order,
params.max_partition_order,
params.try_mid_side,
params.exhaustive_rice,
)
.map_err(AudioIOError::FlacError)?;
writer.write_all(&frame_bytes)?;
Ok(())
}
impl<W> AudioStreamWriter for StreamedFlacWriter<W>
where
W: WriteSeek,
{
fn flush(&mut self) -> AudioIOResult<()> {
self.writer.flush()?;
Ok(())
}
fn finalize(&mut self) -> AudioIOResult<()> {
if self.finalized {
return Ok(()); }
let remaining = self.accum.first().map(|c| c.len()).unwrap_or(0);
if remaining > 0 {
encode_and_write_block(
&mut self.writer,
&self.accum,
0,
remaining,
self.bits_per_sample,
self.sample_rate,
self.frame_number,
&self.params,
)?;
self.frame_number += 1;
for ch in self.accum.iter_mut() {
ch.clear();
}
}
let total = self.frames_written as u64;
let final_block = if self.frames_written == 0 {
self.block_size
} else {
self.block_size.min(self.frames_written as u32)
};
let stream_info = StreamInfo {
min_block_size: final_block as u16,
max_block_size: final_block as u16,
min_frame_size: 0,
max_frame_size: 0,
sample_rate: self.sample_rate,
channels: self.channels as u8,
bits_per_sample: self.bits_per_sample,
total_samples: total,
md5_signature: [0; 16],
};
let streaminfo_bytes = stream_info.to_bytes();
let end = self.writer.stream_position()?;
self.writer.seek(SeekFrom::Start(self.streaminfo_offset))?;
self.writer.write_all(&streaminfo_bytes)?;
self.writer.seek(SeekFrom::Start(end))?;
self.writer.flush()?;
self.finalized = true;
Ok(())
}
fn is_finalized(&self) -> bool {
self.finalized
}
fn frames_written(&self) -> usize {
self.frames_written
}
fn sample_rate(&self) -> u32 {
self.sample_rate
}
fn num_channels(&self) -> u16 {
self.channels
}
}
impl<W> AudioStreamWrite for StreamedFlacWriter<W>
where
W: WriteSeek,
{
fn write_frames<T>(&mut self, samples: &AudioSamples<'_, T>) -> AudioIOResult<usize>
where
T: StandardSample + 'static,
{
if self.finalized {
return Err(AudioIOError::corrupted_data_simple(
"Cannot write to finalized stream",
"Call write_frames before finalize()",
));
}
let input_channels = samples.num_channels();
if input_channels.get() != self.channels as u32 {
return Err(AudioIOError::corrupted_data_simple(
"Channel count mismatch",
format!(
"Writer configured for {} channels, got {} channels",
self.channels, input_channels
),
));
}
let chunk_bits: u8 = match T::SAMPLE_TYPE {
SampleType::I16 => 16,
_ => 24,
};
if chunk_bits != self.bits_per_sample {
return Err(AudioIOError::corrupted_data_simple(
"Sample bit-depth mismatch",
format!(
"Writer configured for {}-bit FLAC, but input maps to {}-bit",
self.bits_per_sample, chunk_bits
),
));
}
let frames_per_channel = samples.samples_per_channel().get();
let planar = audio_to_planar_i32(samples)?;
for (ch, data) in planar.iter().enumerate() {
self.accum[ch].extend_from_slice(data);
}
let bs = self.block_size as usize;
let available = self.accum.first().map(|c| c.len()).unwrap_or(0);
if bs > 0 && available >= bs {
let full_blocks = available / bs;
for b in 0..full_blocks {
encode_and_write_block(
&mut self.writer,
&self.accum,
b * bs,
bs,
self.bits_per_sample,
self.sample_rate,
self.frame_number,
&self.params,
)?;
self.frame_number += 1;
}
let consumed = full_blocks * bs;
for ch in self.accum.iter_mut() {
ch.drain(0..consumed);
}
}
self.frames_written += frames_per_channel;
Ok(frames_per_channel)
}
}
impl<W> Drop for StreamedFlacWriter<W>
where
W: WriteSeek,
{
fn drop(&mut self) {
if !self.finalized {
#[cfg(debug_assertions)]
if self.frames_written > 0 {
eprintln!(
"Warning: StreamedFlacWriter dropped without calling finalize(); \
finalizing now. Call finalize() explicitly to surface I/O errors."
);
}
let _ = self.finalize();
}
}
}