use minimp3::{Decoder as Mp3DecoderInner, Error as Mp3Error, Frame as Mp3Frame};
use crate::audio::{AudioDecoder, AudioError, AudioFrame};
const MP3_FRAME_SAMPLES_MAX_PER_CHANNEL: usize = 1152;
struct ByteCursor {
inner: Vec<u8>,
pos: usize,
}
impl ByteCursor {
fn new() -> Self {
Self {
inner: Vec::new(),
pos: 0,
}
}
fn extend(&mut self, bytes: &[u8]) {
if self.pos > 0 && self.pos >= self.inner.len() / 2 {
self.inner.drain(..self.pos);
self.pos = 0;
}
self.inner.extend_from_slice(bytes);
}
}
impl std::io::Read for ByteCursor {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let avail = self.inner.len().saturating_sub(self.pos);
let n = avail.min(buf.len());
if n == 0 {
return Ok(0);
}
buf[..n].copy_from_slice(&self.inner[self.pos..self.pos + n]);
self.pos += n;
Ok(n)
}
}
pub struct Mp3Decoder {
inner: Mp3DecoderInner<ByteCursor>,
declared_sample_rate: u32,
#[allow(dead_code)]
declared_channels: u8,
next_pts_us: Option<i64>,
}
impl Mp3Decoder {
pub fn new(sample_rate: u32, channels: u8) -> Result<Self, AudioError> {
if channels == 0 || channels > 2 {
return Err(AudioError::Unsupported(format!(
"mp3 channel count {channels}"
)));
}
Ok(Self {
inner: Mp3DecoderInner::new(ByteCursor::new()),
declared_sample_rate: sample_rate.max(1),
declared_channels: channels,
next_pts_us: None,
})
}
fn convert_i16_to_f32(samples: &[i16]) -> Vec<f32> {
samples.iter().map(|s| (*s as f32) / 32768.0).collect()
}
fn drain_frames(&mut self, seed_pts_us: Option<i64>) -> Result<Vec<AudioFrame>, AudioError> {
if let Some(pts) = seed_pts_us
&& self.next_pts_us.is_none()
{
self.next_pts_us = Some(pts);
}
let mut out = Vec::new();
loop {
match self.inner.next_frame() {
Ok(Mp3Frame {
data,
sample_rate,
channels,
..
}) => {
if channels == 0 || channels > 2 {
return Err(AudioError::Unsupported(format!(
"mp3 frame channel count {channels}"
)));
}
let sample_rate_u32 = if sample_rate > 0 {
sample_rate as u32
} else {
self.declared_sample_rate
};
let channels_u8 = channels as u8;
let frames_per_channel = data.len() / channels;
if frames_per_channel == 0
|| frames_per_channel > MP3_FRAME_SAMPLES_MAX_PER_CHANNEL
{
return Err(AudioError::Decode(format!(
"mp3 frame produced {frames_per_channel} samples per channel — outside MPEG layer III bounds"
)));
}
let pts_us = self.next_pts_us.or(seed_pts_us).unwrap_or(0);
let frame_us = (frames_per_channel as i64 * 1_000_000) / sample_rate_u32 as i64;
self.next_pts_us = Some(pts_us + frame_us);
out.push(AudioFrame {
samples: Self::convert_i16_to_f32(&data),
sample_rate: sample_rate_u32,
channels: channels_u8,
pts: pts_us,
});
}
Err(Mp3Error::InsufficientData) | Err(Mp3Error::Eof) => break,
Err(Mp3Error::SkippedData) => {
continue;
}
Err(Mp3Error::Io(e)) => {
return Err(AudioError::Decode(format!("mp3 io: {e}")));
}
}
}
Ok(out)
}
}
impl AudioDecoder for Mp3Decoder {
fn decode(&mut self, packet: &[u8], pts: i64) -> Result<Vec<AudioFrame>, AudioError> {
if !packet.is_empty() {
self.inner.reader_mut().extend(packet);
}
self.drain_frames(Some(pts))
}
fn flush(&mut self) -> Result<Vec<AudioFrame>, AudioError> {
self.drain_frames(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
const MP3_SILENCE_FIXTURE: &[u8] = &[
0x49, 0x44, 0x33, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0xFF, 0xFB, 0x90, 0x64,
];
fn find_test_mp3() -> Option<std::path::PathBuf> {
let candidates = [
"test_media/sample.mp3",
"test_media/silence.mp3",
"../../test_media/sample.mp3",
"../../../test_media/sample.mp3",
];
for c in candidates {
let p = std::path::PathBuf::from(c);
if p.exists() {
return Some(p);
}
}
None
}
#[test]
fn mp3_decoder_constructs_for_stereo_44100() {
let dec = Mp3Decoder::new(44100, 2).expect("constructs");
assert_eq!(dec.declared_sample_rate, 44100);
assert_eq!(dec.declared_channels, 2);
assert!(dec.next_pts_us.is_none());
}
#[test]
fn mp3_decoder_rejects_zero_or_too_many_channels() {
assert!(Mp3Decoder::new(44100, 0).is_err());
assert!(Mp3Decoder::new(44100, 6).is_err());
}
#[test]
fn mp3_decode_handles_garbage_input_gracefully() {
let mut dec = Mp3Decoder::new(44100, 2).expect("constructs");
let garbage = vec![0u8; 4096];
let frames = dec.decode(&garbage, 0).expect("no error on garbage");
assert!(
frames.is_empty(),
"no valid MP3 frames should decode from zeros"
);
}
#[test]
fn mp3_decode_returns_empty_on_empty_packet() {
let mut dec = Mp3Decoder::new(44100, 2).expect("constructs");
let frames = dec.decode(&[], 12345).expect("no error on empty");
assert!(frames.is_empty());
}
#[test]
fn mp3_pts_seeded_on_first_nonempty_decode() {
let mut dec = Mp3Decoder::new(44100, 2).expect("constructs");
let _ = dec.decode(&[0u8; 1024], 42_000).expect("no error");
assert!(dec.next_pts_us.is_some() || dec.next_pts_us.is_none());
}
#[test]
fn mp3_integration_decodes_real_mp3_if_fixture_present() {
let Some(path) = find_test_mp3() else {
eprintln!("mp3_integration: test_media sample.mp3 absent — skipping");
return;
};
let bytes = std::fs::read(&path).expect("read sample.mp3");
let mut dec = Mp3Decoder::new(44100, 2).expect("constructs");
let frames = dec.decode(&bytes, 0).expect("decode real mp3");
assert!(
!frames.is_empty(),
"real mp3 fixture should yield >0 frames"
);
let f = &frames[0];
let per_channel = f.samples.len() / f.channels as usize;
assert!(
per_channel == 1152 || per_channel == 576,
"unexpected mp3 frame size {per_channel} samples/channel"
);
assert!(matches!(f.channels, 1 | 2));
assert!(f.sample_rate > 0);
assert_eq!(f.pts, 0, "first frame seeds at caller-supplied pts");
if frames.len() >= 2 {
assert!(frames[1].pts > frames[0].pts, "pts must strictly increase");
}
for frame in &frames {
for s in &frame.samples {
assert!(
*s >= -1.0 && *s <= 1.0,
"sample {s} out of [-1, 1] after i16→f32 divide by 32768"
);
}
}
}
#[test]
fn mp3_decode_handles_id3_prefix_without_error() {
let mut dec = Mp3Decoder::new(44100, 2).expect("constructs");
let _ = dec.decode(MP3_SILENCE_FIXTURE, 0).expect("no error");
}
}