use std::{io::SeekFrom, num::NonZeroUsize};
use audio_samples::{AudioSamples, ConvertTo, I24, StandardSample};
use non_empty_slice::non_empty_vec;
use crate::{
WriteSeek,
error::{AudioIOError, AudioIOResult},
traits::{AudioStreamWrite, AudioStreamWriter},
types::ValidatedSampleType,
wav::FormatCode,
};
#[derive(Debug)]
pub struct StreamedWavWriter<W>
where
W: WriteSeek,
{
writer: W,
channels: u16,
sample_rate: u32,
sample_type: ValidatedSampleType,
#[allow(dead_code)]
bytes_per_sample: u16,
block_align: u16,
frames_written: usize,
data_bytes_written: u64,
riff_size_offset: u64,
data_size_offset: u64,
finalized: bool,
}
impl<W> StreamedWavWriter<W>
where
W: WriteSeek,
{
pub fn new_i16(writer: W, channels: u16, sample_rate: u32) -> AudioIOResult<Self> {
Self::new_with_sample_type(writer, channels, sample_rate, ValidatedSampleType::I16)
}
pub fn new_i24(writer: W, channels: u16, sample_rate: u32) -> AudioIOResult<Self> {
Self::new_with_sample_type(writer, channels, sample_rate, ValidatedSampleType::I24)
}
pub fn new_i32(writer: W, channels: u16, sample_rate: u32) -> AudioIOResult<Self> {
Self::new_with_sample_type(writer, channels, sample_rate, ValidatedSampleType::I32)
}
pub fn new_f32(writer: W, channels: u16, sample_rate: u32) -> AudioIOResult<Self> {
Self::new_with_sample_type(writer, channels, sample_rate, ValidatedSampleType::F32)
}
pub fn new_f64(writer: W, channels: u16, sample_rate: u32) -> AudioIOResult<Self> {
Self::new_with_sample_type(writer, channels, sample_rate, ValidatedSampleType::F64)
}
fn new_with_sample_type(
mut writer: W,
channels: u16,
sample_rate: u32,
sample_type: ValidatedSampleType,
) -> AudioIOResult<Self> {
if channels == 0 {
return Err(AudioIOError::corrupted_data_simple(
"Invalid channel count",
"Channel count must be at least 1",
));
}
let bytes_per_sample = sample_type.bytes_per_sample().get() as u16;
let block_align = channels * bytes_per_sample;
let use_extensible = Self::needs_extensible_format(channels, sample_type);
let _fmt_chunk_size: u32 = if use_extensible { 40 } else { 16 };
writer.write_all(b"RIFF")?;
let riff_size_offset = writer.stream_position()?;
writer.write_all(&0u32.to_le_bytes())?; writer.write_all(b"WAVE")?;
if use_extensible {
Self::write_extensible_fmt(&mut writer, channels, sample_rate, sample_type)?;
} else {
Self::write_base_fmt(&mut writer, channels, sample_rate, sample_type)?;
}
writer.write_all(b"data")?;
let data_size_offset = writer.stream_position()?;
writer.write_all(&0u32.to_le_bytes())?;
let _data_start_offset = writer.stream_position()?;
Ok(StreamedWavWriter {
writer,
channels,
sample_rate,
sample_type,
bytes_per_sample,
block_align,
frames_written: 0,
data_bytes_written: 0,
riff_size_offset,
data_size_offset,
finalized: false,
})
}
const fn needs_extensible_format(channels: u16, sample_type: ValidatedSampleType) -> bool {
channels > 2
|| matches!(
sample_type,
ValidatedSampleType::I24 | ValidatedSampleType::F64
)
}
fn write_base_fmt(
writer: &mut W,
channels: u16,
sample_rate: u32,
sample_type: ValidatedSampleType,
) -> AudioIOResult<()> {
let format_code = Self::sample_type_to_format(sample_type);
let bits_per_sample = sample_type.bits_per_sample();
let bytes_per_sample = sample_type.bytes_per_sample().get() as u16;
let block_align = channels * bytes_per_sample;
let byte_rate = sample_rate * block_align as u32;
writer.write_all(b"fmt ")?;
writer.write_all(&16u32.to_le_bytes())?; writer.write_all(&format_code.as_u16().to_le_bytes())?;
writer.write_all(&channels.to_le_bytes())?;
writer.write_all(&sample_rate.to_le_bytes())?;
writer.write_all(&byte_rate.to_le_bytes())?;
writer.write_all(&block_align.to_le_bytes())?;
writer.write_all(&bits_per_sample.get().to_le_bytes())?;
Ok(())
}
fn write_extensible_fmt(
writer: &mut W,
channels: u16,
sample_rate: u32,
sample_type: ValidatedSampleType,
) -> AudioIOResult<()> {
let format_code = Self::sample_type_to_format(sample_type);
let bits_per_sample = sample_type.bits_per_sample();
let bytes_per_sample = sample_type.bytes_per_sample().get() as u16;
let block_align = channels * bytes_per_sample;
let byte_rate = sample_rate * block_align as u32;
let channel_mask: u32 = match channels {
1 => 0x4, 2 => 0x3, 3 => 0x7, 4 => 0x33, 5 => 0x37, 6 => 0x3F, 7 => 0x13F, 8 => 0x63F, _ => {
if channels < 32 {
(1u32 << channels) - 1
} else {
0xFFFFFFFF
}
}
};
writer.write_all(b"fmt ")?;
writer.write_all(&40u32.to_le_bytes())?;
writer.write_all(&FormatCode::Extensible.as_u16().to_le_bytes())?;
writer.write_all(&channels.to_le_bytes())?;
writer.write_all(&sample_rate.to_le_bytes())?;
writer.write_all(&byte_rate.to_le_bytes())?;
writer.write_all(&block_align.to_le_bytes())?;
writer.write_all(&bits_per_sample.get().to_le_bytes())?;
writer.write_all(&22u16.to_le_bytes())?;
writer.write_all(&bits_per_sample.get().to_le_bytes())?; writer.write_all(&channel_mask.to_le_bytes())?;
let mut sub_format = [0u8; 16];
sub_format[0..2].copy_from_slice(&format_code.as_u16().to_le_bytes());
sub_format[2..16].copy_from_slice(&[
0x00, 0x00, 0x10, 0x00, 0x80, 0x00, 0x00, 0xAA, 0x00, 0x38, 0x9B, 0x71, 0x00, 0x00,
]);
writer.write_all(&sub_format)?;
Ok(())
}
const fn sample_type_to_format(sample_type: ValidatedSampleType) -> FormatCode {
match sample_type {
ValidatedSampleType::U8
| ValidatedSampleType::I16
| ValidatedSampleType::I24
| ValidatedSampleType::I32 => FormatCode::Pcm,
ValidatedSampleType::F32 | ValidatedSampleType::F64 => FormatCode::IeeeFloat,
}
}
pub const fn target_sample_type(&self) -> ValidatedSampleType {
self.sample_type
}
pub fn write_raw_bytes(&mut self, bytes: &[u8]) -> AudioIOResult<usize> {
if self.finalized {
return Err(AudioIOError::corrupted_data_simple(
"Cannot write to finalized stream",
"Call write_frames before finalize()",
));
}
let frame_bytes = self.block_align as usize;
if !bytes.len().is_multiple_of(frame_bytes) {
return Err(AudioIOError::corrupted_data_simple(
"Byte count must be a multiple of frame size",
format!(
"Got {} bytes, frame size is {} bytes",
bytes.len(),
frame_bytes
),
));
}
self.writer.write_all(bytes)?;
let frames = bytes.len() / frame_bytes;
self.frames_written += frames;
self.data_bytes_written += bytes.len() as u64;
Ok(frames)
}
}
impl<W> AudioStreamWriter for StreamedWavWriter<W>
where
W: WriteSeek,
{
fn flush(&mut self) -> AudioIOResult<()> {
self.writer.flush()?;
Ok(())
}
fn finalize(&mut self) -> AudioIOResult<()> {
if self.finalized {
return Ok(()); }
if self.data_bytes_written % 2 == 1 {
self.writer.write_all(&[0])?;
}
let data_size = self.data_bytes_written as u32;
let padded_data_size = if self.data_bytes_written % 2 == 1 {
self.data_bytes_written + 1
} else {
self.data_bytes_written
};
let use_extensible = Self::needs_extensible_format(self.channels, self.sample_type);
let fmt_chunk_size: u64 = if use_extensible { 40 } else { 16 };
let fmt_total_size = 8 + fmt_chunk_size;
let riff_size = 4 + fmt_total_size + 8 + padded_data_size;
let current_pos = self.writer.stream_position()?;
self.writer.seek(SeekFrom::Start(self.riff_size_offset))?;
self.writer.write_all(&(riff_size as u32).to_le_bytes())?;
self.writer.seek(SeekFrom::Start(self.data_size_offset))?;
self.writer.write_all(&data_size.to_le_bytes())?;
self.writer.seek(SeekFrom::Start(current_pos))?;
self.writer.flush()?;
self.finalized = true;
Ok(())
}
fn is_finalized(&self) -> bool {
self.finalized
}
fn frames_written(&self) -> usize {
self.frames_written
}
fn sample_rate(&self) -> u32 {
self.sample_rate
}
fn num_channels(&self) -> u16 {
self.channels
}
}
impl<W> AudioStreamWrite for StreamedWavWriter<W>
where
W: WriteSeek,
{
fn write_frames<T>(&mut self, samples: &AudioSamples<'_, T>) -> AudioIOResult<usize>
where
T: StandardSample + 'static,
{
if self.finalized {
return Err(AudioIOError::corrupted_data_simple(
"Cannot write to finalized stream",
"Call write_frames before finalize()",
));
}
let input_channels = samples.num_channels();
if input_channels.get() != self.channels as u32 {
return Err(AudioIOError::corrupted_data_simple(
"Channel count mismatch",
format!(
"Writer configured for {} channels, got {} channels",
self.channels, input_channels
),
));
}
let frames_per_channel = samples.samples_per_channel();
let interleaved = samples.data.as_interleaved_vec();
let bytes_written = match self.sample_type {
ValidatedSampleType::U8 => self.write_samples_as::<T, u8>(&interleaved)?,
ValidatedSampleType::I16 => self.write_samples_as::<T, i16>(&interleaved)?,
ValidatedSampleType::I24 => self.write_samples_as::<T, I24>(&interleaved)?,
ValidatedSampleType::I32 => self.write_samples_as::<T, i32>(&interleaved)?,
ValidatedSampleType::F32 => self.write_samples_as::<T, f32>(&interleaved)?,
ValidatedSampleType::F64 => self.write_samples_as::<T, f64>(&interleaved)?,
};
self.frames_written += frames_per_channel.get();
self.data_bytes_written += bytes_written as u64;
Ok(frames_per_channel.get())
}
}
impl<W> StreamedWavWriter<W>
where
W: WriteSeek,
{
fn write_samples_as<T, U>(&mut self, samples: &[T]) -> AudioIOResult<usize>
where
T: StandardSample + ConvertTo<U> + 'static,
U: StandardSample + 'static,
{
const TARGET_CHUNK_BYTES: usize = 256 * 1024; let bytes_per_sample = U::BYTES as usize;
let bytes_per_sample = unsafe { NonZeroUsize::new_unchecked(bytes_per_sample) };
let samples_per_chunk = TARGET_CHUNK_BYTES / bytes_per_sample;
let samples_per_chunk = samples_per_chunk.max(self.channels as usize);
let samples_per_chunk = unsafe { NonZeroUsize::new_unchecked(samples_per_chunk) };
let mut buf = non_empty_vec![0u8; samples_per_chunk.checked_mul(bytes_per_sample).expect("Should not overflow")];
let mut total_bytes = 0usize;
for chunk in samples.chunks(samples_per_chunk.get()) {
let mut write_idx = 0;
for sample in chunk {
let converted: U = (*sample).convert_to();
let bytes = converted.to_le_bytes();
let dst = &mut buf[write_idx..write_idx + bytes_per_sample.get()];
dst.copy_from_slice(bytes.as_ref());
write_idx += bytes_per_sample.get();
}
self.writer.write_all(&buf[..write_idx])?;
total_bytes += write_idx;
}
Ok(total_bytes)
}
}
impl<W> Drop for StreamedWavWriter<W>
where
W: WriteSeek,
{
fn drop(&mut self) {
if !self.finalized && self.frames_written > 0 {
#[cfg(debug_assertions)]
eprintln!(
"Warning: StreamedWavWriter dropped without calling finalize(). \
The output file may have invalid headers."
);
}
}
}
#[cfg(test)]
mod tests {
use audio_samples::{channels, nzu};
use super::*;
use std::io::Cursor;
use std::num::NonZeroU32;
#[test]
fn test_streaming_writer_basic() {
let mut buffer = Vec::new();
{
let cursor = Cursor::new(&mut buffer);
let mut writer =
StreamedWavWriter::new_f32(cursor, 2, 44100).expect("Failed to create writer");
let sample_rate = NonZeroU32::new(44100).expect("Invalid sample rate");
let samples = AudioSamples::<f32>::zeros_multi(channels!(2), nzu!(1024), sample_rate);
let frames = writer.write_frames(&samples).expect("Write failed");
assert_eq!(frames, 1024);
assert_eq!(writer.frames_written(), 1024);
writer.finalize().expect("Finalize failed");
assert!(writer.is_finalized());
}
assert!(buffer.len() > 44); assert_eq!(&buffer[0..4], b"RIFF");
assert_eq!(&buffer[8..12], b"WAVE");
}
#[test]
fn test_streaming_writer_multiple_writes() {
let mut buffer = Vec::new();
let cursor = Cursor::new(&mut buffer);
let mut writer =
StreamedWavWriter::new_i16(cursor, 1, 22050).expect("Failed to create writer");
let sample_rate = NonZeroU32::new(22050).expect("Invalid sample rate");
let chunk1 = AudioSamples::<f32>::zeros_mono(nzu!(512), sample_rate);
let chunk2 = AudioSamples::<f32>::zeros_mono(nzu!(512), sample_rate);
writer.write_frames(&chunk1).expect("Write 1 failed");
writer.write_frames(&chunk2).expect("Write 2 failed");
assert_eq!(writer.frames_written(), 1024);
writer.finalize().expect("Finalize failed");
}
#[test]
fn test_streaming_writer_idempotent_finalize() {
let mut buffer = Vec::new();
let cursor = Cursor::new(&mut buffer);
let mut writer =
StreamedWavWriter::new_f32(cursor, 1, 44100).expect("Failed to create writer");
writer.finalize().expect("First finalize failed");
writer.finalize().expect("Second finalize should succeed");
assert!(writer.is_finalized());
}
#[test]
fn test_streaming_writer_channel_mismatch() {
let mut buffer = Vec::new();
let cursor = Cursor::new(&mut buffer);
let mut writer =
StreamedWavWriter::new_f32(cursor, 2, 44100).expect("Failed to create writer");
let sample_rate = NonZeroU32::new(44100).expect("Invalid sample rate");
let mono_samples = AudioSamples::<f32>::zeros_mono(nzu!(1024), sample_rate);
let result = writer.write_frames(&mono_samples);
assert!(result.is_err());
}
#[test]
fn test_streaming_writer_write_after_finalize() {
let mut buffer = Vec::new();
let cursor = Cursor::new(&mut buffer);
let mut writer =
StreamedWavWriter::new_f32(cursor, 1, 44100).expect("Failed to create writer");
writer.finalize().expect("Finalize failed");
let sample_rate = NonZeroU32::new(44100).expect("Invalid sample rate");
let samples = AudioSamples::<f32>::zeros_mono(nzu!(1024), sample_rate);
let result = writer.write_frames(&samples);
assert!(result.is_err());
}
}