use alloc::format;
use alloc::vec::Vec;
use serde::{Deserialize, Serialize};
use crate::error::{Result, ShravanError};
use crate::format::{AudioFormat, FormatInfo};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum MpegVersion {
V1,
V2,
V25,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum MpegLayer {
I,
II,
III,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ChannelMode {
Stereo,
JointStereo,
DualChannel,
Mono,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Mp3FrameInfo {
pub version: MpegVersion,
pub layer: MpegLayer,
pub bitrate: u32,
pub sample_rate: u32,
pub channel_mode: ChannelMode,
pub frame_size: usize,
pub samples_per_frame: u32,
pub padding: bool,
}
const BITRATE_V1_L3: [u32; 16] = [
0, 32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 0,
];
const BITRATE_V2_L3: [u32; 16] = [
0, 8, 16, 24, 32, 40, 48, 56, 64, 80, 96, 112, 128, 144, 160, 0,
];
const SAMPLE_RATE_V1: [u32; 3] = [44100, 48000, 32000];
const SAMPLE_RATE_V2: [u32; 3] = [22050, 24000, 16000];
const SAMPLE_RATE_V25: [u32; 3] = [11025, 12000, 8000];
#[must_use]
#[inline]
fn syncsafe_to_u32(data: &[u8]) -> u32 {
(u32::from(data[0]) << 21)
| (u32::from(data[1]) << 14)
| (u32::from(data[2]) << 7)
| u32::from(data[3])
}
#[must_use]
fn id3v2_skip(data: &[u8]) -> usize {
if data.len() < 10 {
return 0;
}
if &data[0..3] != b"ID3" {
return 0;
}
let size = syncsafe_to_u32(&data[6..10]) as usize;
10 + size
}
#[must_use = "parsed frame info is returned and should not be discarded"]
pub fn parse_frame_header(header: &[u8; 4]) -> Result<Mp3FrameInfo> {
if header[0] != 0xFF || (header[1] & 0xE0) != 0xE0 {
return Err(ShravanError::InvalidHeader("invalid sync word".into()));
}
let version_bits = (header[1] >> 3) & 0x03;
let version = match version_bits {
0b00 => MpegVersion::V25,
0b10 => MpegVersion::V2,
0b11 => MpegVersion::V1,
_ => return Err(ShravanError::InvalidHeader("reserved MPEG version".into())),
};
let layer_bits = (header[1] >> 1) & 0x03;
let layer = match layer_bits {
0b01 => MpegLayer::III,
0b10 => MpegLayer::II,
0b11 => MpegLayer::I,
_ => return Err(ShravanError::InvalidHeader("reserved MPEG layer".into())),
};
let bitrate_index = ((header[2] >> 4) & 0x0F) as usize;
let bitrate = match (version, layer) {
(MpegVersion::V1, MpegLayer::III) => BITRATE_V1_L3[bitrate_index],
(MpegVersion::V2 | MpegVersion::V25, MpegLayer::III) => BITRATE_V2_L3[bitrate_index],
_ => {
BITRATE_V1_L3[bitrate_index]
}
};
if bitrate == 0 {
return Err(ShravanError::InvalidHeader(format!(
"invalid bitrate index: {bitrate_index}"
)));
}
let sr_index = ((header[2] >> 2) & 0x03) as usize;
if sr_index >= 3 {
return Err(ShravanError::InvalidHeader(
"reserved sample rate index".into(),
));
}
let sample_rate = match version {
MpegVersion::V1 => SAMPLE_RATE_V1[sr_index],
MpegVersion::V2 => SAMPLE_RATE_V2[sr_index],
MpegVersion::V25 => SAMPLE_RATE_V25[sr_index],
};
let padding = (header[2] >> 1) & 0x01 == 1;
let channel_mode = match (header[3] >> 6) & 0x03 {
0b00 => ChannelMode::Stereo,
0b01 => ChannelMode::JointStereo,
0b10 => ChannelMode::DualChannel,
_ => ChannelMode::Mono,
};
let samples_per_frame = match (version, layer) {
(MpegVersion::V1, MpegLayer::I) => 384,
(MpegVersion::V1, MpegLayer::II | MpegLayer::III) => 1152,
(MpegVersion::V2 | MpegVersion::V25, MpegLayer::I) => 384,
(MpegVersion::V2 | MpegVersion::V25, MpegLayer::II) => 1152,
(MpegVersion::V2 | MpegVersion::V25, MpegLayer::III) => 576,
};
let padding_bytes: usize = if padding { 1 } else { 0 };
let frame_size = match layer {
MpegLayer::I => (12 * bitrate as usize * 1000 / sample_rate as usize + padding_bytes) * 4,
MpegLayer::II | MpegLayer::III => {
let spf = match (layer, version) {
(MpegLayer::III, MpegVersion::V2 | MpegVersion::V25) => 576,
_ => 1152,
};
spf * bitrate as usize * 1000 / (8 * sample_rate as usize) + padding_bytes
}
};
Ok(Mp3FrameInfo {
version,
layer,
bitrate,
sample_rate,
channel_mode,
frame_size,
samples_per_frame,
padding,
})
}
#[must_use = "scanned frame list is returned and should not be discarded"]
pub fn scan_frames(data: &[u8]) -> Result<Vec<Mp3FrameInfo>> {
let skip = id3v2_skip(data);
let mut pos = skip;
let mut frames = Vec::new();
while pos + 4 <= data.len() {
if data[pos] == 0xFF && (data[pos + 1] & 0xE0) == 0xE0 {
let header: [u8; 4] = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
if let Ok(info) = parse_frame_header(&header)
&& info.frame_size > 0
{
frames.push(info);
pos += frames.last().map_or(1, |f| f.frame_size.max(1));
continue;
}
}
pos += 1;
}
if frames.is_empty() {
return Err(ShravanError::InvalidHeader(
"no valid MP3 frames found".into(),
));
}
Ok(frames)
}
#[must_use = "decoded audio data is returned and should not be discarded"]
pub fn decode(data: &[u8]) -> Result<(FormatInfo, Vec<f32>)> {
let frames = scan_frames(data)?;
let first = &frames[0];
let total_samples_all_frames: u64 = frames.iter().map(|f| u64::from(f.samples_per_frame)).sum();
let duration_secs = total_samples_all_frames as f64 / f64::from(first.sample_rate);
let channels: u16 = match first.channel_mode {
ChannelMode::Mono => 1,
_ => 2,
};
let info = FormatInfo {
format: AudioFormat::Mp3,
sample_rate: first.sample_rate,
channels,
bit_depth: 16, duration_secs,
total_samples: total_samples_all_frames,
};
Ok((info, Vec::new()))
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn make_valid_header() -> [u8; 4] {
[0xFF, 0xFB, 0x90, 0x00]
}
#[test]
fn parse_valid_header() {
let header = make_valid_header();
let info = parse_frame_header(&header).unwrap();
assert_eq!(info.version, MpegVersion::V1);
assert_eq!(info.layer, MpegLayer::III);
assert_eq!(info.bitrate, 128);
assert_eq!(info.sample_rate, 44100);
assert_eq!(info.channel_mode, ChannelMode::Stereo);
assert_eq!(info.samples_per_frame, 1152);
assert!(!info.padding);
assert_eq!(info.frame_size, 417);
}
#[test]
fn reject_bad_sync() {
let header = [0x00, 0x00, 0x90, 0x00];
assert!(parse_frame_header(&header).is_err());
}
#[test]
fn reject_reserved_version() {
let header = [0xFF, 0xE9, 0x90, 0x00]; assert!(parse_frame_header(&header).is_err());
}
#[test]
fn reject_reserved_layer() {
let header = [0xFF, 0xF1, 0x90, 0x00]; assert!(parse_frame_header(&header).is_err());
}
#[test]
fn reject_zero_bitrate() {
let header = [0xFF, 0xFB, 0x00, 0x00];
assert!(parse_frame_header(&header).is_err());
}
#[test]
fn reject_reserved_sample_rate() {
let header = [0xFF, 0xFB, 0x9C, 0x00]; assert!(parse_frame_header(&header).is_err());
}
#[test]
fn id3v2_skip_no_tag() {
let data = [0xFF, 0xFB, 0x90, 0x00];
assert_eq!(id3v2_skip(&data), 0);
}
#[test]
fn id3v2_skip_with_tag() {
let mut data = Vec::new();
data.extend_from_slice(b"ID3");
data.push(4); data.push(0); data.push(0); data.extend_from_slice(&[0, 0, 0, 100]);
data.resize(data.len() + 200, 0);
assert_eq!(id3v2_skip(&data), 110); }
#[test]
fn frame_size_with_padding() {
let header = [0xFF, 0xFB, 0x92, 0x00]; let info = parse_frame_header(&header).unwrap();
assert!(info.padding);
assert_eq!(info.frame_size, 418); }
#[test]
fn scan_multiple_frames() {
let header = make_valid_header();
let frame_info = parse_frame_header(&header).unwrap();
let frame_size = frame_info.frame_size;
let mut data = Vec::new();
for _ in 0..3 {
data.extend_from_slice(&header);
data.resize(data.len() + frame_size - 4, 0);
}
let frames = scan_frames(&data).unwrap();
assert_eq!(frames.len(), 3);
}
#[test]
fn scan_frames_empty() {
let data = vec![0u8; 100];
assert!(scan_frames(&data).is_err());
}
#[test]
fn mpeg_version_serde_roundtrip() {
let v = MpegVersion::V1;
let json = serde_json::to_string(&v).unwrap();
let v2: MpegVersion = serde_json::from_str(&json).unwrap();
assert_eq!(v, v2);
}
#[test]
fn mpeg_layer_serde_roundtrip() {
let l = MpegLayer::III;
let json = serde_json::to_string(&l).unwrap();
let l2: MpegLayer = serde_json::from_str(&json).unwrap();
assert_eq!(l, l2);
}
#[test]
fn channel_mode_serde_roundtrip() {
let c = ChannelMode::JointStereo;
let json = serde_json::to_string(&c).unwrap();
let c2: ChannelMode = serde_json::from_str(&json).unwrap();
assert_eq!(c, c2);
}
#[test]
fn mp3_frame_info_serde_roundtrip() {
let header = make_valid_header();
let info = parse_frame_header(&header).unwrap();
let json = serde_json::to_string(&info).unwrap();
let info2: Mp3FrameInfo = serde_json::from_str(&json).unwrap();
assert_eq!(info, info2);
}
#[test]
fn decode_produces_format_info() {
let header = make_valid_header();
let frame_info = parse_frame_header(&header).unwrap();
let frame_size = frame_info.frame_size;
let mut data = Vec::new();
for _ in 0..10 {
data.extend_from_slice(&header);
data.resize(data.len() + frame_size - 4, 0);
}
let (info, samples) = decode(&data).unwrap();
assert_eq!(info.format, AudioFormat::Mp3);
assert_eq!(info.sample_rate, 44100);
assert_eq!(info.channels, 2);
assert!(samples.is_empty());
assert!(info.duration_secs > 0.0);
}
#[test]
fn mono_channel_count() {
let header = [0xFF, 0xFB, 0x90, 0xC0];
let info = parse_frame_header(&header).unwrap();
assert_eq!(info.channel_mode, ChannelMode::Mono);
}
}