use std::{
any::TypeId,
collections::HashMap,
fmt::{Debug, Display},
io::SeekFrom,
};
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
use crate::{
conversion::AudioSample,
core::{ReadSeek, WavInfo, WavType},
error::{WaversError, WaversResult},
};
const RIFF: [u8; 4] = *b"RIFF";
pub const DATA: [u8; 4] = *b"data";
const WAVE: [u8; 4] = *b"WAVE";
const FMT: [u8; 4] = *b"fmt ";
const RIFF_SIZE: usize = 12;
const FMT_SIZE: usize = 16;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct HeaderChunkInfo {
pub offset: usize,
pub size: u32,
}
impl Display for HeaderChunkInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "(offset: {}, size: {})", self.offset, self.size)
}
}
impl HeaderChunkInfo {
pub fn new(offset: usize, size: u32) -> Self {
HeaderChunkInfo { offset, size }
}
}
impl Into<(usize, u32)> for HeaderChunkInfo {
fn into(self) -> (usize, u32) {
(self.offset, self.size)
}
}
impl Into<(usize, u32)> for &HeaderChunkInfo {
fn into(self) -> (usize, u32) {
(self.offset, self.size)
}
}
#[cfg(feature = "pyo3")]
#[pyclass]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WavHeader {
header_info: HashMap<ChunkIdentifier, HeaderChunkInfo>,
pub fmt_chunk: FmtChunk,
pub current_file_size: usize, }
#[cfg(not(feature = "pyo3"))]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WavHeader {
pub header_info: HashMap<ChunkIdentifier, HeaderChunkInfo>,
pub fmt_chunk: FmtChunk,
pub current_file_size: usize, }
impl WavHeader {
pub fn new(
header_info: HashMap<ChunkIdentifier, HeaderChunkInfo>,
fmt_chunk: FmtChunk,
current_file_size: usize,
) -> Self {
assert!(
header_info.contains_key(&DATA.into()),
"Header info must contain a DATA chunk"
);
WavHeader {
header_info,
fmt_chunk,
current_file_size,
}
}
pub fn data(&self) -> &HeaderChunkInfo {
self.header_info.get(&DATA.into()).unwrap() }
pub fn fmt(&self) -> &HeaderChunkInfo {
self.header_info.get(&FMT.into()).unwrap() }
pub fn new_header<T>(sample_rate: i32, n_channels: u16, n_samples: usize) -> WaversResult<Self>
where
T: AudioSample,
{
let wav_type: WavType = TypeId::of::<T>().try_into()?;
let (format, bits_per_sample, _) = wav_type.into();
let fmt_chunk = FmtChunk::new(format, n_channels, sample_rate, bits_per_sample);
let mut header_info = HashMap::new();
let data_size_bytes = n_samples * (bits_per_sample / 8) as usize;
let file_size_bytes = data_size_bytes + 44;
header_info.insert(RIFF.into(), HeaderChunkInfo::new(0, RIFF_SIZE as u32));
header_info.insert(FMT.into(), HeaderChunkInfo::new(12, FMT_SIZE as u32));
header_info.insert(
DATA.into(),
HeaderChunkInfo::new(36, data_size_bytes as u32),
);
let current_file_size = file_size_bytes;
Ok(WavHeader {
header_info,
fmt_chunk,
current_file_size,
})
}
pub fn file_size(&self) -> usize {
self.current_file_size
}
pub fn as_bytes(&self) -> [u8; 36] {
let mut bytes = [0; 36]; bytes[0..4].copy_from_slice(&RIFF);
let size = self.file_size() as u32;
bytes[4..8].copy_from_slice(&size.to_ne_bytes());
bytes[8..12].copy_from_slice(&WAVE);
bytes[12..16].copy_from_slice(&FMT);
bytes[16..20].copy_from_slice(&(FMT_SIZE as u32).to_ne_bytes());
let fmt_bytes: [u8; FMT_SIZE] = self.fmt_chunk.into();
bytes[20..36].copy_from_slice(&fmt_bytes);
bytes
}
pub fn get_chunk(&self, chunk_identifier: ChunkIdentifier) -> Option<&HeaderChunkInfo> {
self.header_info.get(&chunk_identifier)
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub struct ChunkIdentifier {
identifier: [u8; 4],
}
impl ChunkIdentifier {
pub fn new(identifier: [u8; 4]) -> Self {
ChunkIdentifier { identifier }
}
}
impl Into<[u8; 4]> for ChunkIdentifier {
fn into(self) -> [u8; 4] {
self.identifier
}
}
impl Into<ChunkIdentifier> for [u8; 4] {
fn into(self) -> ChunkIdentifier {
ChunkIdentifier::new(self)
}
}
impl Display for ChunkIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let as_str: &str = match std::str::from_utf8(&self.identifier) {
Ok(s) => s,
Err(_) => "Invalid identifier",
};
write!(f, "{:?}", as_str)
}
}
impl Debug for ChunkIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let out_str = match std::str::from_utf8(&self.identifier) {
Ok(s) => s,
Err(_) => "Invalid identifier",
};
write!(f, "{}", out_str)
}
}
#[cfg(not(feature = "pyo3"))]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct FmtChunk {
pub format: u16,
pub channels: u16,
pub sample_rate: i32,
pub byte_rate: i32,
pub block_align: u16,
pub bits_per_sample: u16,
}
#[cfg(feature = "pyo3")]
#[pyclass]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct FmtChunk {
#[pyo3(get)]
pub format: u16,
#[pyo3(get)]
pub channels: u16,
#[pyo3(get)]
pub sample_rate: i32,
#[pyo3(get)]
pub byte_rate: i32,
#[pyo3(get)]
pub block_align: u16,
#[pyo3(get)]
pub bits_per_sample: u16,
}
impl FmtChunk {
pub fn new(format: u16, channels: u16, sample_rate: i32, bits_per_sample: u16) -> Self {
let block_align = (channels * bits_per_sample) / 8;
let byte_rate = sample_rate * (block_align as i32);
FmtChunk {
format,
channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
}
}
#[inline(always)]
pub fn update_fmt_chunk(&mut self, new_type: WavType) -> WaversResult<()> {
let current_type = WavType::try_from((self.format, self.bits_per_sample))?;
if current_type == new_type {
return Ok(());
}
let new_type_info: (u16, u16, u16) = new_type.into();
let (new_format, new_bits_per_sample, new_block_align) = new_type_info;
let new_byte_rate: i32 =
self.sample_rate * (self.channels as i32) * (new_block_align as i32);
self.format = new_format;
self.block_align = new_block_align;
self.byte_rate = new_byte_rate;
self.bits_per_sample = new_bits_per_sample;
Ok(())
}
}
impl Into<[u8; FMT_SIZE]> for FmtChunk {
fn into(self) -> [u8; FMT_SIZE] {
unsafe { std::mem::transmute_copy::<FmtChunk, [u8; FMT_SIZE]>(&self) }
}
}
impl Into<FmtChunk> for [u8; FMT_SIZE] {
fn into(self) -> FmtChunk {
unsafe { std::mem::transmute_copy::<[u8; FMT_SIZE], FmtChunk>(&self) }
}
}
pub(crate) fn read_header(readable: &mut Box<dyn ReadSeek>) -> WaversResult<WavInfo> {
readable.seek(SeekFrom::Start(0))?;
let header_info: HashMap<ChunkIdentifier, HeaderChunkInfo> =
discover_all_header_chunks(readable)?;
match header_info.contains_key(&FMT.into()) {
true => (),
false => {
return Err(WaversError::from(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"File does not contain a fmt chunk",
)));
}
}
let fmt_entry = header_info.get(&FMT.into()).unwrap(); readable.seek(SeekFrom::Start(fmt_entry.offset as u64))?;
let mut fmt_buf: [u8; FMT_SIZE] = [0; FMT_SIZE as usize];
readable.read_exact(&mut fmt_buf)?;
let fmt_chunk: FmtChunk = fmt_buf.into();
let wav_type = crate::core::WavType::try_from((fmt_chunk.format, fmt_chunk.bits_per_sample))?;
let total_size_in_bytes = header_info
.get(&DATA.into())
.expect("File does not contain a data chunk")
.size
+ 44; let wav_header = WavHeader::new(header_info, fmt_chunk, total_size_in_bytes as usize);
Ok(WavInfo {
wav_type,
wav_header,
})
}
fn discover_all_header_chunks(
reader: &mut Box<dyn ReadSeek>,
) -> WaversResult<HashMap<ChunkIdentifier, HeaderChunkInfo>> {
let mut entries: HashMap<ChunkIdentifier, HeaderChunkInfo> = HashMap::new();
let mut buf: [u8; 4] = [0; 4];
reader.read_exact(&mut buf)?;
match buf_eq(&RIFF, &buf) {
true => (),
false => {
return Err(WaversError::from(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"File is not a valid RIFF file",
)));
}
}
reader.read_exact(&mut buf)?;
entries.insert(RIFF.into(), HeaderChunkInfo::new(0, RIFF_SIZE as u32));
reader.read_exact(&mut buf)?;
let _: ChunkIdentifier = buf.into();
while let Ok(_) = reader.read_exact(&mut buf) {
let chunk_identifier: ChunkIdentifier = buf.into();
reader.read_exact(&mut buf)?;
let chunk_size: u32 =
buf[0] as u32 | (buf[1] as u32) << 8 | (buf[2] as u32) << 16 | (buf[3] as u32) << 24;
entries.insert(
chunk_identifier,
HeaderChunkInfo::new(reader.stream_position()? as usize, chunk_size),
);
reader.seek(SeekFrom::Current(chunk_size as i64))?;
}
Ok(entries)
}
#[inline(always)]
fn buf_eq(buf: &[u8; 4], chunk_id: &[u8; 4]) -> bool {
buf[0] == chunk_id[0] && buf[1] == chunk_id[1] && buf[2] == chunk_id[2] && buf[3] == chunk_id[3]
}
#[cfg(test)]
mod header_tests {
use super::*;
use crate::FmtChunk;
use std::fs::File;
const TEST_FILE: &str = "./test_resources/one_channel_i16.wav";
const ONE_CHANNEL_FMT_CHUNK: FmtChunk = FmtChunk {
format: 1,
channels: 1,
sample_rate: 16000,
byte_rate: 16000 * 2 * 1,
block_align: 2,
bits_per_sample: 16,
};
#[test]
fn can_read_header() {
let file = File::open(TEST_FILE).unwrap();
let mut file = Box::new(file) as Box<dyn ReadSeek>;
let wav_info = read_header(&mut file).expect("Failed to read header");
assert_eq!(
wav_info.wav_header.fmt_chunk, ONE_CHANNEL_FMT_CHUNK,
"Fmt chunk does not match"
);
}
#[test]
fn can_convert_to_and_from_bytes() {
let file = File::open(TEST_FILE).unwrap();
let mut file = Box::new(file) as Box<dyn ReadSeek>;
let wav_info = read_header(&mut file).expect("Failed to read header");
let fmt_bytes: [u8; FMT_SIZE] = wav_info.wav_header.fmt_chunk.into();
let new_fmt = fmt_bytes.into();
assert_eq!(
wav_info.wav_header.fmt_chunk, new_fmt,
"Fmt chunk does not match"
);
}
}