use std::{
alloc::Layout,
fs::File,
io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write},
path::Path,
u8,
};
use ndarray::Array2;
use pyo3::prelude::*;
use crate::sample::SampleType;
const RIFF: &[u8; 4] = b"RIFF";
const DATA: &[u8; 4] = b"data";
const WAVE: &[u8; 4] = b"WAVE";
const FMT: &[u8; 4] = b"fmt ";
#[allow(unused)] const LIST: &[u8; 4] = b"LIST";
#[derive(Debug, Clone, PartialEq, Eq)]
#[pyclass]
pub struct WavFile {
pub fmt_chunk: FmtChunk,
pub data: Box<[u8]>,
pub seek_pos: u64,
}
impl WavFile {
pub fn new(fmt_chunk: FmtChunk, data: Box<[u8]>, seek_pos: u64) -> WavFile {
WavFile {
fmt_chunk,
data,
seek_pos,
}
}
#[inline(always)]
pub fn from_file(fp: &Path) -> Result<WavFile, std::io::Error> {
let file = File::open(fp)?;
let mut buf_reader = std::io::BufReader::new(file);
let fmt_chunk = FmtChunk::from_buf_reader(&mut buf_reader)?;
let (data_offset, data_len) = find_sub_chunk_id(&mut buf_reader, &b"data")?;
let mut data = alloc_box_buffer(data_len);
buf_reader.seek(SeekFrom::Start(data_offset as u64 + 4))?;
match buf_reader.read_exact(&mut data) {
Ok(_) => (),
Err(err) => {
eprintln!("Error reading data chunk: {}", err);
return Err(err);
}
}
Ok(WavFile::new(fmt_chunk, data, 0))
}
pub fn from_data(data: Array2<SampleType>, sample_rate: i32) -> WavFile {
let type_info = match data.first().expect("Empty array") {
SampleType::I16(_) => (1, 16),
SampleType::I32(_) => (1, 32),
SampleType::F32(_) => (3, 32),
SampleType::F64(_) => (3, 64),
};
let n_channels: u16 = data.shape()[1] as u16;
let block_align = n_channels * type_info.1 / 8;
let byte_rate = sample_rate * block_align as i32;
let fmt_chunk = match type_info {
(1, 16) => FmtChunk::new(16, 1, n_channels, sample_rate, byte_rate, block_align, 16),
(1, 32) => FmtChunk::new(16, 1, n_channels, sample_rate, byte_rate, block_align, 32),
(3, 32) => FmtChunk::new(16, 3, n_channels, sample_rate, byte_rate, block_align, 32),
(3, 64) => FmtChunk::new(16, 3, n_channels, sample_rate, byte_rate, block_align, 64),
_ => panic!("Unsupported sample type"),
};
let data = array_to_box_buffer(&data);
WavFile::new(fmt_chunk, data, 0)
}
#[inline]
pub fn read(&self, as_wav_type: Option<SampleType>) -> Array2<SampleType> {
let bits_per_sample = self.get_bits_per_sample(); let base_format = self.get_format();
let sample_format = match base_format {
1 => match bits_per_sample {
16 => SampleType::I16(0),
32 => SampleType::I32(0),
_ => panic!("Unsupported bit depth for PCM: {}", bits_per_sample),
},
3 => match bits_per_sample {
32 => SampleType::F32(0.0),
64 => SampleType::F64(0.0),
_ => panic!("Unsupported bit depth for float: {}", bits_per_sample),
},
_ => panic!("Unsupported format: {}", base_format),
};
let data = match sample_format {
SampleType::I16(_) => self.read_pcm_i16(),
SampleType::I32(_) => self.read_pcm_i32(),
SampleType::F32(_) => self.read_ieee_f32(),
SampleType::F64(_) => self.read_ieee_f64(),
};
match as_wav_type {
Some(wav_type) => match wav_type {
SampleType::I16(_) => data.mapv(|sample| sample.convert_to(wav_type)),
SampleType::I32(_) => data.mapv(|sample| sample.convert_to(wav_type)),
SampleType::F32(_) => data.mapv(|sample| sample.convert_to(wav_type)),
SampleType::F64(_) => data.mapv(|sample| sample.convert_to(wav_type)),
},
None => data,
}
}
#[inline(always)]
pub fn write_wav(&self, fp: &Path) -> Result<(), std::io::Error> {
let file = File::create(fp)?;
let mut buf_writer = BufWriter::new(file);
buf_writer.write(RIFF)?;
buf_writer.write(&(self.data.len() as u32 + 36).to_le_bytes())?;
buf_writer.write(WAVE)?;
buf_writer.write(FMT)?;
buf_writer.write_all(&self.fmt_chunk.as_bytes())?;
buf_writer.write(DATA)?;
buf_writer.write(&self.data.len().to_le_bytes())?;
buf_writer.write_all(self.data.as_ref())?;
Ok(())
}
#[inline(always)]
fn read_pcm_i16(&self) -> Array2<SampleType> {
let n_channels = self.fmt_chunk.channels as usize;
let mut channel_data: Vec<SampleType> =
Vec::with_capacity((self.data.len() / 2) - self.seek_pos as usize);
unsafe {
channel_data.set_len((self.data.len() / 2) - self.seek_pos as usize);
}
let mut idx = 0;
let iter_step = 2 * n_channels;
for samples in self.data.chunks(iter_step) {
unsafe {
for channel_sample in samples.as_chunks_unchecked::<2>() {
channel_data[idx] = SampleType::I16(i16::from_ne_bytes(*channel_sample));
idx += 1;
}
}
}
let out_array: Array2<SampleType> = match Array2::from_shape_vec(
(channel_data.len() / n_channels, n_channels),
channel_data,
) {
Ok(arr) => arr,
Err(err) => {
panic!("Error while shaping data : {}", err);
}
};
out_array
}
#[inline(always)]
fn read_pcm_i32(&self) -> Array2<SampleType> {
let n_channels = self.fmt_chunk.channels as usize;
let mut channel_data: Vec<SampleType> =
Vec::with_capacity((self.data.len() / 4) - self.seek_pos as usize); unsafe {
channel_data.set_len((self.data.len() / 4) - self.seek_pos as usize);
}
let mut idx = 0;
let iter_step = 4 * n_channels;
for samples in self.data.chunks(iter_step) {
for channel_sample in samples.chunks(std::mem::size_of::<i32>()) {
channel_data[idx] =
SampleType::I32(i32::from_ne_bytes(channel_sample.try_into().unwrap()));
idx += 1;
}
}
let out_array: Array2<SampleType> = match Array2::from_shape_vec(
(channel_data.len() / n_channels, n_channels),
channel_data,
) {
Ok(arr) => arr,
Err(err) => {
eprintln!("Error reading data chunk: {}", err);
panic!("Error reading data chunk: {}", err);
}
};
out_array
}
#[inline(always)]
fn read_ieee_f32(&self) -> Array2<SampleType> {
let n_channels = self.fmt_chunk.channels as usize;
let mut channel_data: Vec<SampleType> =
Vec::with_capacity((self.data.len() / 4) - self.seek_pos as usize); unsafe {
channel_data.set_len((self.data.len() / 4) - self.seek_pos as usize);
}
let mut idx = 0;
let iter_step = 4 * n_channels;
for samples in self.data.chunks(iter_step) {
for channel_sample in samples.chunks(std::mem::size_of::<f32>()) {
channel_data[idx] =
SampleType::F32(f32::from_ne_bytes(channel_sample.try_into().unwrap()));
idx += 1;
}
}
let out_array: Array2<SampleType> = match Array2::from_shape_vec(
(channel_data.len() / n_channels, n_channels),
channel_data,
) {
Ok(arr) => arr,
Err(err) => {
eprintln!("Error reading data chunk: {}", err);
panic!("Error reading data chunk: {}", err);
}
};
out_array
}
#[inline(always)]
fn read_ieee_f64(&self) -> Array2<SampleType> {
let n_channels = self.fmt_chunk.channels as usize;
let mut channel_data: Vec<SampleType> =
Vec::with_capacity((self.data.len() / 8) - self.seek_pos as usize); unsafe {
channel_data.set_len((self.data.len() / 8) - self.seek_pos as usize);
}
let mut idx = 0;
let iter_step = 8 * n_channels;
for samples in self.data.chunks(iter_step) {
for channel_sample in samples.chunks(std::mem::size_of::<f64>()) {
channel_data[idx] =
SampleType::F32(f32::from_ne_bytes(channel_sample.try_into().unwrap()));
idx += 1;
}
}
let out_array: Array2<SampleType> = match Array2::from_shape_vec(
(channel_data.len() / n_channels, n_channels),
channel_data,
) {
Ok(arr) => arr,
Err(err) => {
eprintln!("Error reading data chunk: {}", err);
panic!("Error reading data chunk: {}", err);
}
};
out_array
}
#[inline(always)]
pub fn duration(&self) -> u64 {
self.data_size() as u64
/ (self.sample_rate() * self.channels() as i32 * (self.bits_per_sample() / 8) as i32)
as u64
}
pub fn sample_rate(&self) -> i32 {
self.fmt_chunk.sample_rate()
}
pub fn channels(&self) -> u16 {
self.fmt_chunk.channels()
}
fn bits_per_sample(&self) -> u16 {
self.fmt_chunk.bits_per_sample()
}
fn data_size(&self) -> usize {
self.data.len() - self.seek_pos as usize
}
}
#[pymethods]
impl WavFile {
pub fn get_format(&self) -> u16 {
self.fmt_chunk.format
}
pub fn get_bits_per_sample(&self) -> u16 {
self.fmt_chunk.bits_per_sample
}
}
#[inline(always)]
pub fn signal_duration(signal_fp: &Path) -> Result<u64, std::io::Error> {
let wav_file = File::open(signal_fp)?;
let mut br = BufReader::new(wav_file);
let fmt_chunk = FmtChunk::from_buf_reader(&mut br)?;
let (data_offset, _) = find_sub_chunk_id(&mut br, &b"data")?;
let mut data_size_buf: [u8; 4] = [0; 4];
br.seek(SeekFrom::Start(data_offset as u64))?;
br.read_exact(&mut data_size_buf)?;
Ok(i32::from_ne_bytes(data_size_buf) as u64
/ (fmt_chunk.sample_rate()
* fmt_chunk.channels() as i32
* (fmt_chunk.bits_per_sample() / 8) as i32) as u64)
}
#[inline(always)]
pub fn signal_sample_rate(signal_fp: &Path) -> Result<i32, std::io::Error> {
let wav_file = File::open(signal_fp)?;
let mut br = BufReader::new(wav_file);
let fmt_chunk = FmtChunk::from_buf_reader(&mut br)?;
Ok(fmt_chunk.sample_rate())
}
#[inline(always)]
pub fn signal_channels(signal_fp: &Path) -> Result<u16, std::io::Error> {
let wav_file = File::open(signal_fp)?;
let mut br = BufReader::new(wav_file);
let fmt_chunk = FmtChunk::from_buf_reader(&mut br)?;
Ok(fmt_chunk.channels())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[pyclass]
pub struct SignalInfo {
pub sample_rate: i32,
pub channels: u16,
pub bits_per_sample: u16,
pub duration: u64,
}
impl SignalInfo {
pub fn new(sample_rate: i32, channels: u16, bits_per_sample: u16, duration: u64) -> Self {
Self {
sample_rate,
channels,
bits_per_sample,
duration,
}
}
}
pub fn signal_info(signal_fp: &Path) -> Result<SignalInfo, std::io::Error> {
let wav_file = File::open(signal_fp)?;
let mut br = BufReader::new(wav_file);
let fmt_chunk = FmtChunk::from_buf_reader(&mut br)?;
let (data_offset, _) = find_sub_chunk_id(&mut br, &b"data")?;
let mut data_size_buf: [u8; 4] = [0; 4];
br.seek(SeekFrom::Start(data_offset as u64))?;
br.read_exact(&mut data_size_buf)?;
Ok(SignalInfo::new(
fmt_chunk.sample_rate(),
fmt_chunk.channels(),
fmt_chunk.bits_per_sample(),
i32::from_ne_bytes(data_size_buf) as u64
/ (fmt_chunk.sample_rate()
* fmt_chunk.channels() as i32
* (fmt_chunk.bits_per_sample() / 8) as i32) as u64,
))
}
#[inline(always)]
pub fn read(fp: &Path, as_type: Option<SampleType>) -> Result<Array2<SampleType>, std::io::Error> {
let wav_file = WavFile::from_file(fp)?;
Ok(wav_file.read(as_type))
}
#[inline(always)]
pub fn write_wav_as(
fp: &Path,
data: &Array2<SampleType>,
as_type: Option<SampleType>,
sample_rate: i32,
) -> Result<(), std::io::Error> {
let file = File::create(fp)?;
let mut buf_writer = BufWriter::new(file);
buf_writer.write(RIFF)?;
let data_len = match as_type {
Some(t) => match t {
SampleType::I16(_) => data.len() * 2,
SampleType::I32(_) => data.len() * 4,
SampleType::F32(_) => data.len() * 4,
SampleType::F64(_) => data.len() * 8,
},
None => data.len() * 2,
};
buf_writer.write(&((data_len as i32 + 36).to_le_bytes()))?; buf_writer.write(WAVE)?; let byte_rate = sample_rate * data.ndim() as i32 * 2;
let block_align = data.ndim() as u16 * 2;
let n_channels = data.shape()[1] as u16;
let fmt_chunk = match as_type {
Some(t) => match t {
SampleType::I16(_) => {
FmtChunk::new(16, 1, n_channels, sample_rate, byte_rate, block_align, 16)
}
SampleType::I32(_) => {
FmtChunk::new(16, 1, n_channels, sample_rate, byte_rate, block_align, 32)
}
SampleType::F32(_) => {
FmtChunk::new(16, 3, n_channels, sample_rate, byte_rate, block_align, 32)
}
SampleType::F64(_) => {
FmtChunk::new(16, 3, n_channels, sample_rate, byte_rate, block_align, 64)
}
},
None => match data.first().expect("Empty array") {
SampleType::I16(_) => {
FmtChunk::new(16, 1, n_channels, sample_rate, byte_rate, block_align, 16)
}
SampleType::I32(_) => {
FmtChunk::new(16, 1, n_channels, sample_rate, byte_rate, block_align, 32)
}
SampleType::F32(_) => {
FmtChunk::new(16, 3, n_channels, sample_rate, byte_rate, block_align, 32)
}
SampleType::F64(_) => {
FmtChunk::new(16, 3, n_channels, sample_rate, byte_rate, block_align, 64)
}
},
};
let fmt_bytes = fmt_chunk.as_bytes(); buf_writer.write_all(&fmt_bytes)?;
buf_writer.write(DATA)?; buf_writer.write(&(data_len as u32).to_le_bytes())?;
for column in data.rows() {
for sample in column.iter() {
buf_writer.write_all(&&sample.to_le_bytes())?;
}
}
Ok(())
}
fn array_to_box_buffer(data: &Array2<SampleType>) -> Box<[u8]> {
let mut box_buf = alloc_box_buffer(data.len() * std::mem::size_of::<SampleType>());
let mut idx = 0;
let type_size = std::mem::size_of::<SampleType>();
for column in data.rows() {
for sample in column.iter() {
let sample_bytes = sample.to_le_bytes();
box_buf[idx..idx + type_size].copy_from_slice(&sample_bytes);
idx += type_size;
}
}
box_buf
}
#[inline(always)]
pub fn alloc_box_buffer(len: usize) -> Box<[u8]> {
if len == 0 {
return <Box<[u8]>>::default();
}
let layout = match Layout::array::<u8>(len) {
Ok(layout) => layout,
Err(_) => panic!("Failed to allocate buffer of size {}", len),
};
let ptr = unsafe { std::alloc::alloc(layout) };
let slice_ptr = core::ptr::slice_from_raw_parts_mut(ptr, len);
unsafe { Box::from_raw(slice_ptr) }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
#[pyclass]
pub struct FmtChunk {
pub size: i32,
pub format: u16,
pub channels: u16,
pub sample_rate: i32,
pub byte_rate: i32,
pub block_align: u16,
pub bits_per_sample: u16,
}
impl FmtChunk {
pub fn new(
size: i32, format: u16, channels: u16, sample_rate: i32, byte_rate: i32, block_align: u16, bits_per_sample: u16, ) -> FmtChunk {
FmtChunk {
size,
format,
channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
}
}
#[inline(always)]
pub fn from_path(signal_fp: &Path) -> Result<FmtChunk, std::io::Error> {
let wav_file = File::open(signal_fp)?;
let mut br = BufReader::new(wav_file);
FmtChunk::from_buf_reader(&mut br)
}
#[inline(always)]
fn from_buf_reader(br: &mut BufReader<File>) -> Result<FmtChunk, std::io::Error> {
let mut buf: [u8; 4] = [0; 4];
let mut buf_two: [u8; 2] = [0; 2];
let (offset, _) = find_sub_chunk_id(br, b"fmt ")?;
br.seek(SeekFrom::Start(offset as u64))?;
br.read_exact(&mut buf)?;
let size = i32::from_ne_bytes(buf);
br.read_exact(&mut buf_two)?;
let format = u16::from_ne_bytes(buf_two);
br.read_exact(&mut buf_two)?;
let channels = u16::from_ne_bytes(buf_two);
br.read_exact(&mut buf)?;
let sample_rate = i32::from_ne_bytes(buf);
br.read_exact(&mut buf)?;
let byte_rate = i32::from_ne_bytes(buf);
br.read_exact(&mut buf_two)?;
let block_align = u16::from_ne_bytes(buf_two);
br.read_exact(&mut buf_two)?;
let bits_per_sample = u16::from_ne_bytes(buf_two);
br.seek(SeekFrom::Start(0))?;
Ok(FmtChunk::new(
size,
format,
channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
))
}
#[inline(always)]
pub fn as_bytes(&self) -> [u8; 24] {
let mut buf: [u8; 24] = [0; 24];
buf[0..4].copy_from_slice(FMT);
buf[4..8].copy_from_slice(&self.size.to_le_bytes());
buf[8..10].copy_from_slice(&self.format.to_le_bytes());
buf[10..12].copy_from_slice(&self.channels.to_le_bytes());
buf[12..16].copy_from_slice(&self.sample_rate.to_le_bytes());
buf[16..20].copy_from_slice(&self.byte_rate.to_le_bytes());
buf[20..22].copy_from_slice(&self.block_align.to_le_bytes());
buf[22..24].copy_from_slice(&self.bits_per_sample.to_le_bytes());
buf
}
#[inline(always)]
pub fn get_sample_size(&self) -> usize {
self.bits_per_sample as usize / 8
}
#[inline(always)]
pub fn format(&self) -> u16 {
self.format
}
#[inline(always)]
pub fn channels(&self) -> u16 {
self.channels
}
#[inline(always)]
pub fn sample_rate(&self) -> i32 {
self.sample_rate
}
#[inline(always)]
pub fn byte_rate(&self) -> i32 {
self.byte_rate
}
#[inline(always)]
pub fn block_align(&self) -> u16 {
self.block_align
}
#[inline(always)]
pub fn bits_per_sample(&self) -> u16 {
self.bits_per_sample
}
}
#[inline(always)]
pub fn find_sub_chunk_id(
file: &mut BufReader<File>,
chunk_id: &[u8; 4],
) -> Result<(usize, usize), std::io::Error> {
let mut buf: [u8; 4] = [0; 4];
file.read_exact(&mut buf)?;
if !buf_eq(&buf, RIFF) {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to find RIFF tag in {:?}", file.get_ref()),
));
}
file.seek(SeekFrom::Current(8))?;
let mut tag_offset: usize = 0;
let mut bytes_traversed: usize = 12;
loop {
let bytes_read = file.read(&mut buf)?;
if bytes_read == 0 {
break;
}
bytes_traversed += bytes_read;
if buf_eq(&buf, chunk_id) {
tag_offset = bytes_traversed;
}
let bytes_read = file.read(&mut buf)?;
if bytes_read == 0 {
break;
}
bytes_traversed += bytes_read;
let chunk_len =
buf[0] as u32 | (buf[1] as u32) << 8 | (buf[2] as u32) << 16 | (buf[3] as u32) << 24;
if tag_offset > 0 {
let chunk_size = chunk_len as usize;
file.seek(SeekFrom::Start(0))?; return Ok((tag_offset, chunk_size));
}
file.seek(SeekFrom::Current(chunk_len as i64))?;
bytes_traversed += chunk_len as usize;
}
file.seek(SeekFrom::Start(0))?;
Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!(
"Failed to find {:?} tag in {:?}",
std::str::from_utf8(chunk_id).unwrap(),
file.get_ref()
),
))
}
#[inline(always)]
fn buf_eq(buf: &[u8; 4], chunk_id: &[u8; 4]) -> bool {
buf[0] == chunk_id[0] && buf[1] == chunk_id[1] && buf[2] == chunk_id[2] && buf[3] == chunk_id[3]
}