use rskit_errors::{AppError, AppResult, ErrorCode};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct WavSpec {
pub channels: u16,
pub sample_rate: u32,
pub bits_per_sample: u16,
}
#[derive(Debug, Clone)]
pub struct WavReader {
pub spec: WavSpec,
pub samples: Vec<f32>,
}
impl WavReader {
pub fn from_bytes(data: &[u8]) -> AppResult<Self> {
if data.len() < 44 {
return Err(AppError::new(ErrorCode::InvalidInput, "WAV file too small"));
}
if &data[0..4] != b"RIFF" || &data[8..12] != b"WAVE" {
return Err(AppError::new(
ErrorCode::InvalidInput,
"Not a valid WAV file (missing RIFF/WAVE header)",
));
}
let (spec, fmt_end) = Self::parse_fmt_chunk(data)?;
let (data_offset, data_len) = Self::find_chunk(data, b"data", fmt_end)?;
let samples = Self::decode_samples(&data[data_offset..data_offset + data_len], &spec)?;
Ok(Self { spec, samples })
}
#[must_use]
pub fn duration_secs(&self) -> f64 {
let total_frames = self.samples.len() / self.spec.channels as usize;
total_frames as f64 / self.spec.sample_rate as f64
}
#[must_use]
pub fn frame_count(&self) -> usize {
self.samples.len() / self.spec.channels as usize
}
#[must_use]
pub fn channel_samples(&self, channel: usize) -> Vec<f32> {
let ch = self.spec.channels as usize;
if channel >= ch {
return Vec::new();
}
self.samples
.iter()
.skip(channel)
.step_by(ch)
.copied()
.collect()
}
fn parse_fmt_chunk(data: &[u8]) -> AppResult<(WavSpec, usize)> {
let (offset, chunk_len) = Self::find_chunk(data, b"fmt ", 12)?;
if chunk_len < 16 {
return Err(AppError::new(
ErrorCode::InvalidInput,
"WAV fmt chunk too small",
));
}
let audio_format = u16::from_le_bytes([data[offset], data[offset + 1]]);
if audio_format != 1 && audio_format != 3 {
return Err(AppError::new(
ErrorCode::InvalidInput,
format!("Unsupported WAV audio format: {audio_format} (only PCM/float supported)"),
));
}
let channels = u16::from_le_bytes([data[offset + 2], data[offset + 3]]);
let sample_rate = u32::from_le_bytes([
data[offset + 4],
data[offset + 5],
data[offset + 6],
data[offset + 7],
]);
let bits_per_sample = u16::from_le_bytes([data[offset + 14], data[offset + 15]]);
Ok((
WavSpec {
channels,
sample_rate,
bits_per_sample,
},
offset + chunk_len,
))
}
fn find_chunk(data: &[u8], id: &[u8; 4], start: usize) -> AppResult<(usize, usize)> {
let mut pos = start;
while pos + 8 <= data.len() {
if &data[pos..pos + 4] == id {
let size = u32::from_le_bytes([
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]) as usize;
let data_start = pos + 8;
let available = data.len().saturating_sub(data_start);
return Ok((data_start, size.min(available)));
}
let chunk_size =
u32::from_le_bytes([data[pos + 4], data[pos + 5], data[pos + 6], data[pos + 7]])
as usize;
pos += 8 + ((chunk_size + 1) & !1);
}
Err(AppError::new(
ErrorCode::InvalidInput,
format!(
"WAV chunk '{}' not found",
std::str::from_utf8(id).unwrap_or("????")
),
))
}
fn decode_samples(data: &[u8], spec: &WavSpec) -> AppResult<Vec<f32>> {
let bps = spec.bits_per_sample;
let bytes_per_sample = (bps / 8) as usize;
if bytes_per_sample == 0 {
return Err(AppError::new(
ErrorCode::InvalidInput,
"Invalid bits_per_sample",
));
}
let sample_count = data.len() / bytes_per_sample;
let mut samples = Vec::with_capacity(sample_count);
for i in 0..sample_count {
let offset = i * bytes_per_sample;
let sample = match bps {
8 => {
(data[offset] as f32 - 128.0) / 128.0
}
16 => {
let val = i16::from_le_bytes([data[offset], data[offset + 1]]);
val as f32 / i16::MAX as f32
}
24 => {
let val =
i32::from_le_bytes([0, data[offset], data[offset + 1], data[offset + 2]]);
let val = if val & 0x0080_0000 != 0 {
val | (0xFF << 24)
} else {
val
};
val as f32 / 8_388_607.0
}
32 => {
f32::from_le_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
])
}
_ => {
return Err(AppError::new(
ErrorCode::InvalidInput,
format!("Unsupported bits_per_sample: {bps}"),
));
}
};
samples.push(sample.clamp(-1.0, 1.0));
}
Ok(samples)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_wav_16bit_mono(sample_rate: u32, samples: &[i16]) -> Vec<u8> {
let data_size = (samples.len() * 2) as u32;
let file_size = 36 + data_size;
let mut buf = Vec::with_capacity(file_size as usize + 8);
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&file_size.to_le_bytes());
buf.extend_from_slice(b"WAVE");
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&16u32.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&(sample_rate * 2).to_le_bytes()); buf.extend_from_slice(&2u16.to_le_bytes()); buf.extend_from_slice(&16u16.to_le_bytes());
buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_size.to_le_bytes());
for &s in samples {
buf.extend_from_slice(&s.to_le_bytes());
}
buf
}
fn make_wav(
audio_format: u16,
channels: u16,
sample_rate: u32,
bps: u16,
data: &[u8],
) -> Vec<u8> {
let data_size = data.len() as u32;
let bytes_per_sample = u32::from(bps / 8).max(1);
let byte_rate = sample_rate * u32::from(channels) * bytes_per_sample;
let block_align = channels * (bps / 8).max(1);
let file_size = 36 + data_size;
let mut buf = Vec::with_capacity(file_size as usize + 8);
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&file_size.to_le_bytes());
buf.extend_from_slice(b"WAVE");
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&16u32.to_le_bytes());
buf.extend_from_slice(&audio_format.to_le_bytes());
buf.extend_from_slice(&channels.to_le_bytes());
buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&byte_rate.to_le_bytes());
buf.extend_from_slice(&block_align.to_le_bytes());
buf.extend_from_slice(&bps.to_le_bytes());
buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_size.to_le_bytes());
buf.extend_from_slice(data);
buf
}
#[test]
fn parse_valid_wav() {
let samples = vec![0, 16383, -16384, 32767, -32768];
let wav_data = make_wav_16bit_mono(44100, &samples);
let reader = WavReader::from_bytes(&wav_data).unwrap();
assert_eq!(reader.spec.channels, 1);
assert_eq!(reader.spec.sample_rate, 44100);
assert_eq!(reader.spec.bits_per_sample, 16);
assert_eq!(reader.frame_count(), 5);
}
#[test]
fn duration_calculation() {
let samples = vec![0i16; 44100]; let wav_data = make_wav_16bit_mono(44100, &samples);
let reader = WavReader::from_bytes(&wav_data).unwrap();
assert!((reader.duration_secs() - 1.0).abs() < 0.001);
}
#[test]
fn rejects_non_wav() {
let result = WavReader::from_bytes(b"not a wav file at all!!!!!!!!!!!!!!!!!!!!!!!!!!");
assert!(result.is_err());
}
#[test]
fn rejects_too_small() {
let result = WavReader::from_bytes(b"tiny");
assert!(result.is_err());
}
#[test]
fn channel_samples_returns_requested_channel_or_empty() {
let wav_data = make_wav_16bit_mono(44_100, &[1, 2, 3, 4]);
let reader = WavReader::from_bytes(&wav_data).unwrap();
assert_eq!(reader.channel_samples(0).len(), 4);
assert!(reader.channel_samples(1).is_empty());
}
#[test]
fn rejects_unsupported_format_and_short_fmt_chunk() {
let unsupported = make_wav(6, 1, 8_000, 16, &[0, 0]);
let err = WavReader::from_bytes(&unsupported).unwrap_err();
assert!(err.message().contains("Unsupported WAV audio format"));
let mut short = Vec::new();
short.extend_from_slice(b"RIFF");
short.extend_from_slice(&28u32.to_le_bytes());
short.extend_from_slice(b"WAVE");
short.extend_from_slice(b"fmt ");
short.extend_from_slice(&4u32.to_le_bytes());
short.extend_from_slice(&[0, 0, 0, 0]);
short.extend_from_slice(b"data");
short.extend_from_slice(&0u32.to_le_bytes());
let err = WavReader::from_bytes(&short).unwrap_err();
assert!(
err.message().contains("fmt chunk too small")
|| err.message().contains("WAV file too small")
);
}
#[test]
fn rejects_missing_data_chunk_after_aligned_unknown_chunk() {
let mut wav = make_wav(1, 1, 8_000, 16, &[0, 0]);
wav.truncate(36);
wav.extend_from_slice(b"JUNK");
wav.extend_from_slice(&3u32.to_le_bytes());
wav.extend_from_slice(&[1, 2, 3, 0]);
let err = WavReader::from_bytes(&wav).unwrap_err();
assert!(err.message().contains("data"));
}
#[test]
fn decodes_supported_sample_widths_and_clamps_float() {
let eight = make_wav(1, 1, 8_000, 8, &[0, 128, 255]);
let eight = WavReader::from_bytes(&eight).unwrap();
assert_eq!(eight.samples[0], -1.0);
assert_eq!(eight.samples[1], 0.0);
let twenty_four = make_wav(1, 1, 8_000, 24, &[0xff, 0x7f, 0x00, 0x00, 0x80, 0xff]);
let twenty_four = WavReader::from_bytes(&twenty_four).unwrap();
assert!(twenty_four.samples[0] > 0.99);
assert!(twenty_four.samples[1] < -0.99);
let mut float_data = Vec::new();
float_data.extend_from_slice(&2.0f32.to_le_bytes());
float_data.extend_from_slice(&(-2.0f32).to_le_bytes());
let float = make_wav(3, 1, 8_000, 32, &float_data);
let float = WavReader::from_bytes(&float).unwrap();
assert_eq!(float.samples, vec![1.0, -1.0]);
}
#[test]
fn rejects_invalid_or_unsupported_bits_per_sample() {
let invalid = make_wav(1, 1, 8_000, 0, &[0]);
let err = WavReader::from_bytes(&invalid).unwrap_err();
assert!(err.message().contains("Invalid bits_per_sample"));
let unsupported = make_wav(1, 1, 8_000, 12, &[0, 0]);
let err = WavReader::from_bytes(&unsupported).unwrap_err();
assert!(err.message().contains("Unsupported bits_per_sample"));
}
}