use std::io::Write;
use crate::error::AudioIOResult;
use crate::types::ValidatedSampleType;
use crate::wav::FormatCode;
const SUBFORMAT_GUID_TAIL: [u8; 12] = [0x00, 0x00, 0x10, 0x00, 0x80, 0x00, 0x00, 0xAA, 0x00, 0x38, 0x9B, 0x71];
pub const fn needs_extensible(channels: u16, sample_type: ValidatedSampleType) -> bool {
channels > 2 || matches!(sample_type, ValidatedSampleType::I24 | ValidatedSampleType::F64)
}
const fn fmt_body_len(channels: u16, sample_type: ValidatedSampleType) -> usize {
if needs_extensible(channels, sample_type) {
40
} else {
16
}
}
pub const fn wav_header_len(channels: u16, sample_type: ValidatedSampleType) -> usize {
12 + 8 + fmt_body_len(channels, sample_type) + 8
}
pub const fn wav_data_len(channels: u16, sample_type: ValidatedSampleType, total_frames: usize) -> usize {
total_frames * channels as usize * sample_type.bytes_per_sample().get()
}
pub const fn wav_file_len(channels: u16, sample_type: ValidatedSampleType, total_frames: usize) -> usize {
let data = wav_data_len(channels, sample_type, total_frames);
let padded = data + (data & 1);
wav_header_len(channels, sample_type) + padded
}
const fn channel_mask(channels: u16) -> u32 {
match channels {
1 => 0x4,
2 => 0x3,
3 => 0x7,
4 => 0x33,
5 => 0x37,
6 => 0x3F,
7 => 0x13F,
8 => 0x63F,
_ => {
if channels < 32 {
(1u32 << channels) - 1
} else {
0xFFFFFFFF
}
},
}
}
const fn format_code(sample_type: ValidatedSampleType) -> FormatCode {
match sample_type {
ValidatedSampleType::F32 | ValidatedSampleType::F64 => FormatCode::IeeeFloat,
_ => FormatCode::Pcm,
}
}
fn write_fmt_chunk<W: Write>(
w: &mut W,
channels: u16,
sample_rate: u32,
sample_type: ValidatedSampleType,
) -> AudioIOResult<()> {
let bits = sample_type.bits_per_sample().get() as u16;
let bytes = sample_type.bytes_per_sample().get() as u16;
let block_align = channels * bytes;
let byte_rate = sample_rate * block_align as u32;
let fc = format_code(sample_type);
if needs_extensible(channels, sample_type) {
w.write_all(b"fmt ")?;
w.write_all(&40u32.to_le_bytes())?;
w.write_all(&FormatCode::Extensible.as_u16().to_le_bytes())?;
w.write_all(&channels.to_le_bytes())?;
w.write_all(&sample_rate.to_le_bytes())?;
w.write_all(&byte_rate.to_le_bytes())?;
w.write_all(&block_align.to_le_bytes())?;
w.write_all(&bits.to_le_bytes())?;
w.write_all(&22u16.to_le_bytes())?; w.write_all(&bits.to_le_bytes())?; w.write_all(&channel_mask(channels).to_le_bytes())?;
w.write_all(&u32::from(fc.as_u16()).to_le_bytes())?;
w.write_all(&SUBFORMAT_GUID_TAIL)?;
} else {
w.write_all(b"fmt ")?;
w.write_all(&16u32.to_le_bytes())?;
w.write_all(&fc.as_u16().to_le_bytes())?;
w.write_all(&channels.to_le_bytes())?;
w.write_all(&sample_rate.to_le_bytes())?;
w.write_all(&byte_rate.to_le_bytes())?;
w.write_all(&block_align.to_le_bytes())?;
w.write_all(&bits.to_le_bytes())?;
}
Ok(())
}
pub fn build_wav_header(
channels: u16,
sample_rate: u32,
sample_type: ValidatedSampleType,
total_frames: usize,
) -> AudioIOResult<Vec<u8>> {
let data_size = wav_data_len(channels, sample_type, total_frames);
let padded = data_size + (data_size & 1);
let fmt_total = 8 + fmt_body_len(channels, sample_type);
let file_size = 4 + fmt_total + 8 + padded;
let mut header = Vec::with_capacity(wav_header_len(channels, sample_type));
header.extend_from_slice(b"RIFF");
header.extend_from_slice(&(file_size as u32).to_le_bytes());
header.extend_from_slice(b"WAVE");
write_fmt_chunk(&mut header, channels, sample_rate, sample_type)?;
header.extend_from_slice(b"data");
header.extend_from_slice(&(data_size as u32).to_le_bytes());
Ok(header)
}
pub fn build_wav_header_infinite(
channels: u16,
sample_rate: u32,
sample_type: ValidatedSampleType,
) -> AudioIOResult<Vec<u8>> {
let mut header = Vec::with_capacity(wav_header_len(channels, sample_type));
header.extend_from_slice(b"RIFF");
header.extend_from_slice(&0xFFFF_FFFFu32.to_le_bytes());
header.extend_from_slice(b"WAVE");
write_fmt_chunk(&mut header, channels, sample_rate, sample_type)?;
header.extend_from_slice(b"data");
header.extend_from_slice(&0xFFFF_FFFFu32.to_le_bytes());
Ok(header)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn header_len_matches_built_header() {
for (ch, st) in [
(1u16, ValidatedSampleType::I16),
(2, ValidatedSampleType::F32),
(1, ValidatedSampleType::I24), (6, ValidatedSampleType::I16), ] {
let h = build_wav_header(ch, 44_100, st, 100).expect("build header");
assert_eq!(h.len(), wav_header_len(ch, st), "ch={ch} st={st:?}");
}
}
#[test]
fn file_len_accounts_for_pad_byte() {
let st = ValidatedSampleType::U8;
let total = wav_file_len(1, st, 3);
assert_eq!(total % 2, 0);
assert_eq!(total, wav_header_len(1, st) + 4); }
#[test]
fn header_declares_expected_sizes() {
let h = build_wav_header(2, 48_000, ValidatedSampleType::I16, 10).expect("build header");
let data_size = u32::from_le_bytes([h[h.len() - 4], h[h.len() - 3], h[h.len() - 2], h[h.len() - 1]]);
assert_eq!(data_size, 40);
let riff_size = u32::from_le_bytes([h[4], h[5], h[6], h[7]]);
assert_eq!(riff_size as usize, wav_file_len(2, ValidatedSampleType::I16, 10) - 8);
}
}