use std::fs::File;
use std::path::Path;
use num::Float;
use symphonia::core::audio::SampleBuffer;
use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
use symphonia::core::errors::Error;
use symphonia::core::formats::{FormatOptions, SeekMode, SeekTo};
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;
use thiserror::Error;
use crate::resample::{ResampleError, resample};
#[derive(Debug, Clone)]
pub struct Audio<F> {
pub samples_interleaved: Vec<F>,
pub sample_rate: u32,
pub num_channels: u16,
}
#[derive(Debug, Error)]
pub enum ReadError {
#[error("could not read file")]
Io(#[from] std::io::Error),
#[error("could not decode audio")]
Decode(#[from] symphonia::core::errors::Error),
#[error("no track found")]
NoTrack,
#[error("no sample rate found")]
NoSampleRate,
#[error("end frame ({end}) must not exceed start frame ({start})")]
InvalidFrameRange { start: usize, end: usize },
#[error("start channel {index} out of bounds (file has {total} channels)")]
InvalidChannel { index: usize, total: usize },
#[error("invalid channel count: {0}")]
InvalidChannelCount(usize),
#[error("resample failed")]
Resample(#[from] ResampleError),
}
#[derive(Default, Debug, Clone, Copy)]
pub enum Position {
#[default]
Default,
Time(std::time::Duration),
Frame(usize),
}
#[derive(Default)]
pub struct ReadConfig {
pub start: Position,
pub stop: Position,
pub start_channel: Option<usize>,
pub num_channels: Option<usize>,
pub sample_rate: Option<u32>,
}
pub fn read<F: Float + rubato::Sample>(
path: impl AsRef<Path>,
config: ReadConfig,
) -> Result<Audio<F>, ReadError> {
let src = File::open(path.as_ref())?;
let mss = MediaSourceStream::new(Box::new(src), Default::default());
let mut hint = Hint::new();
if let Some(ext) = path.as_ref().extension()
&& let Some(ext_str) = ext.to_str()
{
hint.with_extension(ext_str);
}
let meta_opts: MetadataOptions = Default::default();
let fmt_opts: FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.ok_or(ReadError::NoTrack)?;
let sample_rate = track
.codec_params
.sample_rate
.ok_or(ReadError::NoSampleRate)?;
let track_id = track.id;
let codec_params = track.codec_params.clone();
let time_base = track.codec_params.time_base;
let start_frame = match config.start {
Position::Default => 0,
Position::Time(duration) => {
let secs = duration.as_secs_f64();
(secs * sample_rate as f64) as usize
}
Position::Frame(frame) => frame,
};
let end_frame: Option<usize> = match config.stop {
Position::Default => None,
Position::Time(duration) => {
let secs = duration.as_secs_f64();
Some((secs * sample_rate as f64) as usize)
}
Position::Frame(frame) => Some(frame),
};
if let Some(end_frame) = end_frame
&& start_frame > end_frame
{
return Err(ReadError::InvalidFrameRange {
start: start_frame,
end: end_frame,
});
}
if start_frame > sample_rate as usize
&& let Some(tb) = time_base
{
let seek_sample = (start_frame as f64 * 0.9) as u64;
let seek_ts = (seek_sample * tb.denom as u64) / (sample_rate as u64);
let _ = format.seek(
SeekMode::Accurate,
SeekTo::TimeStamp {
ts: seek_ts,
track_id,
},
);
}
let dec_opts: DecoderOptions = Default::default();
let mut decoder = symphonia::default::get_codecs().make(&codec_params, &dec_opts)?;
let mut sample_buf = None;
let mut samples = Vec::new();
let mut num_channels = 0usize;
let start_channel = config.start_channel;
let mut current_sample: Option<u64> = None;
loop {
let packet = match format.next_packet() {
Ok(packet) => packet,
Err(Error::ResetRequired) => {
decoder.reset();
continue;
}
Err(Error::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
break;
}
Err(err) => return Err(err.into()),
};
if packet.track_id() != track_id {
continue;
}
let decoded = decoder.decode(&packet)?;
if current_sample.is_none() {
let ts = packet.ts();
if let Some(tb) = time_base {
current_sample = Some((ts * sample_rate as u64) / tb.denom as u64);
} else {
current_sample = Some(0);
}
}
if sample_buf.is_none() {
let spec = *decoded.spec();
let duration = decoded.capacity() as u64;
sample_buf = Some(SampleBuffer::<f32>::new(duration, spec));
num_channels = spec.channels.count();
let ch_start = start_channel.unwrap_or(0);
let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
if ch_start >= num_channels {
return Err(ReadError::InvalidChannel {
index: ch_start,
total: num_channels,
});
}
if ch_count == 0 {
return Err(ReadError::InvalidChannelCount(0));
}
if ch_start + ch_count > num_channels {
return Err(ReadError::InvalidChannelCount(ch_count));
}
}
if let Some(buf) = &mut sample_buf {
buf.copy_interleaved_ref(decoded);
let packet_samples = buf.samples();
let mut pos = current_sample.unwrap_or(0);
let ch_start = start_channel.unwrap_or(0);
let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
let ch_end = ch_start + ch_count;
let frames = packet_samples.len() / num_channels;
for frame_idx in 0..frames {
if let Some(end) = end_frame
&& pos >= end as u64
{
return Ok(Audio {
samples_interleaved: samples,
sample_rate,
num_channels: ch_count as u16,
});
}
if pos >= start_frame as u64 {
for ch in ch_start..ch_end {
let sample_idx = frame_idx * num_channels + ch;
samples.push(F::from(packet_samples[sample_idx]).unwrap());
}
}
pos += 1;
}
current_sample = Some(pos);
}
}
let ch_start = start_channel.unwrap_or(0);
let ch_count = config.num_channels.unwrap_or(num_channels - ch_start);
let samples = if let Some(sr_out) = config.sample_rate
&& sr_out != sample_rate
{
resample(&samples, ch_count, sample_rate, sr_out)?
} else {
samples
};
let actual_sample_rate = config.sample_rate.unwrap_or(sample_rate);
Ok(Audio {
samples_interleaved: samples,
sample_rate: actual_sample_rate,
num_channels: ch_count as u16,
})
}
#[cfg(feature = "audio-blocks")]
pub fn read_block<F: num::Float + 'static + rubato::Sample>(
path: impl AsRef<Path>,
config: ReadConfig,
) -> Result<(audio_blocks::Interleaved<F>, u32), ReadError> {
let audio = read(path, config)?;
Ok((
audio_blocks::Interleaved::from_slice(&audio.samples_interleaved, audio.num_channels),
audio.sample_rate,
))
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use audio_blocks::{AudioBlock, InterleavedView};
use super::*;
fn to_block<F: num::Float + 'static>(audio: &Audio<F>) -> InterleavedView<'_, F> {
InterleavedView::from_slice(&audio.samples_interleaved, audio.num_channels)
}
#[test]
fn test_sine_wave_data_integrity() {
const SAMPLE_RATE: f64 = 48000.0;
const N_SAMPLES: usize = 48000;
const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
let audio = read::<f32>("test_data/test_4ch.wav", ReadConfig::default()).unwrap();
let block = to_block(&audio);
assert_eq!(audio.sample_rate, 48000);
assert_eq!(block.num_frames(), N_SAMPLES);
assert_eq!(block.num_channels(), 4);
for (ch, &freq) in FREQUENCIES.iter().enumerate() {
for frame in 0..N_SAMPLES {
let expected =
(2.0 * std::f64::consts::PI * freq * frame as f64 / SAMPLE_RATE).sin() as f32;
let actual = block.sample(ch as u16, frame);
assert!(
(actual - expected).abs() < 1e-4,
"Mismatch at channel {ch}, frame {frame}: expected {expected}, got {actual}"
);
}
}
let audio = read::<f32>(
"test_data/test_4ch.wav",
ReadConfig {
start: Position::Frame(24000),
stop: Position::Frame(24100),
..Default::default()
},
)
.unwrap();
let block = to_block(&audio);
for (ch, &freq) in FREQUENCIES.iter().enumerate() {
for frame in 0..100 {
let actual_frame = 24000 + frame;
let expected = (2.0 * std::f64::consts::PI * freq * actual_frame as f64
/ SAMPLE_RATE)
.sin() as f32;
let actual = block.sample(ch as u16, frame);
assert!(
(actual - expected).abs() < 1e-4,
"Offset mismatch at channel {ch}, frame {actual_frame}: expected {expected}, got {actual}"
);
}
}
}
#[test]
fn test_samples_selection() {
let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
let block1 = to_block(&audio1);
assert_eq!(audio1.sample_rate, 48000);
assert_eq!(block1.num_frames(), 48000);
assert_eq!(block1.num_channels(), 1);
let audio2 = read::<f32>(
"test_data/test_1ch.wav",
ReadConfig {
start: Position::Frame(1100),
stop: Position::Frame(1200),
..Default::default()
},
)
.unwrap();
let block2 = to_block(&audio2);
assert_eq!(audio2.sample_rate, 48000);
assert_eq!(block2.num_frames(), 100);
assert_eq!(block2.num_channels(), 1);
assert_eq!(block1.raw_data()[1100..1200], block2.raw_data()[..]);
}
#[test]
fn test_time_selection() {
let audio1 = read::<f32>("test_data/test_1ch.wav", ReadConfig::default()).unwrap();
let block1 = to_block(&audio1);
assert_eq!(audio1.sample_rate, 48000);
assert_eq!(block1.num_frames(), 48000);
assert_eq!(block1.num_channels(), 1);
let audio2 = read::<f32>(
"test_data/test_1ch.wav",
ReadConfig {
start: Position::Time(Duration::from_secs_f32(0.5)),
stop: Position::Time(Duration::from_secs_f32(0.6)),
..Default::default()
},
)
.unwrap();
let block2 = to_block(&audio2);
assert_eq!(audio2.sample_rate, 48000);
assert_eq!(block2.num_frames(), 4800);
assert_eq!(block2.num_channels(), 1);
assert_eq!(block1.raw_data()[24000..28800], block2.raw_data()[..]);
}
#[test]
fn test_channel_selection() {
let audio1 = read::<f32>("test_data/test_4ch.wav", ReadConfig::default()).unwrap();
let block1 = to_block(&audio1);
assert_eq!(audio1.sample_rate, 48000);
assert_eq!(block1.num_frames(), 48000);
assert_eq!(block1.num_channels(), 4);
let audio2 = read::<f32>(
"test_data/test_4ch.wav",
ReadConfig {
start_channel: Some(1),
num_channels: Some(2),
..Default::default()
},
)
.unwrap();
let block2 = to_block(&audio2);
assert_eq!(audio2.sample_rate, 48000);
assert_eq!(block2.num_frames(), 48000);
assert_eq!(block2.num_channels(), 2);
for frame in 0..10 {
assert_eq!(block2.sample(0, frame), block1.sample(1, frame));
assert_eq!(block2.sample(1, frame), block1.sample(2, frame));
}
}
#[test]
fn test_fail_selection() {
match read::<f32>(
"test_data/test_1ch.wav",
ReadConfig {
start: Position::Frame(100),
stop: Position::Frame(99),
..Default::default()
},
) {
Err(ReadError::InvalidFrameRange { start: _, end: _ }) => (),
_ => panic!(),
}
match read::<f32>(
"test_data/test_1ch.wav",
ReadConfig {
start: Position::Time(Duration::from_secs_f32(0.6)),
stop: Position::Time(Duration::from_secs_f32(0.5)),
..Default::default()
},
) {
Err(ReadError::InvalidFrameRange { start: _, end: _ }) => (),
_ => panic!(),
}
match read::<f32>(
"test_data/test_1ch.wav",
ReadConfig {
start_channel: Some(1),
..Default::default()
},
) {
Err(ReadError::InvalidChannel { index: _, total: _ }) => (),
_ => panic!(),
}
match read::<f32>(
"test_data/test_1ch.wav",
ReadConfig {
num_channels: Some(0),
..Default::default()
},
) {
Err(ReadError::InvalidChannelCount(0)) => (),
_ => panic!(),
}
match read::<f32>(
"test_data/test_1ch.wav",
ReadConfig {
num_channels: Some(2),
..Default::default()
},
) {
Err(ReadError::InvalidChannelCount(2)) => (),
_ => panic!(),
}
}
#[test]
fn test_resample_preserves_frequency() {
const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
let sr_out: u32 = 22050;
let audio = read::<f32>(
"test_data/test_4ch.wav",
ReadConfig {
sample_rate: Some(sr_out),
..Default::default()
},
)
.unwrap();
let block = to_block(&audio);
assert_eq!(audio.sample_rate, sr_out); assert_eq!(block.num_channels(), 4);
let expected_frames = 22050;
assert_eq!(
block.num_frames(),
expected_frames,
"Expected {} frames, got {}",
expected_frames,
block.num_frames()
);
let start_frame = 100;
let test_frames = 1000;
for (ch, &freq) in FREQUENCIES.iter().enumerate() {
let mut max_error: f32 = 0.0;
for frame in start_frame..(start_frame + test_frames) {
let expected =
(2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
let actual = block.sample(ch as u16, frame);
let error = (actual - expected).abs();
max_error = max_error.max(error);
}
assert!(
max_error < 0.02,
"Channel {} ({}Hz): max error {} exceeds threshold",
ch,
freq,
max_error
);
}
}
#[test]
fn test_channel_selection_with_resampling() {
const FREQUENCIES: [f64; 4] = [440.0, 554.37, 659.25, 880.0];
let sr_out: u32 = 22050;
let audio = read::<f32>(
"test_data/test_4ch.wav",
ReadConfig {
start_channel: Some(1),
num_channels: Some(2),
sample_rate: Some(sr_out),
..Default::default()
},
)
.unwrap();
let block = to_block(&audio);
assert_eq!(audio.num_channels, 2, "Should have 2 channels");
assert_eq!(
audio.sample_rate, sr_out,
"Sample rate should be the resampled rate"
);
let expected_frames = 22050;
assert_eq!(
block.num_frames(),
expected_frames,
"Expected {} frames, got {}",
expected_frames,
block.num_frames()
);
let selected_freqs = &FREQUENCIES[1..3];
let start_frame = 100;
let test_frames = 1000;
for (ch, &freq) in selected_freqs.iter().enumerate() {
let mut max_error: f32 = 0.0;
for frame in start_frame..(start_frame + test_frames) {
let expected =
(2.0 * std::f64::consts::PI * freq * frame as f64 / sr_out as f64).sin() as f32;
let actual = block.sample(ch as u16, frame);
let error = (actual - expected).abs();
max_error = max_error.max(error);
}
assert!(
max_error < 0.02,
"Channel {} ({}Hz): max error {} exceeds threshold",
ch,
freq,
max_error
);
}
}
}