use audio_samples::traits::StandardSample;
use non_empty_slice::NonEmptyVec;
use numpy::{Element, PyArray1, PyArray2, PyArrayMethods};
use pyo3::prelude::*;
use std::any::TypeId;
use std::io::{BufReader, Read};
use std::path::Path;
use crate::traits::{AudioFile, AudioFileMetadata};
use crate::types::{OpenOptions, ValidatedSampleType};
use crate::wav::{WavFile, wav_file::parse_wav_header_streaming};
use crate::{AudioIOError, AudioIOResult, BaseAudioInfo, FileType};
use audio_samples::I24;
#[cfg(target_endian = "little")]
pub enum NativeAudioArray {
U8(Py<PyArray2<u8>>, BaseAudioInfo),
I16(Py<PyArray2<i16>>, BaseAudioInfo),
I32(Py<PyArray2<i32>>, BaseAudioInfo),
F32(Py<PyArray2<f32>>, BaseAudioInfo),
F64(Py<PyArray2<f64>>, BaseAudioInfo),
}
#[cfg(target_endian = "little")]
pub fn read_pyarray_native(
py: Python<'_>,
path: &Path,
) -> Option<PyResult<NativeAudioArray>> {
if !matches!(FileType::from_path(path), FileType::WAV) {
return None;
}
let parse_result: AudioIOResult<(BaseAudioInfo, BufReader<std::fs::File>)> = py.detach(|| {
let file = std::fs::File::open(path).map_err(AudioIOError::from)?;
let mut reader = BufReader::with_capacity(65536, file);
let (info, _) = parse_wav_header_streaming(&mut reader)?;
Ok((info, reader))
});
let (info, reader) = parse_result.ok()?;
if matches!(info.sample_type, audio_samples::SampleType::I24) {
return None;
}
let channels = info.channels as usize;
let total_samples = info.total_samples;
if channels == 0 || total_samples == 0 || total_samples % channels != 0 {
return None;
}
let result = match info.sample_type {
audio_samples::SampleType::U8 => {
alloc_and_fill::<u8>(py, reader, info).map(|(a, i)| NativeAudioArray::U8(a, i))
}
audio_samples::SampleType::I16 => {
alloc_and_fill::<i16>(py, reader, info).map(|(a, i)| NativeAudioArray::I16(a, i))
}
audio_samples::SampleType::I32 => {
alloc_and_fill::<i32>(py, reader, info).map(|(a, i)| NativeAudioArray::I32(a, i))
}
audio_samples::SampleType::F32 => {
alloc_and_fill::<f32>(py, reader, info).map(|(a, i)| NativeAudioArray::F32(a, i))
}
audio_samples::SampleType::F64 => {
alloc_and_fill::<f64>(py, reader, info).map(|(a, i)| NativeAudioArray::F64(a, i))
}
_ => return None,
};
Some(result)
}
#[cfg(target_endian = "little")]
fn alloc_and_fill<T>(
py: Python<'_>,
reader: BufReader<std::fs::File>,
info: BaseAudioInfo,
) -> PyResult<(Py<PyArray2<T>>, BaseAudioInfo)>
where
T: Element + 'static,
{
use std::mem::size_of;
let channels = info.channels as usize;
let total_samples = info.total_samples;
let frames = total_samples / channels;
let byte_count = total_samples * size_of::<T>();
let array = unsafe { PyArray2::<T>::new(py, [channels, frames], true) };
let data_ptr_usize = array.data() as usize;
let read_result: AudioIOResult<()> = py.detach(|| {
let mut r = reader;
let bytes =
unsafe { std::slice::from_raw_parts_mut(data_ptr_usize as *mut u8, byte_count) };
r.read_exact(bytes).map_err(AudioIOError::from)
});
read_result
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
Ok((array.unbind(), info))
}
pub fn read_pyarray<P, T>(py: Python<'_>, fp: P) -> PyResult<(Py<PyArray2<T>>, BaseAudioInfo)>
where
P: AsRef<Path>,
T: StandardSample + Element + 'static,
{
let path = fp.as_ref();
#[cfg(target_endian = "little")]
if TypeId::of::<T>() != TypeId::of::<I24>() {
if let Some(result) = read_pyarray_direct::<T>(py, path) {
return result;
}
}
let (interleaved_vec, info) = py
.detach(|| read_interleaved_with_info::<_, T>(path))
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
let pyarray = create_pyarray_fortran(
py,
interleaved_vec,
info.channels as usize,
info.total_samples,
)?;
Ok((pyarray, info))
}
#[cfg(target_endian = "little")]
fn read_pyarray_direct<T>(
py: Python<'_>,
path: &Path,
) -> Option<PyResult<(Py<PyArray2<T>>, BaseAudioInfo)>>
where
T: StandardSample + Element + 'static,
{
use std::mem::size_of;
let parse_result: AudioIOResult<(BaseAudioInfo, BufReader<std::fs::File>)> = py.detach(|| {
let file = std::fs::File::open(path).map_err(AudioIOError::from)?;
let mut reader = BufReader::with_capacity(65536, file);
let (info, _) = parse_wav_header_streaming(&mut reader)?;
Ok((info, reader))
});
let (info, reader) = match parse_result {
Ok(v) => v,
Err(_) => return None,
};
let native_matches = match info.sample_type {
audio_samples::SampleType::U8 => TypeId::of::<T>() == TypeId::of::<u8>(),
audio_samples::SampleType::I16 => TypeId::of::<T>() == TypeId::of::<i16>(),
audio_samples::SampleType::I32 => TypeId::of::<T>() == TypeId::of::<i32>(),
audio_samples::SampleType::F32 => TypeId::of::<T>() == TypeId::of::<f32>(),
audio_samples::SampleType::F64 => TypeId::of::<T>() == TypeId::of::<f64>(),
_ => false,
};
if !native_matches {
return None;
}
let channels = info.channels as usize;
let total_samples = info.total_samples;
if channels == 0 || total_samples == 0 || total_samples % channels != 0 {
return None;
}
let frames = total_samples / channels;
let byte_count = total_samples * size_of::<T>();
let array = unsafe { PyArray2::<T>::new(py, [channels, frames], true) };
let data_ptr_usize = array.data() as usize;
let read_result: AudioIOResult<()> = py.detach(|| {
let mut r = reader;
let bytes = unsafe { std::slice::from_raw_parts_mut(data_ptr_usize as *mut u8, byte_count) };
r.read_exact(bytes).map_err(AudioIOError::from)
});
if let Err(e) = read_result {
return Some(Err(PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string())));
}
Some(Ok((array.unbind(), info)))
}
fn read_interleaved_with_info<P, T>(fp: P) -> AudioIOResult<(NonEmptyVec<T>, BaseAudioInfo)>
where
P: AsRef<Path>,
T: StandardSample + 'static,
{
let path = fp.as_ref();
match FileType::from_path(path) {
FileType::WAV => {
#[cfg(target_endian = "little")]
if TypeId::of::<T>() != TypeId::of::<I24>() {
if let Ok(result) = read_wav_direct::<T>(path) {
return Ok(result);
}
}
let wav_file = WavFile::open_with_options(path, OpenOptions::default())?;
let info = wav_file.base_info()?;
let data_chunk = wav_file.data();
let sample_type = wav_file.sample_type();
let interleaved_vec = match sample_type {
ValidatedSampleType::U8 => data_chunk.read_samples::<u8, T>(),
ValidatedSampleType::I16 => data_chunk.read_samples::<i16, T>(),
ValidatedSampleType::I24 => data_chunk.read_samples::<I24, T>(),
ValidatedSampleType::I32 => data_chunk.read_samples::<i32, T>(),
ValidatedSampleType::F32 => data_chunk.read_samples::<f32, T>(),
ValidatedSampleType::F64 => data_chunk.read_samples::<f64, T>(),
}?;
Ok((interleaved_vec, info))
}
#[cfg(feature = "flac")]
FileType::FLAC => {
use crate::flac::FlacFile;
use crate::traits::{AudioFile, AudioFileRead};
let flac_file = FlacFile::open_with_options(path, OpenOptions::default())?;
let info = flac_file.base_info()?;
let channels = info.channels as usize;
let total_samples = info.total_samples;
let frames = total_samples / channels;
let audio = flac_file.read::<T>()?.into_owned();
let planar = audio.as_slice().ok_or_else(|| {
AudioIOError::corrupted_data_simple(
"FLAC decode produced non-contiguous data",
"Cannot extract samples",
)
})?;
let mut interleaved: Vec<T> = Vec::with_capacity(total_samples);
for f in 0..frames {
for c in 0..channels {
interleaved.push(planar[c * frames + f]);
}
}
let nev = NonEmptyVec::try_from(interleaved).map_err(|_| {
AudioIOError::corrupted_data_simple("Empty FLAC file", "No samples decoded")
})?;
Ok((nev, info))
}
other => Err(AudioIOError::unsupported_format(format!(
"Unsupported file format: {:?}",
other
))),
}
}
#[cfg(target_endian = "little")]
fn read_wav_direct<T>(path: &Path) -> AudioIOResult<(NonEmptyVec<T>, BaseAudioInfo)>
where
T: StandardSample + 'static,
{
use std::mem::size_of;
let file = std::fs::File::open(path).map_err(AudioIOError::from)?;
let mut reader = BufReader::with_capacity(65536, file);
let (info, data_byte_offset) = parse_wav_header_streaming(&mut reader)?;
let native_matches = match info.sample_type {
audio_samples::SampleType::U8 => TypeId::of::<T>() == TypeId::of::<u8>(),
audio_samples::SampleType::I16 => TypeId::of::<T>() == TypeId::of::<i16>(),
audio_samples::SampleType::I32 => TypeId::of::<T>() == TypeId::of::<i32>(),
audio_samples::SampleType::F32 => TypeId::of::<T>() == TypeId::of::<f32>(),
audio_samples::SampleType::F64 => TypeId::of::<T>() == TypeId::of::<f64>(),
_ => false,
};
if !native_matches {
return Err(AudioIOError::unsupported_format(
"Type mismatch — use mmap path for conversion",
));
}
let total_samples = info.total_samples;
let byte_count = total_samples * size_of::<T>();
let mut vec: Vec<T> = Vec::with_capacity(total_samples);
unsafe { vec.set_len(total_samples) };
let bytes =
unsafe { std::slice::from_raw_parts_mut(vec.as_mut_ptr().cast::<u8>(), byte_count) };
let _ = data_byte_offset; reader.read_exact(bytes).map_err(AudioIOError::from)?;
let nev = NonEmptyVec::try_from(vec)
.map_err(|_| AudioIOError::corrupted_data_simple("Empty WAV file", "No audio samples"))?;
Ok((nev, info))
}
fn create_pyarray_fortran<T>(
py: Python<'_>,
interleaved_vec: NonEmptyVec<T>,
channels: usize,
total_samples: usize,
) -> PyResult<Py<PyArray2<T>>>
where
T: StandardSample + Element,
{
if channels == 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"channels must be non-zero",
));
}
if total_samples % channels != 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"total_samples ({}) not divisible by channels ({})",
total_samples, channels
)));
}
if interleaved_vec.len().get() != total_samples {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Vec length ({}) does not match total_samples ({})",
interleaved_vec.len(),
total_samples
)));
}
let frames = total_samples / channels;
let shape = (channels, frames);
let array1 = PyArray1::from_vec(py, interleaved_vec.into_vec());
let array2 = array1
.reshape_with_order(shape, numpy::npyffi::NPY_ORDER::NPY_FORTRANORDER)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to reshape array: {}",
e
))
})?;
Ok(array2.unbind())
}
#[cfg(test)]
mod tests {
use super::*;
use non_empty_slice::non_empty_vec;
use numpy::{PyReadonlyArray2, PyUntypedArrayMethods};
#[test]
#[cfg(feature = "numpy")]
fn test_create_pyarray_fortran_stereo() {
Python::initialize();
Python::attach(|py| {
let interleaved = non_empty_vec![1i16, 2, 3, 4, 5, 6];
let channels = 2;
let total_samples = 6;
let arr = create_pyarray_fortran(py, interleaved, channels, total_samples)
.expect("Failed to create PyArray");
let bound = arr.bind(py);
assert_eq!(bound.shape(), &[2, 3]);
assert!(bound.is_fortran_contiguous());
assert!(!bound.is_c_contiguous());
let ro: PyReadonlyArray2<i16> = bound.readonly();
let nd = ro.as_array();
assert_eq!(nd[[0, 0]], 1);
assert_eq!(nd[[1, 0]], 2);
assert_eq!(nd[[0, 1]], 3);
assert_eq!(nd[[1, 1]], 4);
assert_eq!(nd[[0, 2]], 5);
assert_eq!(nd[[1, 2]], 6);
});
}
#[test]
#[cfg(feature = "numpy")]
fn test_create_pyarray_fortran_mono() {
Python::initialize();
Python::attach(|py| {
let mono = non_empty_vec![10i16, 20, 30, 40];
let channels = 1;
let total_samples = 4;
let arr = create_pyarray_fortran(py, mono, channels, total_samples)
.expect("Failed to create PyArray");
let bound = arr.bind(py);
assert_eq!(bound.shape(), &[1, 4]);
let ro: PyReadonlyArray2<i16> = bound.readonly();
let nd = ro.as_array();
assert_eq!(nd[[0, 0]], 10);
assert_eq!(nd[[0, 1]], 20);
assert_eq!(nd[[0, 2]], 30);
assert_eq!(nd[[0, 3]], 40);
});
}
#[test]
#[cfg(feature = "numpy")]
fn test_create_pyarray_fortran_validation() {
Python::initialize();
Python::attach(|py| {
let result = create_pyarray_fortran(py, non_empty_vec![1i16, 2], 0, 2);
assert!(result.is_err());
let result = create_pyarray_fortran(py, non_empty_vec![1i16, 2, 3], 2, 2);
assert!(result.is_err());
let result = create_pyarray_fortran(py, non_empty_vec![1i16, 2, 3], 2, 3);
assert!(result.is_err());
});
}
}