use oxideav_core::{Error, Result};
#[derive(Clone, Debug)]
pub struct DecodedPcm {
pub samples: Vec<f32>,
pub sample_rate: u32,
pub channels: u16,
pub bits_per_sample: u16,
}
pub fn decode_wav(bytes: &[u8]) -> Result<DecodedPcm> {
if bytes.len() < 12 {
return Err(Error::invalid(
"WAV: file shorter than 12 bytes (no RIFF header)",
));
}
if &bytes[0..4] != b"RIFF" {
return Err(Error::invalid("WAV: missing RIFF magic at offset 0"));
}
if &bytes[8..12] != b"WAVE" {
return Err(Error::invalid("WAV: missing WAVE form-type at offset 8"));
}
let mut fmt: Option<Fmt> = None;
let mut data_bytes: Option<&[u8]> = None;
let mut pos = 12usize;
while pos + 8 <= bytes.len() {
let tag = &bytes[pos..pos + 4];
let size = u32::from_le_bytes(bytes[pos + 4..pos + 8].try_into().unwrap()) as usize;
let body_start = pos + 8;
if body_start.saturating_add(size) > bytes.len() {
return Err(Error::invalid(format!(
"WAV: chunk '{}' length {} exceeds remaining bytes",
tag_str(tag),
size,
)));
}
let body = &bytes[body_start..body_start + size];
match tag {
b"fmt " => fmt = Some(parse_fmt(body)?),
b"data" => data_bytes = Some(body),
_ => {}
}
pos = body_start + size + (size & 1);
}
let fmt = fmt.ok_or_else(|| Error::invalid("WAV: missing 'fmt ' chunk"))?;
let data = data_bytes.ok_or_else(|| Error::invalid("WAV: missing 'data' chunk"))?;
decode_data(&fmt, data)
}
pub fn decode_pcm_bytes(
data: &[u8],
sample_rate: u32,
channels: u16,
bits_per_sample: u16,
format_tag: u16,
) -> Result<DecodedPcm> {
let fmt = Fmt {
format_tag,
channels,
sample_rate,
bits_per_sample,
};
decode_data(&fmt, data)
}
#[derive(Clone, Copy, Debug)]
struct Fmt {
format_tag: u16,
channels: u16,
sample_rate: u32,
bits_per_sample: u16,
}
fn parse_fmt(body: &[u8]) -> Result<Fmt> {
if body.len() < 16 {
return Err(Error::invalid(format!(
"WAV: 'fmt ' chunk {} bytes < 16 minimum",
body.len(),
)));
}
Ok(Fmt {
format_tag: u16::from_le_bytes(body[0..2].try_into().unwrap()),
channels: u16::from_le_bytes(body[2..4].try_into().unwrap()),
sample_rate: u32::from_le_bytes(body[4..8].try_into().unwrap()),
bits_per_sample: u16::from_le_bytes(body[14..16].try_into().unwrap()),
})
}
fn decode_data(fmt: &Fmt, data: &[u8]) -> Result<DecodedPcm> {
if fmt.channels == 0 {
return Err(Error::invalid("WAV: zero channels"));
}
if fmt.sample_rate == 0 {
return Err(Error::invalid("WAV: zero sample rate"));
}
let channels = fmt.channels as usize;
let bps = fmt.bits_per_sample as usize;
let bytes_per_sample = bps.div_ceil(8);
let bytes_per_frame = bytes_per_sample * channels;
if bytes_per_frame == 0 {
return Err(Error::invalid("WAV: zero bytes per frame"));
}
let is_float = fmt.format_tag == 3;
let is_pcm = fmt.format_tag == 1 || fmt.format_tag == 0xFFFE;
if !is_float && !is_pcm {
return Err(Error::invalid(format!(
"WAV: unsupported format tag 0x{:04X} (only PCM and IEEE_FLOAT)",
fmt.format_tag,
)));
}
if is_float && bps != 32 {
return Err(Error::invalid(format!(
"WAV: IEEE_FLOAT requires 32 bits/sample, got {bps}",
)));
}
if !is_float && !matches!(bps, 8 | 16 | 24 | 32) {
return Err(Error::invalid(format!(
"WAV: unsupported PCM bit depth {bps} (need 8/16/24/32)",
)));
}
let frame_count = data.len() / bytes_per_frame;
const MAX_FRAMES: usize = 64 * 1024 * 1024;
if frame_count > MAX_FRAMES {
return Err(Error::invalid(format!(
"WAV: {frame_count} frames exceeds {MAX_FRAMES} cap",
)));
}
let mut out = Vec::with_capacity(frame_count);
for frame in data.chunks_exact(bytes_per_frame) {
let mut acc: f32 = 0.0;
for ch in 0..channels {
let s = &frame[ch * bytes_per_sample..(ch + 1) * bytes_per_sample];
let v = if is_float {
f32::from_le_bytes(s.try_into().unwrap())
} else {
match bps {
8 => (s[0] as f32 - 128.0) / 128.0,
16 => i16::from_le_bytes(s.try_into().unwrap()) as f32 / 32_768.0,
24 => {
let raw = (s[0] as i32) | ((s[1] as i32) << 8) | ((s[2] as i32) << 16);
let signed = if raw & 0x80_0000 != 0 {
raw | !0xFF_FFFF
} else {
raw
};
signed as f32 / 8_388_608.0
}
32 => i32::from_le_bytes(s.try_into().unwrap()) as f32 / 2_147_483_648.0,
_ => unreachable!(),
}
};
acc += v;
}
out.push(acc / channels as f32);
}
Ok(DecodedPcm {
samples: out,
sample_rate: fmt.sample_rate,
channels: fmt.channels,
bits_per_sample: fmt.bits_per_sample,
})
}
fn tag_str(tag: &[u8]) -> String {
if tag.iter().all(|b| b.is_ascii_graphic() || *b == b' ') {
String::from_utf8_lossy(tag).into_owned()
} else {
format!("{:02X}{:02X}{:02X}{:02X}", tag[0], tag[1], tag[2], tag[3])
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build_wav_16bit_mono(samples: &[i16], rate: u32) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.extend_from_slice(b"RIFF");
let data_size = samples.len() * 2;
bytes.extend_from_slice(&((36u32 + data_size as u32).to_le_bytes()));
bytes.extend_from_slice(b"WAVE");
bytes.extend_from_slice(b"fmt ");
bytes.extend_from_slice(&16u32.to_le_bytes());
bytes.extend_from_slice(&1u16.to_le_bytes()); bytes.extend_from_slice(&1u16.to_le_bytes()); bytes.extend_from_slice(&rate.to_le_bytes());
bytes.extend_from_slice(&(rate * 2).to_le_bytes()); bytes.extend_from_slice(&2u16.to_le_bytes()); bytes.extend_from_slice(&16u16.to_le_bytes()); bytes.extend_from_slice(b"data");
bytes.extend_from_slice(&(data_size as u32).to_le_bytes());
for s in samples {
bytes.extend_from_slice(&s.to_le_bytes());
}
bytes
}
#[test]
fn decodes_16bit_mono_wav_to_f32() {
let in_samples: Vec<i16> = (0..8).map(|i| (i * 4096) as i16 - 16384).collect();
let wav = build_wav_16bit_mono(&in_samples, 44_100);
let pcm = decode_wav(&wav).expect("decode");
assert_eq!(pcm.sample_rate, 44_100);
assert_eq!(pcm.channels, 1);
assert_eq!(pcm.bits_per_sample, 16);
assert_eq!(pcm.samples.len(), 8);
assert!(
(pcm.samples[0] - -0.5).abs() < 1e-4,
"got {}",
pcm.samples[0]
);
}
#[test]
fn decodes_8bit_unsigned_pcm() {
let body = [0u8, 64, 128, 192];
let pcm = decode_pcm_bytes(&body, 22_050, 1, 8, 1).expect("decode");
assert_eq!(pcm.samples.len(), 4);
assert!((pcm.samples[0] - -1.0).abs() < 1e-4);
assert!((pcm.samples[1] - -0.5).abs() < 1e-3);
assert!((pcm.samples[2] - 0.0).abs() < 1e-4);
assert!((pcm.samples[3] - 0.5).abs() < 1e-3);
}
#[test]
fn decodes_stereo_to_mono_average() {
let mut body: Vec<u8> = Vec::new();
for _ in 0..4 {
body.extend_from_slice(&(-32768i16).to_le_bytes());
body.extend_from_slice(&32766i16.to_le_bytes());
}
let pcm = decode_pcm_bytes(&body, 44_100, 2, 16, 1).expect("decode");
assert_eq!(pcm.samples.len(), 4);
for s in &pcm.samples {
assert!(s.abs() < 0.01, "expected near-zero from mono mix, got {s}");
}
}
#[test]
fn rejects_non_riff_input() {
let err = decode_wav(b"not a wav").unwrap_err();
assert!(matches!(err, Error::InvalidData(_)));
}
}