pub mod ebml;
pub mod matroska_v4;
pub mod parser;
pub mod types;
use std::collections::HashMap;
use std::io::SeekFrom;
use async_trait::async_trait;
use bytes::Bytes;
use oximedia_core::{OxiError, OxiResult, Rational, Timestamp};
use oximedia_io::MediaSource;
use crate::demux::Demuxer;
use crate::DecodeSkipCursor;
use crate::{CodecParams, ContainerFormat, Metadata, Packet, PacketFlags, ProbeResult, StreamInfo};
use ebml::element_id;
use parser::MatroskaParser;
#[allow(clippy::wildcard_imports)]
use types::*;
const DEFAULT_BUFFER_SIZE: usize = 64 * 1024;
const MAX_HEADER_SIZE: usize = 16 * 1024 * 1024;
const NS_PER_MS: u64 = 1_000_000;
pub struct MatroskaDemuxer<R> {
source: R,
buffer: Vec<u8>,
ebml_header: Option<EbmlHeader>,
segment_info: Option<SegmentInfo>,
tracks: Vec<TrackEntry>,
cues: Vec<CuePoint>,
editions: Vec<Edition>,
tags: Vec<Tag>,
seek_entries: Vec<SeekEntry>,
streams: Vec<StreamInfo>,
track_to_stream: HashMap<u64, usize>,
segment_start: u64,
current_cluster: Option<ClusterState>,
position: u64,
eof: bool,
header_parsed: bool,
format: Option<ContainerFormat>,
skip_until: Option<u64>,
}
impl<R> MatroskaDemuxer<R> {
#[must_use]
pub fn new(source: R) -> Self {
Self {
source,
buffer: Vec::with_capacity(DEFAULT_BUFFER_SIZE),
ebml_header: None,
segment_info: None,
tracks: Vec::new(),
cues: Vec::new(),
editions: Vec::new(),
tags: Vec::new(),
seek_entries: Vec::new(),
streams: Vec::new(),
track_to_stream: HashMap::new(),
segment_start: 0,
current_cluster: None,
position: 0,
eof: false,
header_parsed: false,
format: None,
skip_until: None,
}
}
#[must_use]
pub const fn source(&self) -> &R {
&self.source
}
pub fn source_mut(&mut self) -> &mut R {
&mut self.source
}
#[must_use]
#[allow(dead_code)]
pub fn into_source(self) -> R {
self.source
}
#[must_use]
pub fn segment_info(&self) -> Option<&SegmentInfo> {
self.segment_info.as_ref()
}
#[must_use]
pub fn tracks(&self) -> &[TrackEntry] {
&self.tracks
}
#[must_use]
pub fn cues(&self) -> &[CuePoint] {
&self.cues
}
#[must_use]
pub fn editions(&self) -> &[Edition] {
&self.editions
}
#[must_use]
pub fn tags(&self) -> &[Tag] {
&self.tags
}
}
impl<R: MediaSource> MatroskaDemuxer<R> {
async fn ensure_buffer(&mut self, min_size: usize) -> OxiResult<()> {
while self.buffer.len() < min_size && !self.eof {
let mut temp = vec![0u8; DEFAULT_BUFFER_SIZE];
let n = self.source.read(&mut temp).await?;
if n == 0 {
self.eof = true;
break;
}
self.buffer.extend_from_slice(&temp[..n]);
self.position += n as u64;
}
Ok(())
}
async fn seek_to_position(&mut self, position: u64) -> OxiResult<()> {
self.source.seek(SeekFrom::Start(position)).await?;
self.position = position;
self.buffer.clear();
self.eof = false;
self.current_cluster = None;
Ok(())
}
fn consume_buffer(&mut self, n: usize) {
self.buffer.drain(..n);
}
#[allow(dead_code)]
async fn read_bytes(&mut self, size: usize) -> OxiResult<Vec<u8>> {
self.ensure_buffer(size).await?;
if self.buffer.len() < size {
return Err(OxiError::UnexpectedEof);
}
let data = self.buffer[..size].to_vec();
self.consume_buffer(size);
Ok(data)
}
#[allow(dead_code)]
async fn seek_to(&mut self, pos: u64) -> OxiResult<()> {
self.source.seek(SeekFrom::Start(pos)).await?;
self.position = pos;
self.buffer.clear();
self.eof = false;
Ok(())
}
async fn parse_headers(&mut self) -> OxiResult<()> {
self.ensure_buffer(MAX_HEADER_SIZE.min(1024 * 1024)).await?;
let (header, consumed) = parser::parse_ebml_header(&self.buffer)?;
self.ebml_header = Some(header);
self.consume_buffer(consumed);
if let Some(ref hdr) = self.ebml_header {
self.format = Some(match hdr.doc_type {
DocType::WebM => ContainerFormat::WebM,
DocType::Matroska => ContainerFormat::Matroska,
});
}
self.ensure_buffer(12).await?;
let mut parser = MatroskaParser::new(&self.buffer);
let segment = parser.read_element()?;
if segment.id != element_id::SEGMENT {
return Err(OxiError::Parse {
offset: self.position,
message: format!("Expected Segment, got 0x{:X}", segment.id),
});
}
self.segment_start = consumed as u64 + segment.header_size as u64;
self.consume_buffer(segment.header_size);
self.parse_segment_children().await?;
self.build_streams();
self.header_parsed = true;
Ok(())
}
async fn parse_segment_children(&mut self) -> OxiResult<()> {
loop {
self.ensure_buffer(12).await?;
if self.buffer.is_empty() {
break;
}
let mut parser = MatroskaParser::new(&self.buffer);
let Ok(element) = parser.read_element() else {
break;
};
if element.id == element_id::CLUSTER {
break;
}
let header_size = element.header_size;
if element.is_unbounded() {
break;
}
#[allow(clippy::cast_possible_truncation)]
let data_size = element.size as usize;
let total_size = header_size + data_size;
self.ensure_buffer(total_size).await?;
if self.buffer.len() < total_size {
break;
}
let element_data = &self.buffer[header_size..total_size];
match element.id {
element_id::SEEK_HEAD => {
self.seek_entries = parser::parse_seek_head(element_data, element.size)?;
}
element_id::INFO => {
self.segment_info =
Some(parser::parse_segment_info(element_data, element.size)?);
}
element_id::TRACKS => {
self.tracks = parser::parse_tracks(element_data, element.size)?;
}
element_id::CUES => {
self.cues = parser::parse_cues(element_data, element.size)?;
}
element_id::CHAPTERS => {
self.editions = parser::parse_chapters(element_data, element.size)?;
}
element_id::TAGS => {
self.tags = parser::parse_tags(element_data, element.size)?;
}
_ => {}
}
self.consume_buffer(total_size);
}
Ok(())
}
#[allow(clippy::cast_possible_wrap)]
fn build_streams(&mut self) {
let timecode_scale = self
.segment_info
.as_ref()
.map_or(NS_PER_MS, |info| info.timecode_scale);
let timebase = Rational::new(timecode_scale as i64, 1_000_000_000);
for (index, track) in self.tracks.iter().enumerate() {
let codec = match &track.oxi_codec {
Some(c) => *c,
None => continue, };
let mut stream = StreamInfo::new(index, codec, timebase);
if let Some(ref info) = self.segment_info {
if let Some(duration) = info.duration {
#[allow(clippy::cast_possible_truncation)]
{
stream.duration = Some(duration as i64);
}
}
}
match track.track_type {
TrackType::Video => {
if let Some(ref video) = track.video {
stream.codec_params =
CodecParams::video(video.pixel_width, video.pixel_height);
}
}
TrackType::Audio => {
if let Some(ref audio) = track.audio {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
{
stream.codec_params =
CodecParams::audio(audio.sampling_frequency as u32, audio.channels);
}
}
}
_ => {}
}
if let Some(ref private) = track.codec_private {
stream.codec_params.extradata = Some(Bytes::copy_from_slice(private));
}
stream.codec_params.block_addition_mappings = track.block_addition_mappings.clone();
if let Some(ref name) = track.name {
stream.metadata = Metadata::new().with_title(name);
}
if !track.language.is_empty() && track.language != "und" {
stream.metadata = stream.metadata.with_entry("language", &track.language);
}
self.track_to_stream
.insert(track.number, self.streams.len());
self.streams.push(stream);
}
}
#[allow(clippy::cast_possible_truncation)]
async fn read_next_block(&mut self) -> OxiResult<(Block, u64)> {
loop {
self.ensure_buffer(12).await?;
if self.buffer.is_empty() {
return Err(OxiError::Eof);
}
let mut parser = MatroskaParser::new(&self.buffer);
let element = parser.read_element()?;
match element.id {
element_id::CLUSTER => {
self.consume_buffer(element.header_size);
self.current_cluster = Some(ClusterState {
timecode: 0,
position: self.position,
size: if element.is_unbounded() {
None
} else {
Some(element.size)
},
data_position: 0,
});
}
element_id::TIMESTAMP => {
let data_size = element.size as usize;
let total_size = element.header_size + data_size;
self.ensure_buffer(total_size).await?;
let data = &self.buffer[element.header_size..total_size];
let (_, timecode) = ebml::read_uint(data).map_err(|e| OxiError::Parse {
offset: self.position,
message: format!("Failed to parse cluster timestamp: {e:?}"),
})?;
if let Some(ref mut cluster) = self.current_cluster {
cluster.timecode = timecode;
}
self.consume_buffer(total_size);
}
element_id::SIMPLE_BLOCK => {
let data_size = element.size as usize;
let total_size = element.header_size + data_size;
self.ensure_buffer(total_size).await?;
let data = &self.buffer[element.header_size..total_size];
let block = parser::parse_simple_block(data)?;
let cluster_time = self.current_cluster.as_ref().map_or(0, |c| c.timecode);
self.consume_buffer(total_size);
return Ok((block, cluster_time));
}
element_id::BLOCK_GROUP => {
let data_size = element.size as usize;
let total_size = element.header_size + data_size;
self.ensure_buffer(total_size).await?;
let data = &self.buffer[element.header_size..total_size];
let block = parser::parse_block_group(data, element.size)?;
let cluster_time = self.current_cluster.as_ref().map_or(0, |c| c.timecode);
self.consume_buffer(total_size);
return Ok((block, cluster_time));
}
_ => {
if element.is_unbounded() {
self.consume_buffer(element.header_size);
} else {
let total_size = element.header_size + element.size as usize;
self.ensure_buffer(total_size).await?;
self.consume_buffer(total_size.min(self.buffer.len()));
}
}
}
}
}
#[allow(clippy::cast_possible_wrap)]
fn block_to_packet(&self, block: &Block, cluster_time: u64) -> OxiResult<Packet> {
let stream_index = self
.track_to_stream
.get(&block.header.track_number)
.copied()
.ok_or_else(|| OxiError::Parse {
offset: 0,
message: format!("Unknown track number: {}", block.header.track_number),
})?;
let stream = &self.streams[stream_index];
#[allow(clippy::cast_sign_loss)]
let pts = cluster_time as i64 + i64::from(block.header.timecode);
let data = if block.frames.is_empty() {
Bytes::new()
} else {
Bytes::copy_from_slice(&block.frames[0])
};
let mut flags = PacketFlags::empty();
if block.is_keyframe() {
flags |= PacketFlags::KEYFRAME;
}
if block.header.discardable {
flags |= PacketFlags::DISCARD;
}
let timestamp = Timestamp::new(pts, stream.timebase);
Ok(Packet::new(stream_index, data, timestamp, flags))
}
#[allow(clippy::cast_possible_wrap)]
fn find_cue_point(
&self,
target_time: u64,
track_number: u64,
backward: bool,
) -> Option<(u64, u64)> {
if self.cues.is_empty() {
return None;
}
let mut best_cue: Option<(u64, u64)> = None;
for cue_point in &self.cues {
let track_pos = cue_point
.track_positions
.iter()
.find(|tp| tp.track == track_number)?;
let cue_time = cue_point.time;
let cluster_pos = self.segment_start + track_pos.cluster_position;
if backward {
if cue_time <= target_time {
match best_cue {
Some((_, best_time)) if cue_time > best_time => {
best_cue = Some((cluster_pos, cue_time));
}
None => {
best_cue = Some((cluster_pos, cue_time));
}
_ => {}
}
}
} else {
if cue_time >= target_time {
match best_cue {
Some((_, best_time)) if cue_time < best_time => {
best_cue = Some((cluster_pos, cue_time));
}
None => {
best_cue = Some((cluster_pos, cue_time));
}
_ => {}
}
}
}
}
best_cue
}
#[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
pub async fn seek_sample_accurate(&mut self, target_pts: u64) -> OxiResult<DecodeSkipCursor> {
if !self.source.is_seekable() {
return Err(OxiError::unsupported("Source is not seekable"));
}
if !self.header_parsed {
return Err(OxiError::InvalidData(
"Cannot seek before probing".to_string(),
));
}
let ref_track_number: u64 = {
let video_track = self
.tracks
.iter()
.find(|t| t.track_type == types::TrackType::Video)
.or_else(|| self.tracks.first());
video_track.map_or(1, |t| t.number)
};
let seek_pos = if let Some((cluster_pos, _)) =
self.find_cue_point(target_pts, ref_track_number, true)
{
cluster_pos
} else {
self.segment_start
};
self.seek_to_position(seek_pos).await?;
self.skip_until = None;
let mut skip_samples = 0u32;
loop {
let result = self.read_next_block().await;
match result {
Ok((block, cluster_time)) => {
let is_ref = block.header.track_number == ref_track_number;
if !is_ref {
continue;
}
let block_pts = cluster_time as i64 + i64::from(block.header.timecode);
#[allow(clippy::cast_possible_truncation)]
let block_pts_u64 = if block_pts < 0 {
0u64
} else {
block_pts as u64
};
if block_pts_u64 >= target_pts {
self.seek_to_position(seek_pos).await?;
self.skip_until = Some(target_pts);
return Ok(DecodeSkipCursor {
byte_offset: seek_pos,
sample_index: 0,
skip_samples,
target_pts: i64::try_from(target_pts).unwrap_or(i64::MAX),
});
}
skip_samples = skip_samples.saturating_add(1);
}
Err(OxiError::Eof) => {
self.seek_to_position(seek_pos).await?;
self.skip_until = Some(target_pts);
return Ok(DecodeSkipCursor {
byte_offset: seek_pos,
sample_index: 0,
skip_samples,
target_pts: i64::try_from(target_pts).unwrap_or(i64::MAX),
});
}
Err(e) => return Err(e),
}
}
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
async fn perform_seek(&mut self, target: crate::SeekTarget) -> OxiResult<()> {
if !self.source.is_seekable() {
return Err(OxiError::unsupported("Source is not seekable"));
}
if !self.header_parsed {
return Err(OxiError::InvalidData(
"Cannot seek before probing".to_string(),
));
}
let stream_index = if let Some(idx) = target.stream_index {
if idx >= self.streams.len() {
return Err(OxiError::InvalidData(format!(
"Stream index {idx} out of range"
)));
}
idx
} else {
self.streams
.iter()
.position(StreamInfo::is_video)
.unwrap_or(0)
};
let track_number = self
.tracks
.get(stream_index)
.map_or(1, |track| track.number);
let timecode_scale = self
.segment_info
.as_ref()
.map_or(NS_PER_MS, |info| info.timecode_scale);
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
let target_timecode = ((target.position * 1_000_000_000.0) / timecode_scale as f64) as u64;
if let Some((cluster_pos, _cue_time)) =
self.find_cue_point(target_timecode, track_number, target.is_backward())
{
self.seek_to_position(cluster_pos).await?;
if target.is_frame_accurate() {
return self
.scan_to_exact_position(target_timecode, stream_index)
.await;
}
return Ok(());
}
self.seek_to_position(self.segment_start).await?;
Ok(())
}
async fn scan_to_exact_position(
&mut self,
target_timecode: u64,
stream_index: usize,
) -> OxiResult<()> {
loop {
let result = self.read_next_block().await;
match result {
Ok((block, cluster_time)) => {
if let Some(&block_stream_index) =
self.track_to_stream.get(&block.header.track_number)
{
if block_stream_index == stream_index {
#[allow(clippy::cast_sign_loss)]
let block_time = cluster_time + block.header.timecode as u64;
if block_time >= target_timecode {
return Ok(());
}
}
}
}
Err(OxiError::Eof) => {
return Ok(());
}
Err(e) => return Err(e),
}
}
}
}
#[async_trait]
impl<R: MediaSource> Demuxer for MatroskaDemuxer<R> {
async fn probe(&mut self) -> OxiResult<ProbeResult> {
if !self.header_parsed {
self.parse_headers().await?;
}
let format = self.format.unwrap_or(ContainerFormat::Matroska);
let confidence = if self.ebml_header.is_some() && !self.tracks.is_empty() {
0.99
} else if self.ebml_header.is_some() {
0.95
} else {
0.5
};
Ok(ProbeResult::new(format, confidence))
}
#[allow(clippy::cast_sign_loss, clippy::cast_possible_wrap)]
async fn read_packet(&mut self) -> OxiResult<Packet> {
if !self.header_parsed {
self.parse_headers().await?;
}
loop {
let (block, cluster_time) = self.read_next_block().await?;
if !self
.track_to_stream
.contains_key(&block.header.track_number)
{
continue;
}
if block.header.invisible {
continue;
}
if let Some(skip_until) = self.skip_until {
let block_pts = cluster_time as i64 + i64::from(block.header.timecode);
#[allow(clippy::cast_possible_truncation)]
let block_pts_u64 = if block_pts < 0 {
0u64
} else {
block_pts as u64
};
if block_pts_u64 < skip_until {
continue;
}
self.skip_until = None;
}
return self.block_to_packet(&block, cluster_time);
}
}
fn streams(&self) -> &[StreamInfo] {
&self.streams
}
async fn seek(&mut self, target: crate::SeekTarget) -> OxiResult<()> {
self.perform_seek(target).await
}
fn is_seekable(&self) -> bool {
self.source.is_seekable() && self.header_parsed
}
}
#[cfg(test)]
mod tests {
use super::*;
use oximedia_io::MemorySource;
fn create_webm_header() -> Vec<u8> {
let mut data = Vec::new();
data.extend_from_slice(&[
0x1A, 0x45, 0xDF, 0xA3, 0x9F, ]);
data.extend_from_slice(&[0x42, 0x86, 0x81, 0x01]);
data.extend_from_slice(&[0x42, 0xF7, 0x81, 0x01]);
data.extend_from_slice(&[0x42, 0xF2, 0x81, 0x04]);
data.extend_from_slice(&[0x42, 0xF3, 0x81, 0x08]);
data.extend_from_slice(&[0x42, 0x82, 0x84, b'w', b'e', b'b', b'm']);
data.extend_from_slice(&[0x42, 0x87, 0x81, 0x04]);
data.extend_from_slice(&[0x42, 0x85, 0x81, 0x02]);
data.extend_from_slice(&[
0x18, 0x53, 0x80, 0x67, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, ]);
data.extend_from_slice(&[
0x15, 0x49, 0xA9, 0x66, 0x8E, ]);
data.extend_from_slice(&[
0x2A, 0xD7, 0xB1, 0x83, 0x0F, 0x42, 0x40, ]);
data.extend_from_slice(&[
0x44, 0x89, 0x84, 0x45, 0x9C, 0x40, 0x00, ]);
data.extend_from_slice(&[
0x16, 0x54, 0xAE, 0x6B, 0x9E, ]);
data.extend_from_slice(&[
0xAE, 0x9C, ]);
data.extend_from_slice(&[0xD7, 0x81, 0x01]);
data.extend_from_slice(&[0x73, 0xC5, 0x82, 0x30, 0x39]);
data.extend_from_slice(&[0x83, 0x81, 0x01]);
data.extend_from_slice(&[0x86, 0x85, b'V', b'_', b'V', b'P', b'9']);
data.extend_from_slice(&[
0xE0, 0x88, ]);
data.extend_from_slice(&[0xB0, 0x82, 0x07, 0x80]);
data.extend_from_slice(&[0xBA, 0x82, 0x04, 0x38]);
data.extend_from_slice(&[
0x1F, 0x43, 0xB6, 0x75, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, ]);
data.extend_from_slice(&[0xE7, 0x81, 0x00]);
let block_data = vec![0x01, 0x02, 0x03, 0x04]; data.extend_from_slice(&[
0xA3, 0x88, 0x81, 0x00, 0x00, 0x80, ]);
data.extend_from_slice(&block_data);
data
}
#[tokio::test]
async fn test_matroska_demuxer_new() {
let source = MemorySource::new(Bytes::new());
let demuxer = MatroskaDemuxer::new(source);
assert!(!demuxer.header_parsed);
assert!(demuxer.streams().is_empty());
}
#[tokio::test]
async fn test_matroska_demuxer_probe() {
let data = create_webm_header();
let source = MemorySource::new(Bytes::from(data));
let mut demuxer = MatroskaDemuxer::new(source);
let result = demuxer.probe().await.expect("probe should succeed");
assert_eq!(result.format, ContainerFormat::WebM);
assert!(result.confidence > 0.9);
assert!(demuxer.header_parsed);
}
#[tokio::test]
async fn test_matroska_demuxer_streams() {
let data = create_webm_header();
let source = MemorySource::new(Bytes::from(data));
let mut demuxer = MatroskaDemuxer::new(source);
demuxer.probe().await.expect("probe should succeed");
let streams = demuxer.streams();
assert!(!streams.is_empty());
assert_eq!(streams[0].codec, oximedia_core::CodecId::Vp9);
}
#[tokio::test]
async fn test_matroska_demuxer_read_packet() {
let data = create_webm_header();
let source = MemorySource::new(Bytes::from(data));
let mut demuxer = MatroskaDemuxer::new(source);
demuxer.probe().await.expect("probe should succeed");
let packet = demuxer
.read_packet()
.await
.expect("operation should succeed");
assert_eq!(packet.stream_index, 0);
assert!(packet.is_keyframe());
assert_eq!(packet.size(), 4);
}
#[tokio::test]
async fn test_matroska_demuxer_eof() {
let data = create_webm_header();
let source = MemorySource::new(Bytes::from(data));
let mut demuxer = MatroskaDemuxer::new(source);
demuxer.probe().await.expect("probe should succeed");
let _ = demuxer
.read_packet()
.await
.expect("operation should succeed");
let result = demuxer.read_packet().await;
assert!(matches!(result, Err(OxiError::Eof)));
}
#[tokio::test]
async fn test_segment_info() {
let data = create_webm_header();
let source = MemorySource::new(Bytes::from(data));
let mut demuxer = MatroskaDemuxer::new(source);
demuxer.probe().await.expect("probe should succeed");
let info = demuxer.segment_info().expect("operation should succeed");
assert_eq!(info.timecode_scale, 1_000_000);
assert!(info.duration.is_some());
}
#[tokio::test]
async fn test_tracks_info() {
let data = create_webm_header();
let source = MemorySource::new(Bytes::from(data));
let mut demuxer = MatroskaDemuxer::new(source);
demuxer.probe().await.expect("probe should succeed");
let tracks = demuxer.tracks();
assert!(!tracks.is_empty());
assert_eq!(tracks[0].number, 1);
assert_eq!(tracks[0].codec_id, "V_VP9");
assert!(tracks[0].video.is_some());
let video = tracks[0].video.as_ref().expect("operation should succeed");
assert_eq!(video.pixel_width, 1920);
assert_eq!(video.pixel_height, 1080);
}
}