mod page;
mod stream;
pub use page::{OggPage, PageFlags};
pub use stream::{identify_codec, LogicalStream};
use async_trait::async_trait;
use bytes::Bytes;
use oximedia_core::{OxiError, OxiResult, Rational, Timestamp};
use oximedia_io::MediaSource;
use std::collections::HashMap;
use crate::demux::Demuxer;
use crate::{CodecParams, ContainerFormat, Packet, PacketFlags, ProbeResult, StreamInfo};
const READ_BUFFER_SIZE: usize = 8192;
const MAX_STREAMS: usize = 64;
pub struct OggDemuxer<R> {
source: R,
buffer: Vec<u8>,
buffer_pos: usize,
buffer_len: usize,
streams: Vec<StreamInfo>,
logical_streams: HashMap<u32, LogicalStream>,
pending_packets: Vec<Packet>,
position: u64,
eof: bool,
headers_parsed: bool,
}
impl<R> OggDemuxer<R> {
#[must_use]
pub fn new(source: R) -> Self {
Self {
source,
buffer: vec![0u8; READ_BUFFER_SIZE],
buffer_pos: 0,
buffer_len: 0,
streams: Vec::new(),
logical_streams: HashMap::new(),
pending_packets: Vec::new(),
position: 0,
eof: false,
headers_parsed: false,
}
}
#[must_use]
pub const fn source(&self) -> &R {
&self.source
}
pub fn source_mut(&mut self) -> &mut R {
&mut self.source
}
#[must_use]
#[allow(dead_code)]
pub fn into_source(self) -> R {
self.source
}
#[must_use]
pub fn stream_count(&self) -> usize {
self.streams.len()
}
}
impl<R: MediaSource> OggDemuxer<R> {
async fn fill_buffer(&mut self) -> OxiResult<usize> {
if self.buffer_pos > 0 {
let remaining = self.buffer_len - self.buffer_pos;
if remaining > 0 {
self.buffer.copy_within(self.buffer_pos..self.buffer_len, 0);
}
self.buffer_len = remaining;
self.buffer_pos = 0;
}
let bytes_read = self
.source
.read(&mut self.buffer[self.buffer_len..])
.await?;
if bytes_read == 0 {
self.eof = true;
}
self.buffer_len += bytes_read;
self.position += bytes_read as u64;
Ok(bytes_read)
}
async fn ensure_buffer(&mut self, min_bytes: usize) -> OxiResult<bool> {
while self.buffer_len - self.buffer_pos < min_bytes {
if self.eof {
return Ok(false);
}
let read = self.fill_buffer().await?;
if read == 0 {
return Ok(false);
}
}
Ok(true)
}
async fn read_page(&mut self) -> OxiResult<Option<OggPage>> {
if !self.ensure_buffer(page::MIN_HEADER_SIZE).await? {
return Ok(None);
}
let sync_pos = self.sync_to_page().await?;
if sync_pos.is_none() {
return Ok(None);
}
let segment_count = self.buffer[self.buffer_pos + 26] as usize;
let header_size = page::MIN_HEADER_SIZE + segment_count;
if !self.ensure_buffer(header_size).await? {
return Ok(None);
}
let segment_table = &self.buffer[self.buffer_pos + 27..self.buffer_pos + header_size];
let data_size: usize = segment_table.iter().map(|&s| usize::from(s)).sum();
let page_size = header_size + data_size;
if !self.ensure_buffer(page_size).await? {
return Ok(None);
}
let page_data = &self.buffer[self.buffer_pos..self.buffer_pos + page_size];
match OggPage::parse(page_data) {
Ok((page, consumed)) => {
self.buffer_pos += consumed;
Ok(Some(page))
}
Err(e) => {
self.buffer_pos += 1;
Err(e)
}
}
}
async fn sync_to_page(&mut self) -> OxiResult<Option<usize>> {
loop {
let search_end = self.buffer_len.saturating_sub(3);
for i in self.buffer_pos..search_end {
if &self.buffer[i..i + 4] == page::OGG_MAGIC {
self.buffer_pos = i;
return Ok(Some(i));
}
}
if self.eof {
return Ok(None);
}
let keep = 3.min(self.buffer_len - self.buffer_pos);
self.buffer_pos = self.buffer_len - keep;
if self.fill_buffer().await? == 0 {
return Ok(None);
}
}
}
fn process_page(&mut self, page: &OggPage) -> OxiResult<()> {
let serial = page.serial_number;
if page.is_bos() {
return self.handle_bos_page(page);
}
if !self.logical_streams.contains_key(&serial) {
return Ok(()); }
let packets = page.packets();
let is_continuation = page.is_continuation();
let granule_position = page.granule_position;
let has_granule = page.has_granule();
let mut headers_completed = false;
let mut packets_to_create = Vec::new();
for (i, (data, complete)) in packets.into_iter().enumerate() {
let Some(stream) = self.logical_streams.get_mut(&serial) else {
break; };
let packet_data = if i == 0 && is_continuation {
if stream.has_incomplete_packet() {
let mut buffer = stream.take_buffer();
buffer.extend_from_slice(&data);
buffer
} else {
continue;
}
} else {
data
};
if !complete {
stream.append_to_buffer(&packet_data);
continue;
}
if !stream.headers_complete() {
stream.add_header(packet_data);
if stream.headers_complete() {
headers_completed = true;
}
continue;
}
let stream_index = stream.stream_index;
let sample_rate = stream.sample_rate();
let last_granule = stream.last_granule;
packets_to_create.push((stream_index, packet_data, sample_rate, last_granule));
}
if headers_completed {
self.update_stream_info(serial);
}
for (stream_index, packet_data, sample_rate, last_granule) in packets_to_create {
let timebase = Rational::new(1, i64::from(sample_rate.max(1)));
let pts = if let Some(stream) = self.logical_streams.get(&serial) {
if has_granule {
stream.granule_to_timebase(granule_position, timebase)
} else {
stream.granule_to_timebase(last_granule, timebase)
}
} else {
0i64
};
let packet = Packet::new(
stream_index,
Bytes::from(packet_data),
Timestamp::new(pts, timebase),
PacketFlags::KEYFRAME,
);
self.pending_packets.push(packet);
}
if has_granule {
if let Some(stream) = self.logical_streams.get_mut(&serial) {
stream.last_granule = granule_position;
}
}
Ok(())
}
fn handle_bos_page(&mut self, page: &OggPage) -> OxiResult<()> {
let serial = page.serial_number;
if self.logical_streams.contains_key(&serial) {
return Ok(());
}
if self.logical_streams.len() >= MAX_STREAMS {
return Err(OxiError::unsupported("Too many streams in Ogg container"));
}
let packets = page.packets();
let first_packet = packets.first().map(|(data, _)| data.as_slice());
let codec = first_packet
.and_then(identify_codec)
.ok_or_else(|| OxiError::unsupported("Unknown Ogg codec"))?;
let stream_index = self.streams.len();
let mut logical_stream = LogicalStream::new(serial, codec, stream_index);
if let Some((data, _)) = packets.into_iter().next() {
logical_stream.add_header(data);
}
let timebase = Rational::new(1, i64::from(logical_stream.sample_rate().max(1)));
let stream_info = StreamInfo::new(stream_index, codec, timebase);
self.streams.push(stream_info);
self.logical_streams.insert(serial, logical_stream);
Ok(())
}
fn update_stream_info(&mut self, serial: u32) {
let Some(stream) = self.logical_streams.get(&serial) else {
return;
};
if stream.stream_index >= self.streams.len() {
return;
}
let info = &mut self.streams[stream.stream_index];
let sample_rate = stream.sample_rate();
if sample_rate > 0 {
info.timebase = Rational::new(1, i64::from(sample_rate));
}
info.codec_params = CodecParams::audio(sample_rate, 2);
if stream.codec == oximedia_core::CodecId::Opus {
if let Some(header) = stream.headers.first() {
if header.len() >= 10 {
info.codec_params.channels = Some(header[9]);
}
}
}
if !stream.headers.is_empty() {
let mut extradata = Vec::new();
for header in &stream.headers {
#[allow(clippy::cast_possible_truncation)]
let len = header.len() as u16;
extradata.extend_from_slice(&len.to_le_bytes());
extradata.extend_from_slice(header);
}
info.codec_params.extradata = Some(Bytes::from(extradata));
}
}
}
#[async_trait]
impl<R: MediaSource> Demuxer for OggDemuxer<R> {
async fn probe(&mut self) -> OxiResult<ProbeResult> {
loop {
if let Some(page) = self.read_page().await? {
self.process_page(&page)?;
if !page.is_bos() {
break;
}
} else {
if self.streams.is_empty() {
return Err(OxiError::UnknownFormat);
}
break;
}
}
while !self.all_headers_complete() {
if let Some(page) = self.read_page().await? {
self.process_page(&page)?;
} else {
break;
}
}
self.headers_parsed = true;
Ok(ProbeResult::new(ContainerFormat::Ogg, 0.99))
}
async fn read_packet(&mut self) -> OxiResult<Packet> {
if let Some(packet) = self.pending_packets.pop() {
return Ok(packet);
}
loop {
if let Some(page) = self.read_page().await? {
self.process_page(&page)?;
if let Some(packet) = self.pending_packets.pop() {
return Ok(packet);
}
} else {
return Err(OxiError::Eof);
}
}
}
fn streams(&self) -> &[StreamInfo] {
&self.streams
}
async fn seek(&mut self, target: crate::SeekTarget) -> OxiResult<()> {
self.perform_seek(target).await
}
fn is_seekable(&self) -> bool {
self.source.is_seekable() && self.headers_parsed
}
}
impl<R: MediaSource> OggDemuxer<R> {
fn all_headers_complete(&self) -> bool {
self.logical_streams
.values()
.all(LogicalStream::headers_complete)
}
async fn seek_to_position(&mut self, position: u64) -> OxiResult<()> {
use std::io::SeekFrom;
self.source.seek(SeekFrom::Start(position)).await?;
self.position = position;
self.buffer_pos = 0;
self.buffer_len = 0;
self.eof = false;
self.pending_packets.clear();
Ok(())
}
#[allow(clippy::cast_possible_wrap)]
async fn read_granule_at(&mut self, position: u64, serial: u32) -> OxiResult<Option<i64>> {
let original_pos = self.source.position();
self.seek_to_position(position).await?;
for _ in 0..10 {
if let Some(page) = self.read_page().await? {
if page.serial_number == serial && page.has_granule() {
let granule = page.granule_position as i64;
self.seek_to_position(original_pos).await?;
return Ok(Some(granule));
}
} else {
break;
}
}
self.seek_to_position(original_pos).await?;
Ok(None)
}
#[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
async fn bisect_seek(
&mut self,
target_granule: i64,
serial: u32,
file_size: u64,
) -> OxiResult<()> {
let mut left = 0u64;
let mut right = file_size;
let mut best_pos = 0u64;
while left < right {
let mid = left + (right - left) / 2;
let aligned_mid = self.align_to_page(mid).await?;
match self.read_granule_at(aligned_mid, serial).await? {
Some(granule) => match granule.cmp(&target_granule) {
std::cmp::Ordering::Less => {
best_pos = aligned_mid;
left = mid + 1;
}
std::cmp::Ordering::Greater => {
right = mid;
}
std::cmp::Ordering::Equal => {
best_pos = aligned_mid;
break;
}
},
None => {
right = mid;
}
}
if right.saturating_sub(left) < 4096 {
break;
}
}
self.seek_to_position(best_pos).await?;
Ok(())
}
async fn align_to_page(&mut self, position: u64) -> OxiResult<u64> {
use std::io::SeekFrom;
let search_start = position.saturating_sub(65536);
self.source.seek(SeekFrom::Start(search_start)).await?;
let mut search_buf = vec![0u8; (position - search_start).min(65536) as usize];
let n = self.source.read(&mut search_buf).await?;
for i in (0..n.saturating_sub(4)).rev() {
if &search_buf[i..i + 4] == page::OGG_MAGIC {
return Ok(search_start + i as u64);
}
}
Ok(position)
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
async fn perform_seek(&mut self, target: crate::SeekTarget) -> OxiResult<()> {
use oximedia_core::OxiError;
if !self.source.is_seekable() {
return Err(OxiError::unsupported("Source is not seekable"));
}
if !self.headers_parsed {
return Err(OxiError::InvalidData(
"Cannot seek before probing".to_string(),
));
}
if self.streams.is_empty() {
return Err(OxiError::InvalidData("No streams available".to_string()));
}
let stream_index = target.stream_index.unwrap_or(0);
if stream_index >= self.streams.len() {
return Err(OxiError::InvalidData(format!(
"Stream index {stream_index} out of range"
)));
}
let (serial, logical_stream) = self
.logical_streams
.iter()
.find(|(_, ls)| ls.stream_index == stream_index)
.ok_or_else(|| OxiError::InvalidData("Stream not found".to_string()))?;
let serial = *serial;
let sample_rate = logical_stream.sample_rate();
#[allow(clippy::cast_possible_truncation)]
let target_granule = (target.position * f64::from(sample_rate)) as i64;
let file_size = self.source.len().unwrap_or(u64::MAX);
self.bisect_seek(target_granule, serial, file_size).await?;
self.pending_packets.clear();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use oximedia_io::MemorySource;
fn create_ogg_page(serial: u32, flags: PageFlags, data: &[u8]) -> Vec<u8> {
let mut page = Vec::new();
page.extend_from_slice(page::OGG_MAGIC); page.push(0); page.push(flags.bits()); page.extend_from_slice(&0u64.to_le_bytes()); page.extend_from_slice(&serial.to_le_bytes()); page.extend_from_slice(&0u32.to_le_bytes()); page.extend_from_slice(&0u32.to_le_bytes());
let segments: Vec<u8> = if data.is_empty() {
vec![0]
} else {
let mut segs = Vec::new();
let mut remaining = data.len();
while remaining > 255 {
segs.push(255);
remaining -= 255;
}
#[allow(clippy::cast_possible_truncation)]
segs.push(remaining as u8);
segs
};
#[allow(clippy::cast_possible_truncation)]
page.push(segments.len() as u8);
page.extend_from_slice(&segments);
page.extend_from_slice(data);
page
}
fn create_opus_head() -> Vec<u8> {
vec![
b'O', b'p', b'u', b's', b'H', b'e', b'a', b'd', 1, 2, 0x00, 0x00, 0x80, 0xBB, 0x00, 0x00, 0x00, 0x00, 0, ]
}
fn create_opus_tags() -> Vec<u8> {
let mut tags = vec![b'O', b'p', b'u', b's', b'T', b'a', b'g', b's'];
tags.extend_from_slice(&0u32.to_le_bytes());
tags.extend_from_slice(&0u32.to_le_bytes());
tags
}
#[tokio::test]
async fn test_ogg_demuxer_new() {
let source = MemorySource::new(Bytes::new());
let demuxer = OggDemuxer::new(source);
assert!(!demuxer.headers_parsed);
assert!(demuxer.streams.is_empty());
}
#[tokio::test]
async fn test_ogg_demuxer_probe_empty() {
let source = MemorySource::new(Bytes::new());
let mut demuxer = OggDemuxer::new(source);
let result = demuxer.probe().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_ogg_demuxer_probe_invalid() {
let source = MemorySource::new(Bytes::from_static(b"not an ogg file"));
let mut demuxer = OggDemuxer::new(source);
let result = demuxer.probe().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_ogg_demuxer_probe_opus() {
let mut data = Vec::new();
let opus_head = create_opus_head();
data.extend(create_ogg_page(1, PageFlags::BOS, &opus_head));
let opus_tags = create_opus_tags();
data.extend(create_ogg_page(1, PageFlags::empty(), &opus_tags));
let source = MemorySource::new(Bytes::from(data));
let mut demuxer = OggDemuxer::new(source);
let result = demuxer.probe().await;
assert!(result.is_ok());
let probe = result.expect("operation should succeed");
assert_eq!(probe.format, ContainerFormat::Ogg);
assert_eq!(demuxer.streams.len(), 1);
assert_eq!(demuxer.streams[0].codec, oximedia_core::CodecId::Opus);
}
#[tokio::test]
async fn test_ogg_demuxer_read_packet() {
let mut data = Vec::new();
let opus_head = create_opus_head();
data.extend(create_ogg_page(1, PageFlags::BOS, &opus_head));
let opus_tags = create_opus_tags();
data.extend(create_ogg_page(1, PageFlags::empty(), &opus_tags));
let audio_data = vec![0u8; 100];
data.extend(create_ogg_page(1, PageFlags::empty(), &audio_data));
let source = MemorySource::new(Bytes::from(data));
let mut demuxer = OggDemuxer::new(source);
demuxer.probe().await.expect("probe should succeed");
let result = demuxer.read_packet().await;
assert!(result.is_ok());
let packet = result.expect("operation should succeed");
assert_eq!(packet.stream_index, 0);
assert_eq!(packet.size(), 100);
}
#[tokio::test]
async fn test_ogg_demuxer_eof() {
let mut data = Vec::new();
let opus_head = create_opus_head();
data.extend(create_ogg_page(1, PageFlags::BOS, &opus_head));
let opus_tags = create_opus_tags();
data.extend(create_ogg_page(1, PageFlags::empty(), &opus_tags));
let source = MemorySource::new(Bytes::from(data));
let mut demuxer = OggDemuxer::new(source);
demuxer.probe().await.expect("probe should succeed");
let result = demuxer.read_packet().await;
assert!(matches!(result, Err(OxiError::Eof)));
}
#[test]
fn test_stream_count() {
let source = MemorySource::new(Bytes::new());
let demuxer = OggDemuxer::new(source);
assert_eq!(demuxer.stream_count(), 0);
}
}