use core::fmt::{Display, Formatter, Result as FmtResult};
use audio_samples::SampleType;
use crate::{
types::ValidatedSampleType,
wav::{FormatCode, error::WavError},
};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FmtChunk<'a> {
Base(&'a [u8; 16]),
Extensible(&'a [u8; 40]),
Extended(&'a [u8]),
}
impl<'a> FmtChunk<'a> {
pub fn from_bytes(bytes: &'a [u8]) -> Result<Self, WavError> {
match bytes.len() {
16 => {
let b: &[u8; 16] = bytes
.try_into()
.map_err(|_| WavError::InvalidFmtChunkSize(bytes.len()))?;
Ok(FmtChunk::Base(b))
},
40 => {
let b: &[u8; 40] = bytes
.try_into()
.map_err(|_| WavError::InvalidFmtChunkSize(bytes.len()))?;
Ok(FmtChunk::Extensible(b))
},
len if len >= 18 => Ok(FmtChunk::Extended(bytes)),
len => Err(WavError::InvalidFmtChunkSize(len)),
}
}
pub fn from_bytes_validated(bytes: &'a [u8]) -> Result<Self, WavError> {
let fmt_chunk = Self::from_bytes(bytes)?;
fmt_chunk.validate_format_consistency()?;
Ok(fmt_chunk)
}
pub const fn as_bytes(&self) -> &[u8] {
match self {
FmtChunk::Base(slice) => *slice,
FmtChunk::Extensible(slice) => *slice,
FmtChunk::Extended(slice) => slice,
}
}
pub const fn try_into_base(&'a self) -> Option<&'a [u8; 16]> {
match self {
FmtChunk::Base(bytes) => Some(bytes),
FmtChunk::Extensible(_) | FmtChunk::Extended(_) => None,
}
}
pub const fn try_into_extensible(&'a self) -> Option<&'a [u8; 40]> {
match self {
FmtChunk::Base(_) | FmtChunk::Extended(_) => None,
FmtChunk::Extensible(bytes) => Some(bytes),
}
}
pub const fn format_code(&self) -> FormatCode {
let bytes = self.as_bytes();
FormatCode::const_from(u16::from_le_bytes([bytes[0], bytes[1]]))
}
pub const fn channels(&self) -> u16 {
let bytes = self.as_bytes();
u16::from_le_bytes([bytes[2], bytes[3]])
}
pub const fn sample_rate(&self) -> u32 {
let bytes = self.as_bytes();
u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]])
}
pub const fn byte_rate(&self) -> u32 {
let bytes = self.as_bytes();
u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]])
}
pub const fn block_align(&self) -> u16 {
let bytes = self.as_bytes();
u16::from_le_bytes([bytes[12], bytes[13]])
}
pub const fn bits_per_sample(&self) -> u16 {
let bytes = self.as_bytes();
u16::from_le_bytes([bytes[14], bytes[15]])
}
pub const fn bytes_per_sample(&self) -> u16 {
self.bits_per_sample() / 8
}
pub const fn fmt_chunk(&self) -> (FormatCode, u16, u32, u32, u16, u16) {
(
self.format_code(),
self.channels(),
self.sample_rate(),
self.byte_rate(),
self.block_align(),
self.bits_per_sample(),
)
}
pub fn extended_bytes(&'a self) -> Option<&'a [u8; 24]> {
match self {
FmtChunk::Base(_) | FmtChunk::Extended(_) => None,
FmtChunk::Extensible(bytes) => {
let b: &[u8; 24] = bytes[16..40]
.try_into()
.expect("Guaranteed by enum variant and constructor");
Some(b)
},
}
}
pub const fn subformat(&'a self) -> Result<Option<(FormatCode, SampleType)>, WavError> {
match self {
FmtChunk::Base(_) | FmtChunk::Extended(_) => Ok(None),
FmtChunk::Extensible(bytes) => {
let format_code = FormatCode::const_from(u16::from_le_bytes([bytes[24], bytes[25]]));
let bits_per_sample = self.bits_per_sample();
let sample_type = SampleType::from_bits(bits_per_sample);
Ok(Some((format_code, sample_type)))
},
}
}
pub fn actual_sample_type(&'a self) -> Result<ValidatedSampleType, WavError> {
let bits_per_sample = self.bits_per_sample();
if let Some((format_code, _)) = self.subformat()? {
return match format_code {
FormatCode::IeeeFloat => match bits_per_sample {
32 => Ok(ValidatedSampleType::F32),
64 => Ok(ValidatedSampleType::F64),
_ => Err(WavError::UnsupportedSampleType),
},
FormatCode::ALaw | FormatCode::MuLaw | FormatCode::MsAdpcm | FormatCode::ImaAdpcm => {
Ok(ValidatedSampleType::I16)
},
_ => ValidatedSampleType::try_from(SampleType::from_bits(bits_per_sample))
.map_err(|_| WavError::UnsupportedSampleType),
};
}
match self.format_code() {
FormatCode::IeeeFloat => match bits_per_sample {
32 => Ok(ValidatedSampleType::F32),
64 => Ok(ValidatedSampleType::F64),
_ => Err(WavError::UnsupportedSampleType),
},
FormatCode::ALaw | FormatCode::MuLaw | FormatCode::MsAdpcm | FormatCode::ImaAdpcm => {
Ok(ValidatedSampleType::I16)
},
_ => ValidatedSampleType::try_from(SampleType::from_bits(bits_per_sample))
.map_err(|_| WavError::UnsupportedSampleType),
}
}
pub fn companding(&'a self) -> Option<crate::wav::Companding> {
let format_code = match self.subformat() {
Ok(Some((sub_format, _))) => sub_format,
_ => self.format_code(),
};
crate::wav::Companding::from_format(format_code)
}
pub fn validate_format_consistency(&self) -> Result<(), WavError> {
if self.format_code().is_adpcm() {
if self.channels() == 0 {
return Err(WavError::invalid_format("Channels cannot be zero"));
}
if self.sample_rate() == 0 {
return Err(WavError::invalid_format("Sample rate cannot be zero"));
}
if self.block_align() == 0 {
return Err(WavError::invalid_format("Block align cannot be zero"));
}
return Ok(());
}
let channels = self.channels();
let sample_rate = self.sample_rate();
let byte_rate = self.byte_rate();
let block_align = self.block_align();
let bits_per_sample = self.bits_per_sample();
if channels == 0 {
return Err(WavError::invalid_format("Channels cannot be zero"));
}
if sample_rate == 0 {
return Err(WavError::invalid_format("Sample rate cannot be zero"));
}
if byte_rate == 0 {
return Err(WavError::invalid_format("Byte rate cannot be zero"));
}
if block_align == 0 {
return Err(WavError::invalid_format("Block align cannot be zero"));
}
if bits_per_sample == 0 {
return Err(WavError::invalid_format("Bits per sample cannot be zero"));
}
if !bits_per_sample.is_multiple_of(8) {
return Err(WavError::invalid_format(format!(
"Bits per sample {bits_per_sample} is not byte-aligned"
)));
}
let bytes_per_sample = bits_per_sample / 8;
let expected_block_align = channels * bytes_per_sample;
if block_align != expected_block_align {
return Err(WavError::invalid_format(format!(
"Block align {block_align} does not match expected {expected_block_align} (channels {channels} * bytes_per_sample {bytes_per_sample})"
)));
}
let expected_byte_rate = sample_rate * block_align as u32;
if byte_rate != expected_byte_rate {
return Err(WavError::invalid_format(format!(
"Byte rate {byte_rate} does not match expected {expected_byte_rate} (sample_rate {sample_rate} * block_align {block_align})"
)));
}
if channels > 256 {
return Err(WavError::invalid_format(format!(
"Too many channels: {channels} (maximum 256)"
)));
}
if sample_rate > 384000 {
return Err(WavError::invalid_format(format!(
"Sample rate too high: {sample_rate} Hz (maximum 384000)"
)));
}
if bits_per_sample > 64 {
return Err(WavError::invalid_format(format!(
"Bits per sample too high: {bits_per_sample} (maximum 64)"
)));
}
let _ = self.actual_sample_type()?;
Ok(())
}
}
impl Display for FmtChunk<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
let (format, channels, sample_rate, byte_rate, block_align, bits_per_sample) = self.fmt_chunk();
write!(
f,
"FmtChunk {{ format: {format:?}, channels: {channels}, sample_rate: {sample_rate}, byte_rate: {byte_rate}, block_align: {block_align}, bits_per_sample: {bits_per_sample} }}"
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_base_fmt_bytes(
format_code: u16,
channels: u16,
sample_rate: u32,
byte_rate: u32,
block_align: u16,
bits_per_sample: u16,
) -> [u8; 16] {
let mut bytes = [0u8; 16];
bytes[0..2].copy_from_slice(&format_code.to_le_bytes());
bytes[2..4].copy_from_slice(&channels.to_le_bytes());
bytes[4..8].copy_from_slice(&sample_rate.to_le_bytes());
bytes[8..12].copy_from_slice(&byte_rate.to_le_bytes());
bytes[12..14].copy_from_slice(&block_align.to_le_bytes());
bytes[14..16].copy_from_slice(&bits_per_sample.to_le_bytes());
bytes
}
#[test]
fn test_fmt_validate_rejects_zero_channels() {
let bytes = make_base_fmt_bytes(1, 0, 44_100, 176_400, 4, 16);
let fmt = FmtChunk::from_bytes(&bytes).expect("Failed to create FmtChunk");
let err = fmt
.validate_format_consistency()
.expect_err("Expected validation to fail");
assert!(err.to_string().contains("Channels cannot be zero"));
}
#[test]
fn test_fmt_validate_rejects_block_align_mismatch() {
let bytes = make_base_fmt_bytes(1, 2, 44_100, 176_400, 2, 16);
let fmt = FmtChunk::from_bytes(&bytes).expect("Failed to create FmtChunk");
let err = fmt
.validate_format_consistency()
.expect_err("Expected validation to fail");
assert!(err.to_string().contains("Block align 2 does not match expected 4"));
}
#[test]
fn test_fmt_validate_rejects_byte_rate_mismatch() {
let bytes = make_base_fmt_bytes(1, 2, 48_000, 1_000, 4, 16);
let fmt = FmtChunk::from_bytes(&bytes).expect("Failed to create FmtChunk");
let err = fmt
.validate_format_consistency()
.expect_err("Expected validation to fail");
assert!(
err.to_string()
.contains("Byte rate 1000 does not match expected 192000")
);
}
#[test]
fn test_fmt_validate_rejects_non_byte_aligned_bits() {
let bytes = make_base_fmt_bytes(1, 1, 44_100, 132_300, 3, 12);
let fmt = FmtChunk::from_bytes(&bytes).expect("Failed to create FmtChunk");
let err = fmt
.validate_format_consistency()
.expect_err("Expected validation to fail");
assert!(err.to_string().contains("Bits per sample 12 is not byte-aligned"));
}
#[test]
fn test_fmt_validate_rejects_excess_channels() {
let channels = 300u16;
let bits_per_sample = 16u16;
let bytes_per_sample = bits_per_sample / 8; let block_align = channels * bytes_per_sample; let sample_rate = 44_100u32;
let byte_rate = sample_rate * block_align as u32; let bytes = make_base_fmt_bytes(1, channels, sample_rate, byte_rate, block_align, bits_per_sample);
let fmt = FmtChunk::from_bytes(&bytes).expect("Failed to create FmtChunk");
let err = fmt
.validate_format_consistency()
.expect_err("Expected validation to fail");
assert!(err.to_string().contains("Too many channels"));
}
fn make_extensible_fmt_bytes(
sub_format_code: u16,
channels: u16,
sample_rate: u32,
bits_per_sample: u16,
) -> [u8; 40] {
let bytes_per_sample = bits_per_sample / 8;
let block_align = channels * bytes_per_sample;
let byte_rate = sample_rate * block_align as u32;
let mut bytes = [0u8; 40];
bytes[0..2].copy_from_slice(&0xFFFEu16.to_le_bytes()); bytes[2..4].copy_from_slice(&channels.to_le_bytes());
bytes[4..8].copy_from_slice(&sample_rate.to_le_bytes());
bytes[8..12].copy_from_slice(&byte_rate.to_le_bytes());
bytes[12..14].copy_from_slice(&block_align.to_le_bytes());
bytes[14..16].copy_from_slice(&bits_per_sample.to_le_bytes());
bytes[16..18].copy_from_slice(&22u16.to_le_bytes()); bytes[18..20].copy_from_slice(&bits_per_sample.to_le_bytes()); bytes[20..24].copy_from_slice(&0u32.to_le_bytes()); bytes[24..28].copy_from_slice(&(sub_format_code as u32).to_le_bytes());
bytes[28..30].copy_from_slice(&0u16.to_le_bytes()); bytes[30..32].copy_from_slice(&0x0010u16.to_le_bytes()); bytes[32..40].copy_from_slice(&[0x80, 0x00, 0x00, 0xAA, 0x00, 0x38, 0x9B, 0x71]); bytes
}
#[test]
fn test_subformat_extensible_pcm_identified_correctly() {
let bytes = make_extensible_fmt_bytes(0x0001, 1, 48000, 32);
let fmt = FmtChunk::from_bytes(&bytes).expect("Failed to create extensible FmtChunk");
let actual = fmt.actual_sample_type().expect("actual_sample_type failed");
assert_eq!(actual, ValidatedSampleType::I32);
}
#[test]
fn test_subformat_extensible_ieee_float_identified_correctly() {
let bytes = make_extensible_fmt_bytes(0x0003, 1, 44100, 32);
let fmt = FmtChunk::from_bytes(&bytes).expect("Failed to create extensible FmtChunk");
let actual = fmt.actual_sample_type().expect("actual_sample_type failed");
assert_eq!(actual, ValidatedSampleType::F32);
}
#[test]
fn test_subformat_extensible_24bit_pcm_identified_correctly() {
let bytes = make_extensible_fmt_bytes(0x0001, 1, 192_000, 24);
let fmt = FmtChunk::from_bytes(&bytes).expect("Failed to create extensible FmtChunk");
let actual = fmt.actual_sample_type().expect("actual_sample_type failed");
assert_eq!(actual, ValidatedSampleType::I24);
}
}