use anyhow::{Context, Result};
use bytes::Bytes;
use symphonia::core::audio::SampleBuffer;
use symphonia::core::codecs::DecoderOptions;
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::{MediaSource, MediaSourceStream};
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;
use super::{HOP_LENGTH, N_FFT};
#[allow(dead_code)]
const MAX_BUFFER_SAMPLES: usize = 16000 * 5; const MAX_DURATION_S: f64 = 600.0;
pub(crate) struct BytesMediaSource {
data: Bytes,
pos: u64,
}
impl BytesMediaSource {
pub(crate) fn new(data: Bytes) -> Self {
Self { data, pos: 0 }
}
}
impl std::io::Read for BytesMediaSource {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let len = self.data.len() as u64;
if self.pos >= len {
return Ok(0);
}
let start = self.pos as usize;
let available = self.data.len() - start;
let n = available.min(buf.len());
buf[..n].copy_from_slice(&self.data[start..start + n]);
self.pos += n as u64;
Ok(n)
}
}
impl std::io::Seek for BytesMediaSource {
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
let len = self.data.len() as u64;
let new_pos: i128 = match pos {
std::io::SeekFrom::Start(n) => n as i128,
std::io::SeekFrom::End(off) => len as i128 + off as i128,
std::io::SeekFrom::Current(off) => self.pos as i128 + off as i128,
};
if new_pos < 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"seek before start of buffer",
));
}
self.pos = new_pos as u64;
Ok(self.pos)
}
}
impl MediaSource for BytesMediaSource {
fn is_seekable(&self) -> bool {
true
}
fn byte_len(&self) -> Option<u64> {
Some(self.data.len() as u64)
}
}
pub fn decode_audio_file(path: &str) -> Result<Vec<f32>> {
let file =
std::fs::File::open(path).with_context(|| format!("Failed to open audio file: {path}"))?;
let mss = MediaSourceStream::new(Box::new(file), Default::default());
let mut hint = Hint::new();
if let Some(ext) = std::path::Path::new(path)
.extension()
.and_then(|e| e.to_str())
{
hint.with_extension(ext);
}
let source_label = format!(
"format={}",
std::path::Path::new(path)
.extension()
.unwrap_or_default()
.to_string_lossy()
);
decode_audio_inner(mss, hint, &source_label)
}
pub fn decode_audio_bytes(data: &[u8]) -> Result<Vec<f32>> {
decode_audio_bytes_shared(Bytes::copy_from_slice(data))
}
pub fn decode_audio_bytes_shared(data: Bytes) -> Result<Vec<f32>> {
let source = BytesMediaSource::new(data);
let mss = MediaSourceStream::new(Box::new(source), Default::default());
let hint = Hint::new();
decode_audio_inner(mss, hint, "bytes")
}
fn decode_audio_inner(mss: MediaSourceStream, hint: Hint, source_label: &str) -> Result<Vec<f32>> {
let probed = symphonia::default::get_probe()
.format(
&hint,
mss,
&FormatOptions::default(),
&MetadataOptions::default(),
)
.context("Unsupported audio format")?;
let mut format = probed.format;
let track = format.default_track().context("No audio track found")?;
let track_id = track.id;
let sample_rate = track
.codec_params
.sample_rate
.context("Unknown sample rate")?;
let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(1);
let n_frames_hint = track.codec_params.n_frames;
tracing::info!("Audio ({source_label}): {sample_rate}Hz, {channels}ch");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &DecoderOptions::default())
.context("Unsupported audio codec")?;
let mut all_samples: Vec<f32> = match n_frames_hint {
Some(n) if n > 0 && n <= (MAX_DURATION_S as u64 + 1) * sample_rate as u64 => {
Vec::with_capacity(n as usize)
}
_ => Vec::new(),
};
let max_samples: usize = (MAX_DURATION_S * sample_rate as f64) as usize;
loop {
let packet = match format.next_packet() {
Ok(p) => p,
Err(symphonia::core::errors::Error::IoError(ref e))
if e.kind() == std::io::ErrorKind::UnexpectedEof =>
{
break;
}
Err(e) => return Err(anyhow::anyhow!("Error reading packet: {e}")),
};
if packet.track_id() != track_id {
continue;
}
let decoded = decoder.decode(&packet).context("Decode error")?;
let spec = *decoded.spec();
let num_frames = decoded.frames();
let mut sample_buf = SampleBuffer::<f32>::new(num_frames as u64, spec);
sample_buf.copy_interleaved_ref(decoded);
let samples = sample_buf.samples();
if spec.channels.count() > 1 {
let ch = spec.channels.count();
for frame in 0..num_frames {
let mut sum = 0.0_f32;
for c in 0..ch {
sum += samples[frame * ch + c];
}
all_samples.push(sum / ch as f32);
}
} else {
all_samples.extend_from_slice(samples);
}
if all_samples.len() > max_samples {
let observed_s = all_samples.len() as f64 / sample_rate as f64;
anyhow::bail!(
"Audio file too long ({:.0}s). Maximum supported: {MAX_DURATION_S:.0}s.",
observed_s
);
}
}
let duration_s = all_samples.len() as f64 / sample_rate as f64;
tracing::info!(
"Decoded {} samples at {}Hz ({:.1}s)",
all_samples.len(),
sample_rate,
duration_s
);
if sample_rate != 16000 {
all_samples = resample(&all_samples, sample_rate, 16000).context("Resampling failed")?;
tracing::info!("Resampled to 16kHz: {} samples", all_samples.len());
}
Ok(all_samples)
}
pub fn resample(samples: &[f32], from_rate: u32, to_rate: u32) -> Result<Vec<f32>> {
if samples.is_empty() || from_rate == 0 || to_rate == 0 {
return Ok(Vec::new());
}
if from_rate == to_rate {
return Ok(samples.to_vec());
}
let samples: Vec<f32> = samples
.iter()
.map(|&s| if s.is_finite() { s } else { 0.0 })
.collect();
use rubato::{
Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction,
};
let params = SincInterpolationParameters {
sinc_len: 256,
f_cutoff: 0.95,
interpolation: SincInterpolationType::Linear,
oversampling_factor: 256,
window: WindowFunction::BlackmanHarris2,
};
let ratio = to_rate as f64 / from_rate as f64;
let mut resampler = SincFixedIn::<f32>::new(
ratio,
2.0,
params,
samples.len(),
1, )
.map_err(|e| anyhow::anyhow!("Resampler init failed: {e}"))?;
let waves_in = vec![samples];
let mut waves_out = resampler
.process(&waves_in, None)
.map_err(|e| anyhow::anyhow!("Resampling failed: {e}"))?;
Ok(waves_out.remove(0))
}
#[allow(dead_code)]
pub(crate) fn prepare_audio_buffer(new_samples: &[f32], buffer: &mut Vec<f32>) -> Option<Vec<f32>> {
let mut all_samples = std::mem::take(buffer);
all_samples.extend_from_slice(new_samples);
if all_samples.len() > MAX_BUFFER_SAMPLES {
tracing::warn!("Audio buffer exceeded 5s limit, truncating");
all_samples = all_samples[all_samples.len() - MAX_BUFFER_SAMPLES..].to_vec();
}
let hop_length = HOP_LENGTH;
let n_fft = N_FFT;
let usable = if all_samples.len() >= n_fft {
let num_frames = (all_samples.len() - n_fft) / hop_length + 1;
(num_frames - 1) * hop_length + n_fft
} else {
0
};
if usable == 0 {
*buffer = all_samples;
return None;
}
*buffer = all_samples[usable..].to_vec();
Some(all_samples[..usable].to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resample_downsample_length() {
let input: Vec<f32> = (0..4800).map(|i| (i as f32).sin()).collect();
let output = resample(&input, 48000, 16000).unwrap();
assert!(!output.is_empty());
assert!(
output.len() > 1400 && output.len() < 1700,
"Unexpected output length: {}",
output.len()
);
}
#[test]
fn test_resample_upsample_length() {
let input: Vec<f32> = (0..800).map(|i| (i as f32).sin()).collect();
let output = resample(&input, 8000, 16000).unwrap();
assert!(!output.is_empty());
assert!(
output.len() > 1200 && output.len() < 1700,
"Unexpected output length: {}",
output.len()
);
}
#[test]
fn test_resample_preserves_dc() {
let input = vec![0.5_f32; 4800];
let output = resample(&input, 48000, 16000).unwrap();
let start = output.len() / 10;
let end = output.len() - start;
for &sample in &output[start..end] {
assert!(
(sample - 0.5).abs() < 0.05,
"DC signal not preserved: {sample}"
);
}
}
#[test]
fn test_resample_empty() {
let output = resample(&[], 48000, 16000).unwrap();
assert!(output.is_empty());
}
#[test]
fn test_resample_zero_rate_returns_empty() {
let input = vec![1.0, 2.0, 3.0];
assert!(resample(&input, 0, 16000).unwrap().is_empty());
assert!(resample(&input, 16000, 0).unwrap().is_empty());
}
#[test]
fn test_resample_same_rate() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let output = resample(&input, 16000, 16000).unwrap();
assert_eq!(output.len(), input.len());
for (a, b) in input.iter().zip(output.iter()) {
assert!((a - b).abs() < 1e-5);
}
}
#[test]
fn test_buffer_short_input_returns_none() {
let new_samples = vec![0.0; 100];
let mut buffer = Vec::new();
let result = prepare_audio_buffer(&new_samples, &mut buffer);
assert!(result.is_none());
assert_eq!(buffer.len(), 100);
}
#[test]
fn test_buffer_exact_frame() {
let new_samples = vec![1.0; N_FFT];
let mut buffer = Vec::new();
let result = prepare_audio_buffer(&new_samples, &mut buffer);
assert!(result.is_some());
assert_eq!(result.unwrap().len(), N_FFT);
assert!(buffer.is_empty());
}
#[test]
fn test_buffer_leftover_correct() {
let new_samples = vec![1.0; N_FFT + 50];
let mut buffer = Vec::new();
let result = prepare_audio_buffer(&new_samples, &mut buffer);
assert!(result.is_some());
let usable = result.unwrap();
assert_eq!(usable.len(), N_FFT); assert_eq!(buffer.len(), 50);
}
#[test]
fn test_buffer_accumulates_across_calls() {
let mut buffer = Vec::new();
let half = N_FFT / 2;
let result = prepare_audio_buffer(&vec![1.0; half], &mut buffer);
assert!(result.is_none());
assert_eq!(buffer.len(), half);
let result = prepare_audio_buffer(&vec![2.0; N_FFT - half], &mut buffer);
assert!(result.is_some());
let usable = result.unwrap();
assert_eq!(usable.len(), N_FFT);
assert_eq!(buffer.len(), 0);
}
#[test]
fn test_buffer_truncation_at_5s() {
let mut buffer = vec![0.0; 90000];
let new_samples = vec![1.0; 1000];
let result = prepare_audio_buffer(&new_samples, &mut buffer);
assert!(result.is_some());
let usable = result.unwrap();
assert!(usable.len() + buffer.len() <= MAX_BUFFER_SAMPLES);
}
#[test]
fn test_buffer_multi_frame() {
let new_samples = vec![1.0; N_FFT + HOP_LENGTH];
let mut buffer = Vec::new();
let result = prepare_audio_buffer(&new_samples, &mut buffer);
assert!(result.is_some());
assert_eq!(result.unwrap().len(), N_FFT + HOP_LENGTH);
assert!(buffer.is_empty());
}
#[test]
fn test_resample_nan_input() {
let input = vec![f32::NAN; 1000];
let output = resample(&input, 48000, 16000).unwrap();
assert!(!output.is_empty());
for &s in &output {
assert!(s.is_finite(), "NaN should be sanitized to zero, got {s}");
}
}
#[test]
fn test_resample_infinity_input() {
let input = vec![f32::INFINITY; 500];
let output = resample(&input, 48000, 16000).unwrap();
assert!(!output.is_empty());
for &s in &output {
assert!(
s.is_finite(),
"Infinity should be sanitized to zero, got {s}"
);
}
}
#[test]
fn test_resample_mixed_nan_normal() {
let mut input = vec![0.5_f32; 480];
input[100] = f32::NAN;
input[200] = f32::NEG_INFINITY;
let output = resample(&input, 48000, 16000).unwrap();
assert!(!output.is_empty());
for &s in &output {
assert!(s.is_finite(), "Non-finite values should be sanitized");
}
}
#[test]
fn test_prepare_buffer_empty_input() {
let mut buffer = vec![1.0; 100];
let result = prepare_audio_buffer(&[], &mut buffer);
assert!(result.is_none());
assert_eq!(buffer.len(), 100);
}
#[test]
fn test_prepare_buffer_exactly_max() {
let new_samples = vec![1.0; MAX_BUFFER_SAMPLES];
let mut buffer = Vec::new();
let result = prepare_audio_buffer(&new_samples, &mut buffer);
assert!(result.is_some());
let usable = result.unwrap();
assert!(usable.len() + buffer.len() <= MAX_BUFFER_SAMPLES);
}
#[test]
fn test_prepare_buffer_one_over_max() {
let new_samples = vec![1.0; MAX_BUFFER_SAMPLES + 1];
let mut buffer = Vec::new();
let result = prepare_audio_buffer(&new_samples, &mut buffer);
assert!(result.is_some());
let usable = result.unwrap();
assert!(usable.len() + buffer.len() <= MAX_BUFFER_SAMPLES);
}
fn make_wav_bytes(samples: &[i16], sample_rate: u32) -> Vec<u8> {
let data_size = (samples.len() * 2) as u32;
let file_size = 36 + data_size;
let mut buf = Vec::new();
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&file_size.to_le_bytes());
buf.extend_from_slice(b"WAVE");
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&16u32.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&(sample_rate * 2).to_le_bytes()); buf.extend_from_slice(&2u16.to_le_bytes()); buf.extend_from_slice(&16u16.to_le_bytes()); buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_size.to_le_bytes());
for &s in samples {
buf.extend_from_slice(&s.to_le_bytes());
}
buf
}
#[test]
fn test_decode_audio_bytes_empty() {
let result = decode_audio_bytes(&[]);
assert!(result.is_err(), "Expected error for empty input, got Ok");
}
#[test]
fn test_decode_audio_bytes_invalid_data() {
let garbage: Vec<u8> = (0u8..128).collect();
let result = decode_audio_bytes(&garbage);
assert!(
result.is_err(),
"Expected error for invalid audio data, got Ok"
);
}
#[test]
fn test_decode_audio_bytes_wav() {
let silence: Vec<i16> = vec![0; 16000]; let wav = make_wav_bytes(&silence, 16000);
let samples = decode_audio_bytes(&wav).unwrap();
assert!(!samples.is_empty());
assert!((samples.len() as i64 - 16000).unsigned_abs() <= 100);
}
use std::io::{Read, Seek, SeekFrom};
#[test]
fn bytes_media_source_read_full() {
let data = Bytes::from_static(b"hello world");
let mut src = BytesMediaSource::new(data.clone());
let mut buf = vec![0u8; data.len()];
let n = src.read(&mut buf).unwrap();
assert_eq!(n, data.len());
assert_eq!(buf, data.as_ref());
let mut more = [0u8; 4];
assert_eq!(src.read(&mut more).unwrap(), 0);
}
#[test]
fn bytes_media_source_seek_end() {
let data = Bytes::from_static(b"abcdefgh");
let mut src = BytesMediaSource::new(data);
let pos = src.seek(SeekFrom::End(0)).unwrap();
assert_eq!(pos, 8);
let mut buf = [0u8; 4];
assert_eq!(src.read(&mut buf).unwrap(), 0);
}
#[test]
fn bytes_media_source_seek_past_end_ok() {
let data = Bytes::from_static(b"abc");
let mut src = BytesMediaSource::new(data);
let pos = src.seek(SeekFrom::Start(42)).unwrap();
assert_eq!(pos, 42);
let mut buf = [0u8; 4];
assert_eq!(src.read(&mut buf).unwrap(), 0);
}
#[test]
fn bytes_media_source_seek_before_start_err() {
let data = Bytes::from_static(b"abc");
let mut src = BytesMediaSource::new(data);
let err = src.seek(SeekFrom::Start(2)).unwrap();
assert_eq!(err, 2);
let result = src.seek(SeekFrom::Current(-100));
assert!(result.is_err(), "seek before start should error");
}
#[test]
fn bytes_media_source_partial_read_progress() {
let data = Bytes::from_static(b"abcdefghij");
let mut src = BytesMediaSource::new(data.clone());
let mut out = Vec::new();
let mut chunk = [0u8; 3];
loop {
let n = src.read(&mut chunk).unwrap();
if n == 0 {
break;
}
out.extend_from_slice(&chunk[..n]);
}
assert_eq!(out, data.as_ref());
}
#[test]
fn bytes_media_source_byte_len_matches() {
use symphonia::core::io::MediaSource as _;
let data = Bytes::from_static(b"0123456789");
let src = BytesMediaSource::new(data.clone());
assert_eq!(src.byte_len(), Some(data.len() as u64));
assert!(src.is_seekable());
}
#[test]
fn decode_audio_shim_matches_shared() {
let silence: Vec<i16> = vec![0; 16000];
let wav = make_wav_bytes(&silence, 16000);
let via_shim = decode_audio_bytes(&wav).unwrap();
let via_shared = decode_audio_bytes_shared(Bytes::copy_from_slice(&wav)).unwrap();
assert_eq!(via_shim.len(), via_shared.len());
for (a, b) in via_shim.iter().zip(via_shared.iter()) {
assert!((a - b).abs() < f32::EPSILON);
}
}
#[test]
fn test_decode_duration_cap_streaming() {
let duration_s: usize = 12 * 60;
let silence: Vec<i16> = vec![0; duration_s * 16000];
let wav = make_wav_bytes(&silence, 16000);
let result = decode_audio_bytes_shared(Bytes::from(wav));
let err = result.expect_err("12-minute audio must be rejected");
let msg = format!("{err:#}");
assert!(
msg.to_lowercase().contains("too long"),
"error should mention 'too long', got: {msg}"
);
}
}