mod chunks;
pub use chunks::{FmtChunk, FmtExtension, RiffChunk, WavFormat};
use async_trait::async_trait;
use bytes::Bytes;
use oximedia_core::{CodecId, OxiError, OxiResult, Rational, Timestamp};
use std::io::SeekFrom;
use crate::demux::Demuxer;
use crate::{CodecParams, ContainerFormat, Metadata, Packet, PacketFlags, ProbeResult, StreamInfo};
pub const RIFF_MAGIC: &[u8; 4] = b"RIFF";
pub const WAVE_FORM: &[u8; 4] = b"WAVE";
pub const RF64_MAGIC: &[u8; 4] = b"RF64";
const FMT_CHUNK: &[u8; 4] = b"fmt ";
const DATA_CHUNK: &[u8; 4] = b"data";
const DEFAULT_PACKET_SAMPLES: u64 = 4096;
pub struct WavDemuxer<R> {
source: R,
buffer: Vec<u8>,
fmt: Option<FmtChunk>,
streams: Vec<StreamInfo>,
data_start: u64,
data_size: u64,
position: u64,
samples_read: u64,
eof: bool,
}
impl<R> WavDemuxer<R> {
#[must_use]
pub fn new(source: R) -> Self {
Self {
source,
buffer: Vec::with_capacity(8192),
fmt: None,
streams: Vec::new(),
data_start: 0,
data_size: 0,
position: 0,
samples_read: 0,
eof: false,
}
}
#[must_use]
pub fn format_info(&self) -> Option<&FmtChunk> {
self.fmt.as_ref()
}
#[must_use]
#[allow(clippy::cast_precision_loss, clippy::manual_checked_ops)]
pub fn duration_seconds(&self) -> Option<f64> {
let fmt = self.fmt.as_ref()?;
if fmt.byte_rate == 0 {
return None;
}
Some(self.data_size as f64 / f64::from(fmt.byte_rate))
}
#[must_use]
pub fn total_samples(&self) -> Option<u64> {
let fmt = self.fmt.as_ref()?;
self.data_size.checked_div(u64::from(fmt.block_align))
}
#[must_use]
pub fn bytes_remaining(&self) -> u64 {
self.data_size.saturating_sub(self.position)
}
#[must_use]
pub fn is_eof(&self) -> bool {
self.eof
}
}
impl<R: oximedia_io::MediaSource> WavDemuxer<R> {
async fn read_exact(&mut self, n: usize) -> OxiResult<Vec<u8>> {
let mut buf = vec![0u8; n];
let mut read = 0;
while read < n {
let chunk = self.source.read(&mut buf[read..]).await?;
if chunk == 0 {
return Err(OxiError::UnexpectedEof);
}
read += chunk;
}
Ok(buf)
}
async fn parse_header(&mut self) -> OxiResult<()> {
let header = self.read_exact(12).await?;
if &header[0..4] != RIFF_MAGIC {
return Err(OxiError::Parse {
offset: 0,
message: "Not a RIFF file".into(),
});
}
if &header[8..12] != WAVE_FORM {
return Err(OxiError::Parse {
offset: 8,
message: "Not a WAVE file".into(),
});
}
let mut offset = 12u64;
let mut found_fmt = false;
let mut found_data = false;
while !found_data {
let chunk_header = match self.read_exact(RiffChunk::HEADER_SIZE).await {
Ok(h) => h,
Err(OxiError::UnexpectedEof) => break,
Err(e) => return Err(e),
};
let chunk = RiffChunk::parse(&chunk_header)?;
if chunk.is(FMT_CHUNK) {
let fmt_data = self.read_exact(chunk.size as usize).await?;
self.fmt = Some(FmtChunk::parse(&fmt_data)?);
found_fmt = true;
} else if chunk.is(DATA_CHUNK) {
self.data_start = offset + RiffChunk::HEADER_SIZE as u64;
self.data_size = u64::from(chunk.size);
found_data = true;
} else {
let skip_size = u64::from(chunk.size);
let padded_size = if skip_size % 2 == 1 {
skip_size + 1
} else {
skip_size
};
self.source
.seek(SeekFrom::Current(i64::try_from(padded_size).unwrap_or(0)))
.await?;
}
offset += RiffChunk::HEADER_SIZE as u64 + u64::from(chunk.size);
if chunk.size % 2 == 1 {
offset += 1;
}
}
if !found_fmt {
return Err(OxiError::Parse {
offset: 0,
message: "Missing fmt chunk".into(),
});
}
if !found_data {
return Err(OxiError::Parse {
offset: 0,
message: "Missing data chunk".into(),
});
}
Ok(())
}
fn build_stream_info(&mut self) {
let Some(fmt) = &self.fmt else {
return;
};
let codec_params =
CodecParams::audio(fmt.sample_rate, u8::try_from(fmt.channels).unwrap_or(2));
let duration = self
.data_size
.checked_div(u64::from(fmt.block_align))
.and_then(|v| i64::try_from(v).ok());
let timebase = Rational::new(1, i64::from(fmt.sample_rate));
let mut stream = StreamInfo::new(0, CodecId::Pcm, timebase);
stream.duration = duration;
stream.codec_params = codec_params;
stream.metadata = Metadata::new();
self.streams.push(stream);
}
}
#[async_trait]
impl<R: oximedia_io::MediaSource> Demuxer for WavDemuxer<R> {
async fn probe(&mut self) -> OxiResult<ProbeResult> {
self.parse_header().await?;
self.build_stream_info();
self.source.seek(SeekFrom::Start(self.data_start)).await?;
self.position = 0;
Ok(ProbeResult {
format: ContainerFormat::Wav,
confidence: 1.0,
})
}
async fn read_packet(&mut self) -> OxiResult<Packet> {
if self.eof {
return Err(OxiError::Eof);
}
let Some(fmt) = &self.fmt else {
return Err(OxiError::InvalidData(
"Format not parsed. Call probe() first.".into(),
));
};
let samples_to_read = DEFAULT_PACKET_SAMPLES;
let bytes_per_sample = u64::from(fmt.block_align);
let bytes_to_read = samples_to_read * bytes_per_sample;
let remaining = self.data_size.saturating_sub(self.position);
if remaining == 0 {
self.eof = true;
return Err(OxiError::Eof);
}
let actual_bytes = bytes_to_read.min(remaining);
let actual_bytes_usize = usize::try_from(actual_bytes)
.map_err(|_| OxiError::InvalidData("Invalid size".into()))?;
self.buffer.resize(actual_bytes_usize, 0);
let mut read = 0;
while read < actual_bytes_usize {
let chunk = self.source.read(&mut self.buffer[read..]).await?;
if chunk == 0 {
if read == 0 {
self.eof = true;
return Err(OxiError::Eof);
}
break;
}
read += chunk;
}
self.buffer.truncate(read);
let pts = i64::try_from(self.samples_read).unwrap_or(0);
let samples_in_packet = (read as u64).checked_div(bytes_per_sample).unwrap_or(0);
let duration = i64::try_from(samples_in_packet).ok();
let timebase = Rational::new(1, i64::from(fmt.sample_rate));
let timestamp = Timestamp::with_dts(pts, None, timebase, duration);
self.position += read as u64;
self.samples_read += samples_in_packet;
Ok(Packet::new(
0,
Bytes::copy_from_slice(&self.buffer),
timestamp,
PacketFlags::KEYFRAME,
))
}
fn streams(&self) -> &[StreamInfo] {
&self.streams
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wav_demuxer_new() {
struct MockSource;
let demuxer = WavDemuxer::new(MockSource);
assert!(demuxer.fmt.is_none());
assert!(demuxer.streams.is_empty());
assert!(!demuxer.eof);
}
#[test]
fn test_bytes_remaining() {
struct MockSource;
let mut demuxer = WavDemuxer::new(MockSource);
demuxer.data_size = 1000;
demuxer.position = 400;
assert_eq!(demuxer.bytes_remaining(), 600);
}
#[test]
fn test_duration_seconds() {
struct MockSource;
let mut demuxer = WavDemuxer::new(MockSource);
demuxer.data_size = 176_400; demuxer.fmt = Some(FmtChunk {
format: WavFormat::Pcm,
channels: 2,
sample_rate: 44_100,
byte_rate: 176_400,
block_align: 4,
bits_per_sample: 16,
extension: None,
});
let duration = demuxer
.duration_seconds()
.expect("operation should succeed");
assert!((duration - 1.0).abs() < 0.001);
}
#[test]
fn test_total_samples() {
struct MockSource;
let mut demuxer = WavDemuxer::new(MockSource);
demuxer.data_size = 176_400;
demuxer.fmt = Some(FmtChunk {
format: WavFormat::Pcm,
channels: 2,
sample_rate: 44_100,
byte_rate: 176_400,
block_align: 4,
bits_per_sample: 16,
extension: None,
});
assert_eq!(demuxer.total_samples(), Some(44_100));
}
#[test]
fn test_duration_zero_byte_rate() {
struct MockSource;
let mut demuxer = WavDemuxer::new(MockSource);
demuxer.data_size = 1000;
demuxer.fmt = Some(FmtChunk {
format: WavFormat::Pcm,
channels: 2,
sample_rate: 44_100,
byte_rate: 0, block_align: 4,
bits_per_sample: 16,
extension: None,
});
assert!(demuxer.duration_seconds().is_none());
}
}