use std::fs::File;
use std::path::Path;
use symphonia::core::audio::{AudioBuffer, AudioBufferRef, Signal};
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
use symphonia::core::errors::Error as SymphoniaError;
use symphonia::core::formats::{FormatOptions, FormatReader, SeekMode, SeekTo};
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;
use symphonia::default::get_codecs;
use symphonia::default::get_probe;
use super::error::SampleSourceError;
use super::traits::SampleSource;
#[cfg(test)]
use super::traits::SampleSourceTestExt;
pub struct AudioSampleSource {
format_reader: Box<dyn FormatReader>,
decoder: Box<dyn symphonia::core::codecs::Decoder>,
track_id: u32,
is_finished: bool,
sample_buffer: Vec<f32>,
buffer_position: usize,
buffer_size: usize,
leftover_samples: Vec<f32>,
bits_per_sample: u16,
channels: u16,
sample_rate: u32,
sample_format: crate::audio::SampleFormat,
duration: std::time::Duration,
}
impl SampleSource for AudioSampleSource {
fn next_sample(&mut self) -> Result<Option<f32>, SampleSourceError> {
if self.is_finished {
return Ok(None);
}
if self.buffer_position >= self.sample_buffer.len() {
self.refill_buffer()?;
if self.sample_buffer.is_empty() {
self.is_finished = true;
return Ok(None);
}
}
let sample = self.sample_buffer[self.buffer_position];
self.buffer_position += 1;
Ok(Some(sample))
}
fn channel_count(&self) -> u16 {
self.channels
}
fn sample_rate(&self) -> u32 {
self.sample_rate
}
fn bits_per_sample(&self) -> u16 {
self.bits_per_sample
}
fn sample_format(&self) -> crate::audio::SampleFormat {
self.sample_format
}
fn duration(&self) -> Option<std::time::Duration> {
Some(self.duration)
}
}
impl AudioSampleSource {
pub fn from_file<P: AsRef<Path>>(
path: P,
start_time: Option<std::time::Duration>,
buffer_size: usize,
) -> Result<Self, SampleSourceError> {
let path_ref = path.as_ref();
let file = File::open(path_ref).map_err(|e| {
SampleSourceError::IoError(std::io::Error::new(
e.kind(),
format!("{}: {}", path_ref.display(), e),
))
})?;
let mss = MediaSourceStream::new(Box::new(file), Default::default());
let mut hint = Hint::new();
if let Some(extension) = path.as_ref().extension().and_then(|ext| ext.to_str()) {
hint.with_extension(extension);
}
let meta_opts: MetadataOptions = Default::default();
let fmt_opts: FormatOptions = Default::default();
let probe = get_probe();
let file_path = path.as_ref().to_string_lossy().to_string();
let probed = probe
.format(&hint, mss, &fmt_opts, &meta_opts)
.map_err(|e| {
SampleSourceError::SampleConversionFailed(format!("'{}': {}", file_path, e))
})?;
let mut format_reader = probed.format;
let track = format_reader
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.ok_or_else(|| {
SampleSourceError::SampleConversionFailed("No audio track found".to_string())
})?;
let track_id = track.id;
let params = &track.codec_params;
let sample_rate = params.sample_rate.ok_or_else(|| {
SampleSourceError::SampleConversionFailed("Sample rate not specified".to_string())
})?;
let bits_per_sample = params.bits_per_sample.unwrap_or(16) as u16;
let sample_format = if params.codec == symphonia::core::codecs::CODEC_TYPE_PCM_F32LE
|| params.codec == symphonia::core::codecs::CODEC_TYPE_PCM_F32BE
|| params.codec == symphonia::core::codecs::CODEC_TYPE_PCM_F64LE
|| params.codec == symphonia::core::codecs::CODEC_TYPE_PCM_F64BE
{
crate::audio::SampleFormat::Float
} else {
crate::audio::SampleFormat::Int
};
let duration = if let Some(n_frames) = params.n_frames {
std::time::Duration::from_secs_f64(n_frames as f64 / sample_rate as f64)
} else {
std::time::Duration::ZERO
};
let decoder_opts: DecoderOptions = Default::default();
let mut decoder = get_codecs().make(params, &decoder_opts).map_err(|e| {
SampleSourceError::SampleConversionFailed(format!("'{}': {}", file_path, e))
})?;
let channels = params.channels.map(|c| c.count() as u16).unwrap_or(0);
let force_detect = cfg!(test) && std::env::var("MTRACK_FORCE_DETECT_CHANNELS").is_ok();
let (channels, initial_leftover) = if channels > 0 && !force_detect {
(channels, Vec::new())
} else {
Self::detect_channels_and_prime_buffer(
format_reader.as_mut(),
decoder.as_mut(),
track_id,
)?
};
let mut source = Self {
format_reader,
decoder,
track_id,
is_finished: false,
sample_buffer: Vec::with_capacity(buffer_size * channels as usize),
buffer_position: 0,
buffer_size,
leftover_samples: initial_leftover,
bits_per_sample,
channels,
sample_rate,
sample_format,
duration,
};
if let Some(start) = start_time {
source.leftover_samples.clear();
use symphonia::core::units::Time;
let seek_to = SeekTo::Time {
time: Time::from(start),
track_id: Some(track_id),
};
source.format_reader.seek(SeekMode::Accurate, seek_to)?;
}
Ok(source)
}
fn read_next_packet(
format_reader: &mut dyn FormatReader,
) -> Result<Option<symphonia::core::formats::Packet>, SampleSourceError> {
match format_reader.next_packet() {
Ok(packet) => Ok(Some(packet)),
Err(SymphoniaError::ResetRequired) => {
Err(SampleSourceError::AudioError(SymphoniaError::ResetRequired))
}
Err(SymphoniaError::IoError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
Ok(None)
}
Err(SymphoniaError::DecodeError(_)) => {
Ok(None)
}
Err(e) => Err(SampleSourceError::AudioError(e)),
}
}
fn read_and_decode_next_packet_for_track(
format_reader: &mut dyn FormatReader,
decoder: &mut dyn symphonia::core::codecs::Decoder,
track_id: u32,
) -> Result<Option<(Vec<f32>, usize)>, SampleSourceError> {
loop {
let packet = match Self::read_next_packet(format_reader) {
Ok(Some(packet)) => packet,
Ok(None) => return Ok(None),
Err(SampleSourceError::AudioError(SymphoniaError::ResetRequired)) => {
decoder.reset();
continue;
}
Err(e) => return Err(e),
};
if packet.track_id() != track_id {
continue;
}
let decoded = match decoder.decode(&packet) {
Ok(decoded) => decoded,
Err(SymphoniaError::ResetRequired) => {
decoder.reset();
match decoder.decode(&packet) {
Ok(decoded) => decoded,
Err(e) => return Err(SampleSourceError::AudioError(e)),
}
}
Err(e) => return Err(SampleSourceError::AudioError(e)),
};
let (samples, channels) = Self::decode_buffer_to_f32(decoded)?;
if channels > 0 && !samples.is_empty() {
return Ok(Some((samples, channels)));
}
}
}
pub fn read_samples(&mut self, buf: &mut [f32]) -> Result<usize, SampleSourceError> {
if self.is_finished || buf.is_empty() {
return Ok(0);
}
let mut written = 0;
while written < buf.len() {
let available = self.sample_buffer.len() - self.buffer_position;
if available > 0 {
let to_copy = available.min(buf.len() - written);
buf[written..written + to_copy].copy_from_slice(
&self.sample_buffer[self.buffer_position..self.buffer_position + to_copy],
);
self.buffer_position += to_copy;
written += to_copy;
} else {
self.refill_buffer()?;
if self.sample_buffer.is_empty() {
self.is_finished = true;
break;
}
}
}
Ok(written)
}
fn refill_buffer(&mut self) -> Result<(), SampleSourceError> {
self.sample_buffer.clear();
self.buffer_position = 0;
let mut samples_read = 0;
let target_samples = self.buffer_size * self.channels as usize;
if !self.leftover_samples.is_empty() {
let to_take = target_samples.min(self.leftover_samples.len());
self.sample_buffer
.extend_from_slice(&self.leftover_samples[..to_take]);
samples_read += to_take;
if self.leftover_samples.len() > to_take {
self.leftover_samples.drain(..to_take);
} else {
self.leftover_samples.clear();
}
if samples_read >= target_samples {
return Ok(());
}
}
loop {
let (samples, _decoded_channels) = match Self::read_and_decode_next_packet_for_track(
self.format_reader.as_mut(),
self.decoder.as_mut(),
self.track_id,
) {
Ok(Some((samples, ch))) => (samples, ch),
Ok(None) => break,
Err(e) => {
if samples_read == 0 && self.sample_buffer.is_empty() {
break;
}
return Err(e);
}
};
if !samples.is_empty() {
let remaining = target_samples.saturating_sub(samples_read);
let to_take = remaining.min(samples.len());
self.sample_buffer.extend_from_slice(&samples[..to_take]);
samples_read += to_take;
if samples.len() > to_take {
self.leftover_samples.extend_from_slice(&samples[to_take..]);
break;
}
if samples.len() < 32 {
break;
}
}
if samples_read >= target_samples {
break;
}
}
if samples_read == 0 && self.leftover_samples.is_empty() {
self.is_finished = true;
}
Ok(())
}
fn detect_channels_and_prime_buffer(
format_reader: &mut dyn FormatReader,
decoder: &mut dyn symphonia::core::codecs::Decoder,
track_id: u32,
) -> Result<(u16, Vec<f32>), SampleSourceError> {
match Self::read_and_decode_next_packet_for_track(format_reader, decoder, track_id)? {
Some((samples, channels)) => Ok((channels as u16, samples)),
None => Err(SampleSourceError::SampleConversionFailed(
"Channels not specified".to_string(),
)),
}
}
#[cfg_attr(test, allow(dead_code))]
pub(crate) fn decode_buffer_to_f32(
decoded: AudioBufferRef,
) -> Result<(Vec<f32>, usize), SampleSourceError> {
match decoded {
AudioBufferRef::F32(buf) => Ok(Self::interleave_planar_samples(&buf, |sample| sample)),
AudioBufferRef::F64(buf) => Ok(Self::interleave_planar_samples(&buf, |sample| {
sample as f32
})),
AudioBufferRef::S8(buf) => Ok(Self::interleave_planar_samples(&buf, |sample| {
Self::scale_s8(sample)
})),
AudioBufferRef::S16(buf) => Ok(Self::interleave_planar_samples(&buf, |sample| {
Self::scale_s16(sample)
})),
AudioBufferRef::S24(buf) => Ok(Self::interleave_planar_samples(&buf, |sample| {
Self::scale_s24(sample.inner())
})),
AudioBufferRef::S32(buf) => Ok(Self::interleave_planar_samples(&buf, |sample| {
Self::scale_s32(sample)
})),
AudioBufferRef::U8(buf) => Ok(Self::interleave_planar_samples(&buf, |sample| {
Self::scale_u8(sample)
})),
AudioBufferRef::U16(buf) => Ok(Self::interleave_planar_samples(&buf, |sample| {
Self::scale_u16(sample)
})),
AudioBufferRef::U24(buf) => Ok(Self::interleave_planar_samples(&buf, |sample| {
Self::scale_u24(sample.inner())
})),
AudioBufferRef::U32(buf) => Ok(Self::interleave_planar_samples(&buf, |sample| {
Self::scale_u32(sample)
})),
}
}
fn interleave_planar_samples<T, F>(buf: &AudioBuffer<T>, convert: F) -> (Vec<f32>, usize)
where
T: symphonia::core::sample::Sample,
F: Fn(T) -> f32,
{
let frames = buf.frames();
let channels = buf.spec().channels.count();
let planes = buf.planes();
let mut samples = Vec::with_capacity(frames * channels);
for frame_idx in 0..frames {
for ch_idx in 0..channels {
samples.push(convert(planes.planes()[ch_idx][frame_idx]));
}
}
(samples, channels)
}
#[inline]
pub(crate) fn scale_s8(sample: i8) -> f32 {
sample as f32 / (1i64 << 7) as f32
}
#[inline]
pub(crate) fn scale_s16(sample: i16) -> f32 {
sample as f32 / (1i64 << 15) as f32
}
#[inline]
pub(crate) fn scale_s24(sample: i32) -> f32 {
sample as f32 / (1i64 << 23) as f32
}
#[inline]
pub(crate) fn scale_s32(sample: i32) -> f32 {
sample as f32 / (1i64 << 31) as f32
}
#[inline]
pub(crate) fn scale_u8(sample: u8) -> f32 {
(sample as f32 / u8::MAX as f32) * 2.0 - 1.0
}
#[inline]
pub(crate) fn scale_u16(sample: u16) -> f32 {
(sample as f32 / u16::MAX as f32) * 2.0 - 1.0
}
#[inline]
pub(crate) fn scale_u24(sample: u32) -> f32 {
let max = (1u32 << 24) - 1;
(sample as f32 / max as f32) * 2.0 - 1.0
}
#[inline]
pub(crate) fn scale_u32(sample: u32) -> f32 {
(sample as f32 / u32::MAX as f32) * 2.0 - 1.0
}
}
#[cfg(test)]
impl SampleSourceTestExt for AudioSampleSource {
fn is_finished(&self) -> bool {
self.is_finished
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::audio::sample_source::traits::SampleSource;
use crate::testutil::write_wav;
#[test]
fn from_file_nonexistent_returns_error() {
let result = AudioSampleSource::from_file("/no/such/file.wav", None, 1024);
assert!(result.is_err());
}
#[test]
fn from_file_reads_wav_successfully() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("test.wav");
write_wav(wav_path.clone(), vec![vec![0.5f32; 100]], 44100).unwrap();
let source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
assert_eq!(source.channel_count(), 1);
assert_eq!(source.sample_rate(), 44100);
assert!(source.duration().is_some());
}
#[test]
fn from_file_stereo() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("stereo.wav");
write_wav(
wav_path.clone(),
vec![vec![0.5f32; 100], vec![0.3f32; 100]],
44100,
)
.unwrap();
let source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
assert_eq!(source.channel_count(), 2);
}
#[test]
fn next_sample_reads_all_samples() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("test.wav");
let num_samples = 50;
write_wav(wav_path.clone(), vec![vec![0.25f32; num_samples]], 44100).unwrap();
let mut source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
let mut count = 0;
while let Ok(Some(_)) = source.next_sample() {
count += 1;
if count > num_samples + 10 {
break; }
}
assert_eq!(count, num_samples);
}
#[test]
fn from_file_with_start_time() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("seek.wav");
write_wav(wav_path.clone(), vec![vec![0.5f32; 44100]], 44100).unwrap();
let mut source = AudioSampleSource::from_file(
&wav_path,
Some(std::time::Duration::from_millis(500)),
1024,
)
.unwrap();
let mut count = 0;
while let Ok(Some(_)) = source.next_sample() {
count += 1;
if count > 44100 {
break;
}
}
assert!(
count < 44100,
"seeking to 0.5s should produce fewer samples than full file"
);
assert!(
count > 10000,
"seeking to 0.5s should still produce many samples, got {}",
count
);
}
#[test]
fn sample_format_is_int_for_pcm_wav() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("int.wav");
crate::testutil::write_wav_with_bits(wav_path.clone(), vec![vec![1000i16; 50]], 44100, 16)
.unwrap();
let source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
assert_eq!(source.sample_format(), crate::audio::SampleFormat::Int);
assert_eq!(source.bits_per_sample(), 16);
}
#[test]
fn sample_format_is_float_for_float_wav() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("float.wav");
write_wav(wav_path.clone(), vec![vec![0.5f32; 50]], 44100).unwrap();
let source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
assert_eq!(source.sample_format(), crate::audio::SampleFormat::Float);
}
#[test]
fn detect_channels_via_env_var() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("detect.wav");
write_wav(wav_path.clone(), vec![vec![0.5f32; 200]], 44100).unwrap();
std::env::set_var("MTRACK_FORCE_DETECT_CHANNELS", "1");
let source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
std::env::remove_var("MTRACK_FORCE_DETECT_CHANNELS");
assert_eq!(source.channel_count(), 1);
let mut src = source;
let mut count = 0;
while let Ok(Some(_)) = src.next_sample() {
count += 1;
if count > 300 {
break;
}
}
assert!(count > 0, "should read samples after channel detection");
}
#[test]
fn refill_buffer_small_file() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("small.wav");
write_wav(wav_path.clone(), vec![vec![0.1f32; 3]], 44100).unwrap();
let mut source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
let mut samples = Vec::new();
while let Ok(Some(s)) = source.next_sample() {
samples.push(s);
if samples.len() > 100 {
break;
}
}
assert_eq!(samples.len(), 3);
}
#[test]
fn is_finished_tracking() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("fin.wav");
write_wav(wav_path.clone(), vec![vec![0.5f32; 5]], 44100).unwrap();
let mut source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
assert!(!SampleSourceTestExt::is_finished(&source));
while let Ok(Some(_)) = source.next_sample() {}
assert!(SampleSourceTestExt::is_finished(&source));
}
#[test]
fn next_sample_after_finished_returns_none() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("tiny.wav");
write_wav(wav_path.clone(), vec![vec![0.5f32; 2]], 44100).unwrap();
let mut source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
while let Ok(Some(_)) = source.next_sample() {}
assert_eq!(source.next_sample().unwrap(), None);
assert_eq!(source.next_sample().unwrap(), None);
}
#[test]
fn from_file_invalid_format() {
let dir = tempfile::tempdir().unwrap();
let bad_path = dir.path().join("bad.wav");
std::fs::write(&bad_path, b"not a wav file").unwrap();
let result = AudioSampleSource::from_file(&bad_path, None, 1024);
assert!(result.is_err());
}
#[test]
fn read_samples_matches_next_sample() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("bulk.wav");
let num_samples = 200;
let data: Vec<f32> = (0..num_samples)
.map(|i| i as f32 / num_samples as f32)
.collect();
write_wav(wav_path.clone(), vec![data], 44100).unwrap();
let mut source1 = AudioSampleSource::from_file(&wav_path, None, 64).unwrap();
let mut expected = Vec::new();
while let Ok(Some(s)) = source1.next_sample() {
expected.push(s);
}
let mut source2 = AudioSampleSource::from_file(&wav_path, None, 64).unwrap();
let mut actual = vec![0.0_f32; num_samples + 10];
let n = source2.read_samples(&mut actual).unwrap();
actual.truncate(n);
assert_eq!(expected.len(), actual.len());
for (i, (a, b)) in expected.iter().zip(actual.iter()).enumerate() {
assert!(
(a - b).abs() < f32::EPSILON,
"sample {} differs: {} vs {}",
i,
a,
b
);
}
}
#[test]
fn read_samples_partial_buffer() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("partial.wav");
write_wav(wav_path.clone(), vec![vec![0.5f32; 100]], 44100).unwrap();
let mut source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
let mut all = Vec::new();
let mut buf = [0.0_f32; 7]; loop {
let n = source.read_samples(&mut buf).unwrap();
if n == 0 {
break;
}
all.extend_from_slice(&buf[..n]);
}
assert_eq!(all.len(), 100);
}
#[test]
fn read_samples_returns_zero_after_eof() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("eof.wav");
write_wav(wav_path.clone(), vec![vec![0.5f32; 5]], 44100).unwrap();
let mut source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
let mut buf = [0.0_f32; 100];
let n = source.read_samples(&mut buf).unwrap();
assert_eq!(n, 5);
assert_eq!(source.read_samples(&mut buf).unwrap(), 0);
assert_eq!(source.read_samples(&mut buf).unwrap(), 0);
}
#[test]
fn read_samples_empty_buffer() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("empty_buf.wav");
write_wav(wav_path.clone(), vec![vec![0.5f32; 10]], 44100).unwrap();
let mut source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
let mut buf = [];
assert_eq!(source.read_samples(&mut buf).unwrap(), 0);
}
#[test]
fn read_samples_tiny_file() {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join("tiny.wav");
write_wav(wav_path.clone(), vec![vec![0.25f32; 1]], 44100).unwrap();
let mut source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
let mut buf = [0.0_f32; 100];
let n = source.read_samples(&mut buf).unwrap();
assert_eq!(n, 1);
assert!((buf[0] - 0.25).abs() < 0.01);
}
#[test]
fn different_sample_rates() {
for rate in &[22050u32, 44100, 48000, 96000] {
let dir = tempfile::tempdir().unwrap();
let wav_path = dir.path().join(format!("rate_{}.wav", rate));
write_wav(wav_path.clone(), vec![vec![0.5f32; 100]], *rate).unwrap();
let source = AudioSampleSource::from_file(&wav_path, None, 1024).unwrap();
assert_eq!(source.sample_rate(), *rate);
}
}
}