use std::{
alloc::Layout,
fs::File,
io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write},
path::Path,
u8,
};
use crate::sample::{IterAudioConversion, Sample};
const RIFF: &[u8; 4] = b"RIFF";
const DATA: &[u8; 4] = b"data";
const WAVE: &[u8; 4] = b"WAVE";
const FMT: &[u8; 4] = b"fmt ";
#[allow(unused)] const LIST: &[u8; 4] = b"LIST";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WavFile {
pub fmt_chunk: FmtChunk,
pub data: Box<[u8]>,
pub seek_pos: u64,
}
impl WavFile {
pub fn new(fmt_chunk: FmtChunk, data: Box<[u8]>, seek_pos: u64) -> WavFile {
WavFile {
fmt_chunk,
data,
seek_pos,
}
}
pub fn from_file(fp: &Path) -> Result<WavFile, std::io::Error> {
let file = File::open(fp)?;
let mut buf_reader = std::io::BufReader::new(file);
let fmt_chunk = FmtChunk::from_buf_reader(&mut buf_reader)?;
let (data_offset, data_len) = find_sub_chunk_id(&mut buf_reader, &b"data")?;
let mut data = alloc_box_buffer(data_len);
buf_reader.seek(SeekFrom::Start(data_offset as u64 + 4))?;
match buf_reader.read(&mut data) {
Ok(_) => (),
Err(err) => {
eprintln!("Error reading data chunk: {}", err);
return Err(err);
}
}
Ok(WavFile::new(fmt_chunk, data, 0))
}
#[inline]
pub fn read(&self, as_wav_type: Option<Sample>) -> Vec<Sample> {
let bits_per_sample = self.bits_per_sample(); let base_format = self.format();
let sample_format = match base_format {
1 => match bits_per_sample {
16 => Sample::I16(0),
32 => Sample::I32(0),
_ => panic!("Unsupported bit depth for PCM: {}", bits_per_sample),
},
3 => match bits_per_sample {
32 => Sample::F32(0.0),
64 => Sample::F64(0.0),
_ => panic!("Unsupported bit depth for float: {}", bits_per_sample),
},
_ => panic!("Unsupported format: {}", base_format),
};
let (mut data, data_type) = match sample_format {
Sample::I16(_) => (self.read_pcm_i16(), Sample::I16(0)),
Sample::I32(_) => (self.read_pcm_i32(), Sample::I32(0)),
Sample::F32(_) => (self.read_ieee_f32(), Sample::F32(0.0)),
Sample::F64(_) => (self.read_ieee_f64(), Sample::F64(0.0)),
};
match as_wav_type {
Some(dtype) => {
if dtype == data_type {
data
} else {
data.as_sample_type(dtype)
}
}
None => data,
}
}
pub fn write_wav(&self, fp: &Path) -> Result<(), std::io::Error> {
let file = File::create(fp)?;
let mut buf_writer = BufWriter::new(file);
buf_writer.write(RIFF)?;
buf_writer.write(&(self.data.len() as u32 + 36).to_ne_bytes())?;
buf_writer.write(WAVE)?;
buf_writer.write(FMT)?;
buf_writer.write_all(&self.fmt_chunk.as_bytes())?;
buf_writer.write(DATA)?;
buf_writer.write(&self.data.len().to_ne_bytes())?;
buf_writer.write_all(self.data.as_ref())?;
Ok(())
}
#[inline]
fn read_pcm_i16(&self) -> Vec<Sample> {
let n_channels = self.fmt_chunk.channels as usize;
let mut channel_data: Vec<Sample> =
Vec::with_capacity((self.data.len() / 2) - self.seek_pos as usize);
unsafe {
channel_data.set_len((self.data.len() / 2) - self.seek_pos as usize);
}
let mut idx = 0;
let iter_step = 2 * n_channels;
for samples in self.data.chunks(iter_step) {
unsafe {
for channel_sample in
samples.as_chunks_unchecked::<{ std::mem::size_of::<i16>() }>()
{
channel_data[idx] = Sample::I16(i16::from_ne_bytes(*channel_sample));
idx += 1;
}
}
}
channel_data
}
fn read_pcm_i32(&self) -> Vec<Sample> {
let n_channels = self.fmt_chunk.channels as usize;
let mut channel_data: Vec<Sample> =
Vec::with_capacity((self.data.len() / 4) - self.seek_pos as usize); unsafe {
channel_data.set_len((self.data.len() / 4) - self.seek_pos as usize);
}
let mut idx = 0;
let iter_step = 4 * n_channels;
for samples in self.data.chunks(iter_step) {
unsafe {
for channel_sample in
samples.as_chunks_unchecked::<{ std::mem::size_of::<i32>() }>()
{
channel_data[idx] = Sample::I32(i32::from_ne_bytes(*channel_sample));
idx += 1;
}
}
}
channel_data
}
fn read_ieee_f32(&self) -> Vec<Sample> {
let n_channels = self.fmt_chunk.channels as usize;
let mut channel_data: Vec<Sample> =
Vec::with_capacity((self.data.len() / 4) - self.seek_pos as usize); unsafe {
channel_data.set_len((self.data.len() / 4) - self.seek_pos as usize);
}
let mut idx = 0;
let iter_step = 4 * n_channels;
for samples in self.data.chunks(iter_step) {
unsafe {
for channel_sample in
samples.as_chunks_unchecked::<{ std::mem::size_of::<f32>() }>()
{
channel_data[idx] = Sample::F32(f32::from_ne_bytes(*channel_sample));
idx += 1;
}
}
}
channel_data
}
fn read_ieee_f64(&self) -> Vec<Sample> {
let n_channels = self.fmt_chunk.channels as usize;
let mut channel_data: Vec<Sample> =
Vec::with_capacity((self.data.len() / 8) - self.seek_pos as usize); unsafe {
channel_data.set_len((self.data.len() / 8) - self.seek_pos as usize);
}
let mut idx = 0;
let iter_step = 8 * n_channels;
for samples in self.data.chunks(iter_step) {
unsafe {
for channel_sample in
samples.as_chunks_unchecked::<{ std::mem::size_of::<f64>() }>()
{
channel_data[idx] = Sample::F64(f64::from_ne_bytes(*channel_sample));
idx += 1;
}
}
}
channel_data
}
pub fn duration(&self) -> u64 {
self.data_size() as u64
/ (self.sample_rate() * self.channels() as i32 * (self.bits_per_sample() / 8) as i32)
as u64
}
pub fn sample_rate(&self) -> i32 {
self.fmt_chunk.sample_rate()
}
pub fn channels(&self) -> u16 {
self.fmt_chunk.channels()
}
fn bits_per_sample(&self) -> u16 {
self.fmt_chunk.bits_per_sample()
}
fn format(&self) -> u16 {
self.fmt_chunk.format()
}
fn data_size(&self) -> usize {
self.data.len() - self.seek_pos as usize
}
}
pub fn signal_duration(signal_fp: &Path) -> Result<u64, std::io::Error> {
let wav_file = File::open(signal_fp)?;
let mut br = BufReader::new(wav_file);
let fmt_chunk = FmtChunk::from_buf_reader(&mut br)?;
let (data_offset, _) = find_sub_chunk_id(&mut br, &b"data")?;
let mut data_size_buf: [u8; 4] = [0; 4];
br.seek(SeekFrom::Start(data_offset as u64))?;
br.read_exact(&mut data_size_buf)?;
Ok(i32::from_ne_bytes(data_size_buf) as u64
/ (fmt_chunk.sample_rate()
* fmt_chunk.channels() as i32
* (fmt_chunk.bits_per_sample() / 8) as i32) as u64)
}
pub fn signal_sample_rate(signal_fp: &Path) -> Result<i32, std::io::Error> {
let wav_file = File::open(signal_fp)?;
let mut br = BufReader::new(wav_file);
let fmt_chunk = FmtChunk::from_buf_reader(&mut br)?;
Ok(fmt_chunk.sample_rate())
}
pub fn signal_channels(signal_fp: &Path) -> Result<u16, std::io::Error> {
let wav_file = File::open(signal_fp)?;
let mut br = BufReader::new(wav_file);
let fmt_chunk = FmtChunk::from_buf_reader(&mut br)?;
Ok(fmt_chunk.channels())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SignalInfo {
pub sample_rate: i32,
pub channels: u16,
pub bits_per_sample: u16,
pub duration: u64,
}
impl SignalInfo {
pub fn new(sample_rate: i32, channels: u16, bits_per_sample: u16, duration: u64) -> Self {
Self {
sample_rate,
channels,
bits_per_sample,
duration,
}
}
}
pub fn signal_info(signal_fp: &Path) -> Result<SignalInfo, std::io::Error> {
let wav_file = File::open(signal_fp)?;
let mut br = BufReader::new(wav_file);
let fmt_chunk = FmtChunk::from_buf_reader(&mut br)?;
let (data_offset, _) = find_sub_chunk_id(&mut br, &b"data")?;
let mut data_size_buf: [u8; 4] = [0; 4];
br.seek(SeekFrom::Start(data_offset as u64))?;
br.read_exact(&mut data_size_buf)?;
Ok(SignalInfo::new(
fmt_chunk.sample_rate(),
fmt_chunk.channels(),
fmt_chunk.bits_per_sample(),
i32::from_ne_bytes(data_size_buf) as u64
/ (fmt_chunk.sample_rate()
* fmt_chunk.channels() as i32
* (fmt_chunk.bits_per_sample() / 8) as i32) as u64,
))
}
#[inline]
pub fn read(fp: &Path, as_type: Option<Sample>) -> Result<Vec<Sample>, std::io::Error> {
let wav_file = WavFile::from_file(fp)?;
Ok(wav_file.read(as_type))
}
pub fn write_wav_as(
fp: &Path,
data: &mut Vec<Sample>,
as_type: Option<Sample>,
n_channels: u16,
sample_rate: i32,
) -> Result<(), std::io::Error> {
let file = File::create(fp)?;
let mut buf_writer = BufWriter::new(file);
let sample_type = match as_type {
Some(t) => t,
None => data[0],
};
let byte_rate = sample_rate * n_channels as i32 * data[0].size_of_underlying() as i32;
let (data_len, block_align, format, bits_per_sample) = match sample_type {
Sample::I16(_) => (data.len() * 2, 2 as u16, 1, 16),
Sample::I32(_) => (data.len() * 4, 4 as u16, 1, 32),
Sample::F32(_) => (data.len() * 4, 4 as u16, 3, 32),
Sample::F64(_) => (data.len() * 8, 8 as u16, 3, 64),
};
let fmt_bytes = FmtChunk::new(
16,
format,
n_channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
)
.as_bytes();
buf_writer.write(RIFF)?; buf_writer.write(&((data_len as i32 + 36).to_ne_bytes()))?; buf_writer.write(WAVE)?; buf_writer.write_all(&fmt_bytes)?;
buf_writer.write(DATA)?;
buf_writer.write(&(data_len as u32).to_ne_bytes())?;
match sample_type {
Sample::I16(_) => {
data.as_i16().iter().for_each(|sample| {
buf_writer.write_all(&sample.to_ne_bytes()).unwrap();
});
}
Sample::I32(_) => {
data.as_i32().iter().for_each(|sample| {
buf_writer.write_all(&sample.to_ne_bytes()).unwrap();
});
}
Sample::F32(_) => {
data.as_f32().iter().for_each(|sample| {
buf_writer.write_all(&sample.to_ne_bytes()).unwrap();
});
}
Sample::F64(_) => {
data.as_f64().iter().for_each(|sample| {
buf_writer.write_all(&sample.to_ne_bytes()).unwrap();
});
}
}
Ok(())
}
pub fn alloc_box_buffer(len: usize) -> Box<[u8]> {
if len == 0 {
return <Box<[u8]>>::default();
}
let layout = match Layout::array::<u8>(len) {
Ok(layout) => layout,
Err(_) => panic!("Failed to allocate buffer of size {}", len),
};
let ptr = unsafe { std::alloc::alloc(layout) };
let slice_ptr = core::ptr::slice_from_raw_parts_mut(ptr, len);
unsafe { Box::from_raw(slice_ptr) }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct FmtChunk {
pub size: i32,
pub format: u16,
pub channels: u16,
pub sample_rate: i32,
pub byte_rate: i32,
pub block_align: u16,
pub bits_per_sample: u16,
}
impl FmtChunk {
pub fn new(
size: i32, format: u16, channels: u16, sample_rate: i32, byte_rate: i32, block_align: u16, bits_per_sample: u16, ) -> FmtChunk {
FmtChunk {
size,
format,
channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
}
}
pub fn from_path(signal_fp: &Path) -> Result<FmtChunk, std::io::Error> {
let wav_file = File::open(signal_fp)?;
let mut br = BufReader::new(wav_file);
FmtChunk::from_buf_reader(&mut br)
}
fn from_buf_reader(br: &mut BufReader<File>) -> Result<FmtChunk, std::io::Error> {
let mut buf: [u8; 4] = [0; 4];
let mut buf_two: [u8; 2] = [0; 2];
let (offset, _) = find_sub_chunk_id(br, b"fmt ")?;
br.seek(SeekFrom::Start(offset as u64))?;
br.read_exact(&mut buf)?;
let size = i32::from_ne_bytes(buf);
br.read_exact(&mut buf_two)?;
let format = u16::from_ne_bytes(buf_two);
br.read_exact(&mut buf_two)?;
let channels = u16::from_ne_bytes(buf_two);
br.read_exact(&mut buf)?;
let sample_rate = i32::from_ne_bytes(buf);
br.read_exact(&mut buf)?;
let byte_rate = i32::from_ne_bytes(buf);
br.read_exact(&mut buf_two)?;
let block_align = u16::from_ne_bytes(buf_two);
br.read_exact(&mut buf_two)?;
let bits_per_sample = u16::from_ne_bytes(buf_two);
br.seek(SeekFrom::Start(0))?;
Ok(FmtChunk::new(
size,
format,
channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
))
}
pub fn as_bytes(&self) -> [u8; 24] {
let mut buf: [u8; 24] = [0; 24];
buf[0..4].copy_from_slice(FMT);
buf[4..8].copy_from_slice(&self.size.to_ne_bytes());
buf[8..10].copy_from_slice(&self.format.to_ne_bytes());
buf[10..12].copy_from_slice(&self.channels.to_ne_bytes());
buf[12..16].copy_from_slice(&self.sample_rate.to_ne_bytes());
buf[16..20].copy_from_slice(&self.byte_rate.to_ne_bytes());
buf[20..22].copy_from_slice(&self.block_align.to_ne_bytes());
buf[22..24].copy_from_slice(&self.bits_per_sample.to_ne_bytes());
buf
}
pub fn get_sample_size(&self) -> usize {
self.bits_per_sample as usize / 8
}
pub fn format(&self) -> u16 {
self.format
}
pub fn channels(&self) -> u16 {
self.channels
}
pub fn sample_rate(&self) -> i32 {
self.sample_rate
}
pub fn byte_rate(&self) -> i32 {
self.byte_rate
}
pub fn block_align(&self) -> u16 {
self.block_align
}
pub fn bits_per_sample(&self) -> u16 {
self.bits_per_sample
}
}
pub fn find_sub_chunk_id(
file: &mut BufReader<File>,
chunk_id: &[u8; 4],
) -> Result<(usize, usize), std::io::Error> {
let mut buf: [u8; 4] = [0; 4];
file.read_exact(&mut buf)?;
if !buf_eq(&buf, RIFF) {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to find RIFF tag in {:?}", file.get_ref()),
));
}
file.seek(SeekFrom::Current(8))?;
let mut tag_offset: usize = 0;
let mut bytes_traversed: usize = 12;
loop {
let bytes_read = file.read(&mut buf)?;
if bytes_read == 0 {
break;
}
bytes_traversed += bytes_read;
if buf_eq(&buf, chunk_id) {
tag_offset = bytes_traversed;
}
let bytes_read = file.read(&mut buf)?;
if bytes_read == 0 {
break;
}
bytes_traversed += bytes_read;
let chunk_len =
buf[0] as u32 | (buf[1] as u32) << 8 | (buf[2] as u32) << 16 | (buf[3] as u32) << 24;
if tag_offset > 0 {
let chunk_size = chunk_len as usize;
file.seek(SeekFrom::Start(0))?; return Ok((tag_offset, chunk_size));
}
file.seek(SeekFrom::Current(chunk_len as i64))?;
bytes_traversed += chunk_len as usize;
}
file.seek(SeekFrom::Start(0))?;
Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!(
"Failed to find {:?} tag in {:?}",
std::str::from_utf8(chunk_id).unwrap(),
file.get_ref()
),
))
}
fn buf_eq(buf: &[u8; 4], chunk_id: &[u8; 4]) -> bool {
buf[0] == chunk_id[0] && buf[1] == chunk_id[1] && buf[2] == chunk_id[2] && buf[3] == chunk_id[3]
}
#[cfg(feature = "ndarray")]
use ndarray::{Array2, ShapeError};
#[cfg(feature = "ndarray")]
pub trait IntoArray<T> {
fn into_array(self, n_channels: usize) -> Result<Array2<T>, ShapeError>;
}
#[cfg(feature = "ndarray")]
impl IntoArray<Sample> for Vec<Sample> {
fn into_array(self, n_channels: usize) -> Result<Array2<Sample>, ShapeError> {
Array2::from_shape_vec((self.len() / n_channels, n_channels), self)
}
}
#[cfg(feature = "ndarray")]
impl IntoArray<i16> for Vec<i16> {
fn into_array(self, n_channels: usize) -> Result<Array2<i16>, ShapeError> {
Array2::from_shape_vec((self.len() / n_channels, n_channels), self)
}
}
#[cfg(feature = "ndarray")]
impl IntoArray<i32> for Vec<i32> {
fn into_array(self, n_channels: usize) -> Result<Array2<i32>, ShapeError> {
Array2::from_shape_vec((self.len() / n_channels, n_channels), self)
}
}
#[cfg(feature = "ndarray")]
impl IntoArray<f32> for Vec<f32> {
fn into_array(self, n_channels: usize) -> Result<Array2<f32>, ShapeError> {
Array2::from_shape_vec((self.len() / n_channels, n_channels), self)
}
}
#[cfg(feature = "ndarray")]
impl IntoArray<f64> for Vec<f64> {
fn into_array(self, n_channels: usize) -> Result<Array2<f64>, ShapeError> {
Array2::from_shape_vec((self.len() / n_channels, n_channels), self)
}
}