mod packet;
mod psi;
pub mod scte35;
use async_trait::async_trait;
use bytes::Bytes;
use oximedia_core::{OxiError, OxiResult, Rational, Timestamp};
use oximedia_io::MediaSource;
use std::collections::HashMap;
use crate::{
CodecParams, ContainerFormat, Demuxer, Metadata, Packet, PacketFlags, ProbeResult, StreamInfo,
};
use packet::{ContinuityTracker, TsPacket, PAT_PID, TS_PACKET_SIZE};
use psi::{ElementaryStreamInfo, ProgramAssociationTable, ProgramMapTable, SectionAssembler};
const MPEG_TS_TIMEBASE: Rational = Rational { num: 1, den: 90000 };
const PES_START_CODE_PREFIX: u32 = 0x0000_0001;
const MAX_PES_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug, Clone)]
struct PesStreamState {
stream_index: usize,
#[allow(dead_code)]
es_info: ElementaryStreamInfo,
buffer: Vec<u8>,
expected_length: usize,
pts: Option<i64>,
dts: Option<i64>,
is_keyframe: bool,
}
impl PesStreamState {
const fn new(stream_index: usize, es_info: ElementaryStreamInfo) -> Self {
Self {
stream_index,
es_info,
buffer: Vec::new(),
expected_length: 0,
pts: None,
dts: None,
is_keyframe: false,
}
}
fn reset(&mut self) {
self.buffer.clear();
self.expected_length = 0;
self.pts = None;
self.dts = None;
self.is_keyframe = false;
}
}
pub struct MpegTsDemuxer<S: MediaSource> {
source: S,
streams: Vec<StreamInfo>,
pat: Option<ProgramAssociationTable>,
pmts: HashMap<u16, ProgramMapTable>,
section_assemblers: HashMap<u16, SectionAssembler>,
pes_streams: HashMap<u16, PesStreamState>,
continuity_tracker: ContinuityTracker,
pcr_base: Option<u64>,
probed: bool,
}
impl<S: MediaSource> MpegTsDemuxer<S> {
#[must_use]
pub fn new(source: S) -> Self {
Self {
source,
streams: Vec::new(),
pat: None,
pmts: HashMap::new(),
section_assemblers: HashMap::new(),
pes_streams: HashMap::new(),
continuity_tracker: ContinuityTracker::new(),
pcr_base: None,
probed: false,
}
}
async fn read_ts_packet(&mut self) -> OxiResult<TsPacket> {
let mut buffer = [0u8; TS_PACKET_SIZE];
let mut bytes_read = 0;
while bytes_read < TS_PACKET_SIZE {
let n = self.source.read(&mut buffer[bytes_read..]).await?;
if n == 0 {
return Err(OxiError::Eof);
}
bytes_read += n;
}
TsPacket::parse(&buffer)
}
fn process_pat(&mut self, packet: &TsPacket) -> OxiResult<()> {
let assembler = self.section_assemblers.entry(PAT_PID).or_default();
if let Some(section_data) = assembler.push(&packet.payload, packet.payload_unit_start) {
let pat = ProgramAssociationTable::parse(§ion_data)?;
for &pmt_pid in pat.programs.values() {
self.section_assemblers.entry(pmt_pid).or_default();
}
self.pat = Some(pat);
}
Ok(())
}
fn process_pmt(&mut self, pid: u16, packet: &TsPacket) -> OxiResult<()> {
let assembler = self.section_assemblers.entry(pid).or_default();
if let Some(section_data) = assembler.push(&packet.payload, packet.payload_unit_start) {
let pmt = ProgramMapTable::parse(§ion_data)?;
for es_info in &pmt.streams {
if let Some(codec_id) = es_info.codec_id {
if !self.pes_streams.contains_key(&es_info.pid) {
let stream_index = self.streams.len();
let stream_info = StreamInfo {
index: stream_index,
codec: codec_id,
media_type: codec_id.media_type(),
timebase: MPEG_TS_TIMEBASE,
duration: None,
codec_params: CodecParams::default(),
metadata: Metadata::default(),
};
self.streams.push(stream_info);
let pes_state = PesStreamState::new(stream_index, es_info.clone());
self.pes_streams.insert(es_info.pid, pes_state);
}
}
}
self.pmts.insert(pmt.program_number, pmt);
}
Ok(())
}
fn process_pes(&mut self, pid: u16, packet: &TsPacket) -> OxiResult<Option<Packet>> {
let pes_state = self
.pes_streams
.get_mut(&pid)
.ok_or_else(|| OxiError::InvalidData(format!("Unknown PES PID: 0x{pid:04X}")))?;
if packet.payload_unit_start {
let has_data = !pes_state.buffer.is_empty();
let _ = pes_state;
let result = if has_data {
self.finalize_pes_packet(pid)
} else {
Ok(None)
};
let pes_state = self
.pes_streams
.get_mut(&pid)
.ok_or_else(|| OxiError::InvalidData(format!("Unknown PES PID: 0x{pid:04X}")))?;
pes_state.reset();
if packet.payload.len() >= 6 {
let start_code = (u32::from(packet.payload[0]) << 16)
| (u32::from(packet.payload[1]) << 8)
| u32::from(packet.payload[2]);
if start_code == PES_START_CODE_PREFIX {
let pes_packet_length =
(u16::from(packet.payload[4]) << 8) | u16::from(packet.payload[5]);
pes_state.expected_length = if pes_packet_length == 0 {
0 } else {
pes_packet_length as usize + 6
};
if packet.payload.len() >= 9 {
let pts_dts_flags = (packet.payload[7] >> 6) & 0x03;
let header_data_length = packet.payload[8] as usize;
let mut offset = 9;
if pts_dts_flags >= 0x02 && packet.payload.len() >= offset + 5 {
pes_state.pts = Some(Self::parse_timestamp(&packet.payload[offset..]));
offset += 5;
}
if pts_dts_flags == 0x03 && packet.payload.len() >= offset + 5 {
pes_state.dts = Some(Self::parse_timestamp(&packet.payload[offset..]));
}
offset = 9 + header_data_length;
if offset < packet.payload.len() {
pes_state
.buffer
.extend_from_slice(&packet.payload[offset..]);
}
}
if packet.is_random_access() {
pes_state.is_keyframe = true;
}
}
}
return result;
}
pes_state.buffer.extend_from_slice(&packet.payload);
if pes_state.buffer.len() > MAX_PES_SIZE {
return Err(OxiError::InvalidData(format!(
"PES packet too large: {} bytes",
pes_state.buffer.len()
)));
}
if pes_state.expected_length > 0 && pes_state.buffer.len() >= pes_state.expected_length - 6
{
return self.finalize_pes_packet(pid);
}
Ok(None)
}
fn finalize_pes_packet(&mut self, pid: u16) -> OxiResult<Option<Packet>> {
let pes_state = self
.pes_streams
.get_mut(&pid)
.ok_or_else(|| OxiError::InvalidData(format!("Unknown PES PID: 0x{pid:04X}")))?;
if pes_state.buffer.is_empty() {
return Ok(None);
}
let data = Bytes::copy_from_slice(&pes_state.buffer);
let stream_index = pes_state.stream_index;
let mut timestamp = Timestamp::new(
pes_state.pts.unwrap_or(0),
self.streams[stream_index].timebase,
);
timestamp.dts = pes_state.dts;
let mut flags = PacketFlags::empty();
if pes_state.is_keyframe {
flags |= PacketFlags::KEYFRAME;
}
Ok(Some(Packet::new(stream_index, data, timestamp, flags)))
}
fn parse_timestamp(data: &[u8]) -> i64 {
((i64::from(data[0]) & 0x0E) << 29)
| (i64::from(data[1]) << 22)
| ((i64::from(data[2]) & 0xFE) << 14)
| (i64::from(data[3]) << 7)
| (i64::from(data[4]) >> 1)
}
fn is_pmt_pid(&self, pid: u16) -> bool {
if let Some(ref pat) = self.pat {
pat.programs.values().any(|&pmt_pid| pmt_pid == pid)
} else {
false
}
}
}
#[async_trait]
impl<S: MediaSource> Demuxer for MpegTsDemuxer<S> {
async fn probe(&mut self) -> OxiResult<ProbeResult> {
const MAX_PROBE_PACKETS: usize = 1000;
if self.probed {
return Ok(ProbeResult::new(ContainerFormat::MpegTs, 0.95));
}
let mut packets_read = 0;
while packets_read < MAX_PROBE_PACKETS {
let ts_packet = self.read_ts_packet().await?;
if let Some(pcr) = ts_packet.pcr() {
if self.pcr_base.is_none() {
self.pcr_base = Some(pcr);
}
}
if ts_packet.is_pat() {
self.process_pat(&ts_packet)?;
} else if self.is_pmt_pid(ts_packet.pid) {
self.process_pmt(ts_packet.pid, &ts_packet)?;
}
if !self.streams.is_empty() {
break;
}
packets_read += 1;
}
if self.streams.is_empty() {
return Err(OxiError::InvalidData(
"No valid streams found in MPEG-TS".to_string(),
));
}
self.probed = true;
Ok(ProbeResult::new(ContainerFormat::MpegTs, 0.95))
}
async fn read_packet(&mut self) -> OxiResult<Packet> {
if !self.probed {
return Err(OxiError::InvalidData(
"Must call probe() before reading packets".to_string(),
));
}
loop {
let ts_packet = self.read_ts_packet().await?;
let has_payload = ts_packet.adaptation_field_control.has_payload();
if !self.continuity_tracker.check(
ts_packet.pid,
ts_packet.continuity_counter,
has_payload,
) {
eprintln!(
"Continuity error on PID 0x{:04X}, CC={}",
ts_packet.pid, ts_packet.continuity_counter
);
}
if let Some(pcr) = ts_packet.pcr() {
if self.pcr_base.is_none() {
self.pcr_base = Some(pcr);
}
}
if ts_packet.is_null() {
continue;
}
if ts_packet.is_pat() {
self.process_pat(&ts_packet)?;
continue;
}
if self.is_pmt_pid(ts_packet.pid) {
self.process_pmt(ts_packet.pid, &ts_packet)?;
continue;
}
if self.pes_streams.contains_key(&ts_packet.pid) {
if let Some(packet) = self.process_pes(ts_packet.pid, &ts_packet)? {
return Ok(packet);
}
}
}
}
fn streams(&self) -> &[StreamInfo] {
&self.streams
}
fn is_seekable(&self) -> bool {
self.source.is_seekable()
}
}
#[cfg(test)]
mod tests {
use super::*;
use oximedia_core::CodecId;
use oximedia_io::MemorySource;
#[tokio::test]
async fn test_parse_timestamp() {
let data = [0x21, 0x00, 0x01, 0x00, 0x01];
let pts = MpegTsDemuxer::<MemorySource>::parse_timestamp(&data);
assert_eq!(pts, 0);
let data = [0x31, 0xFF, 0xFF, 0xFF, 0xFE];
let pts = MpegTsDemuxer::<MemorySource>::parse_timestamp(&data);
assert!(pts > 0);
}
#[test]
fn test_pes_stream_state() {
let es_info = ElementaryStreamInfo {
stream_type: psi::StreamType::Av1,
pid: 0x100,
codec_id: Some(CodecId::Av1),
descriptors: Vec::new(),
};
let mut state = PesStreamState::new(0, es_info);
assert_eq!(state.stream_index, 0);
assert_eq!(state.buffer.len(), 0);
state.buffer.extend_from_slice(&[1, 2, 3, 4]);
assert_eq!(state.buffer.len(), 4);
state.reset();
assert_eq!(state.buffer.len(), 0);
assert!(state.pts.is_none());
}
}