use std::collections::HashMap;
const TS_PACKET_SIZE: usize = 188;
const SYNC_BYTE: u8 = 0x47;
const PAT_PID: u16 = 0;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamType {
H264,
H265,
Aac,
Unknown(u8),
}
impl StreamType {
fn from_byte(b: u8) -> Self {
match b {
0x1B => Self::H264,
0x24 => Self::H265,
0x0F | 0x11 => Self::Aac,
other => Self::Unknown(other),
}
}
}
#[derive(Debug, Clone)]
pub struct PesPacket {
pub pid: u16,
pub stream_type: StreamType,
pub pts: Option<u64>,
pub dts: Option<u64>,
pub payload: Vec<u8>,
}
#[derive(Debug)]
struct PesBuffer {
stream_type: StreamType,
buf: Vec<u8>,
started: bool,
}
#[derive(Debug)]
pub struct TsDemuxer {
remainder: Vec<u8>,
pmt_pid: Option<u16>,
streams: HashMap<u16, StreamType>,
pes_bufs: HashMap<u16, PesBuffer>,
}
impl Default for TsDemuxer {
fn default() -> Self {
Self::new()
}
}
impl TsDemuxer {
pub fn new() -> Self {
Self {
remainder: Vec::new(),
pmt_pid: None,
streams: HashMap::new(),
pes_bufs: HashMap::new(),
}
}
pub fn feed(&mut self, data: &[u8]) -> Vec<PesPacket> {
let mut out = Vec::new();
let input = if self.remainder.is_empty() {
data
} else {
self.remainder.extend_from_slice(data);
self.process_buf(&mut out);
&[]
};
let mut pos = 0;
while pos < input.len() {
let sync_off = match input[pos..].iter().position(|&b| b == SYNC_BYTE) {
Some(p) => p,
None => break,
};
pos += sync_off;
if pos + TS_PACKET_SIZE > input.len() {
break;
}
let pkt: &[u8; TS_PACKET_SIZE] = input[pos..pos + TS_PACKET_SIZE].try_into().unwrap();
self.process_packet(pkt, &mut out);
pos += TS_PACKET_SIZE;
}
if pos < input.len() {
self.remainder.extend_from_slice(&input[pos..]);
}
out
}
fn process_buf(&mut self, out: &mut Vec<PesPacket>) {
let mut pos = 0;
while pos < self.remainder.len() {
let sync_off = match self.remainder[pos..].iter().position(|&b| b == SYNC_BYTE) {
Some(p) => p,
None => {
self.remainder.clear();
return;
}
};
pos += sync_off;
if pos + TS_PACKET_SIZE > self.remainder.len() {
break;
}
let pkt: [u8; TS_PACKET_SIZE] = self.remainder[pos..pos + TS_PACKET_SIZE].try_into().unwrap();
self.process_packet(&pkt, out);
pos += TS_PACKET_SIZE;
}
if pos > 0 {
self.remainder.drain(..pos);
}
}
fn process_packet(&mut self, pkt: &[u8; TS_PACKET_SIZE], out: &mut Vec<PesPacket>) {
let pid = (((pkt[1] & 0x1F) as u16) << 8) | pkt[2] as u16;
let pusi = pkt[1] & 0x40 != 0;
let afc = (pkt[3] >> 4) & 0x03;
let payload_offset = match afc {
0b01 => 4,
0b11 => {
let af_len = pkt[4] as usize;
5 + af_len
}
_ => return,
};
if payload_offset >= TS_PACKET_SIZE {
return;
}
let payload = &pkt[payload_offset..];
if pid == PAT_PID {
self.parse_pat(payload, pusi);
} else if Some(pid) == self.pmt_pid {
self.parse_pmt(payload, pusi);
} else if self.streams.contains_key(&pid) {
self.push_pes(pid, payload, pusi, out);
}
}
fn parse_pat(&mut self, payload: &[u8], pusi: bool) {
let data = if pusi && !payload.is_empty() {
let pointer = payload[0] as usize;
if 1 + pointer >= payload.len() {
return;
}
&payload[1 + pointer..]
} else {
payload
};
if data.len() < 12 {
return;
}
let section_length = (((data[1] & 0x0F) as usize) << 8) | data[2] as usize;
let table_end = 3 + section_length;
if table_end > data.len() || section_length < 9 {
return;
}
let loop_end = table_end.saturating_sub(4);
let mut i = 8;
while i + 4 <= loop_end {
let prog_num = ((data[i] as u16) << 8) | data[i + 1] as u16;
let map_pid = (((data[i + 2] & 0x1F) as u16) << 8) | data[i + 3] as u16;
if prog_num != 0 {
self.pmt_pid = Some(map_pid);
break;
}
i += 4;
}
}
fn parse_pmt(&mut self, payload: &[u8], pusi: bool) {
let data = if pusi && !payload.is_empty() {
let pointer = payload[0] as usize;
if 1 + pointer >= payload.len() {
return;
}
&payload[1 + pointer..]
} else {
payload
};
if data.len() < 16 {
return;
}
let section_length = (((data[1] & 0x0F) as usize) << 8) | data[2] as usize;
let table_end = 3 + section_length;
if table_end > data.len() || section_length < 13 {
return;
}
let prog_info_len = (((data[10] & 0x0F) as usize) << 8) | data[11] as usize;
let mut i = 12 + prog_info_len;
let loop_end = table_end.saturating_sub(4);
self.streams.clear();
while i + 5 <= loop_end {
let st = data[i];
let es_pid = (((data[i + 1] & 0x1F) as u16) << 8) | data[i + 2] as u16;
let es_info_len = (((data[i + 3] & 0x0F) as usize) << 8) | data[i + 4] as usize;
self.streams.insert(es_pid, StreamType::from_byte(st));
i += 5 + es_info_len;
}
}
fn push_pes(&mut self, pid: u16, payload: &[u8], pusi: bool, out: &mut Vec<PesPacket>) {
let stream_type = *self.streams.get(&pid).unwrap_or(&StreamType::Unknown(0));
if pusi {
if let Some(buf) = self.pes_bufs.get_mut(&pid) {
if buf.started && !buf.buf.is_empty() {
if let Some(pkt) = Self::finish_pes(pid, buf) {
out.push(pkt);
}
}
}
let entry = self.pes_bufs.entry(pid).or_insert_with(|| PesBuffer {
stream_type,
buf: Vec::with_capacity(64 * 1024),
started: false,
});
entry.buf.clear();
entry.buf.extend_from_slice(payload);
entry.started = true;
entry.stream_type = stream_type;
} else if let Some(buf) = self.pes_bufs.get_mut(&pid) {
if buf.started {
buf.extend(payload);
}
}
}
fn finish_pes(pid: u16, buf: &mut PesBuffer) -> Option<PesPacket> {
let data = &buf.buf;
if data.len() < 9 || data[0] != 0 || data[1] != 0 || data[2] != 1 {
return None;
}
let pes_packet_length = ((data[4] as usize) << 8) | data[5] as usize;
let header_data_len = data[8] as usize;
let es_start = 9 + header_data_len;
if es_start > data.len() {
return None;
}
let flags = data[7];
let pts_flag = flags & 0x80 != 0;
let dts_flag = flags & 0x40 != 0;
let pts = if pts_flag && header_data_len >= 5 {
Some(parse_ts_timestamp(&data[9..14]))
} else {
None
};
let dts = if dts_flag && header_data_len >= 10 {
Some(parse_ts_timestamp(&data[14..19]))
} else {
None
};
let es_end = if pes_packet_length > 0 {
(6 + pes_packet_length).min(data.len())
} else {
data.len()
};
let payload = data[es_start..es_end].to_vec();
if payload.is_empty() {
return None;
}
Some(PesPacket {
pid,
stream_type: buf.stream_type,
pts,
dts,
payload,
})
}
}
impl PesBuffer {
fn extend(&mut self, data: &[u8]) {
self.buf.extend_from_slice(data);
}
}
fn parse_ts_timestamp(b: &[u8]) -> u64 {
let a = ((b[0] as u64 >> 1) & 0x07) << 30;
let bc = ((b[1] as u64) << 7 | (b[2] as u64 >> 1)) << 15;
let de = (b[3] as u64) << 7 | (b[4] as u64 >> 1);
a | bc | de
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ts_packet(pid: u16, pusi: bool, payload: &[u8]) -> [u8; 188] {
let mut pkt = [0xFFu8; 188];
pkt[0] = SYNC_BYTE;
pkt[1] = if pusi { 0x40 } else { 0x00 } | ((pid >> 8) as u8 & 0x1F);
pkt[2] = pid as u8;
pkt[3] = 0x10; let copy_len = payload.len().min(184);
pkt[4..4 + copy_len].copy_from_slice(&payload[..copy_len]);
pkt
}
fn minimal_pat(pmt_pid: u16) -> Vec<u8> {
let mut data = vec![
0x00, 0x00, 0xB0, 0x0D, 0x00, 0x01, 0xC1, 0x00, 0x00, 0x00, 0x01, ];
data.push(0xE0 | ((pmt_pid >> 8) as u8 & 0x1F));
data.push(pmt_pid as u8);
data.extend_from_slice(&[0x00; 4]); data
}
fn minimal_pmt(video_pid: u16, audio_pid: u16) -> Vec<u8> {
let mut data = vec![
0x00, 0x02, 0xB0, 0x17, 0x00, 0x01, 0xC1, 0x00, 0x00, 0xE1, 0x00, 0xF0, 0x00, ];
data.push(0x1B); data.push(0xE0 | ((video_pid >> 8) as u8 & 0x1F));
data.push(video_pid as u8);
data.push(0xF0);
data.push(0x00); data.push(0x0F); data.push(0xE0 | ((audio_pid >> 8) as u8 & 0x1F));
data.push(audio_pid as u8);
data.push(0xF0);
data.push(0x00); data.extend_from_slice(&[0x00; 4]); data
}
fn minimal_pes(pts_90k: u64, es_payload: &[u8]) -> Vec<u8> {
let pes_len = (3 + 5 + es_payload.len()) as u16;
let mut data = vec![
0x00,
0x00,
0x01, 0xE0, (pes_len >> 8) as u8,
pes_len as u8,
0x80, 0x80, 0x05, ];
let pts = pts_90k & 0x1_FFFF_FFFF;
data.push(0x21 | ((pts >> 29) as u8 & 0x0E));
data.push((pts >> 22) as u8);
data.push(0x01 | ((pts >> 14) as u8 & 0xFE));
data.push((pts >> 7) as u8);
data.push(0x01 | ((pts << 1) as u8 & 0xFE));
data.extend_from_slice(es_payload);
data
}
#[test]
fn demux_discovers_streams_and_yields_pes() {
let mut demux = TsDemuxer::new();
let video_pid = 0x100;
let audio_pid = 0x101;
let pmt_pid = 0x1000;
let pat = make_ts_packet(PAT_PID, true, &minimal_pat(pmt_pid));
assert!(demux.feed(&pat).is_empty());
assert_eq!(demux.pmt_pid, Some(pmt_pid));
let pmt = make_ts_packet(pmt_pid, true, &minimal_pmt(video_pid, audio_pid));
assert!(demux.feed(&pmt).is_empty());
assert_eq!(demux.streams.len(), 2);
assert_eq!(demux.streams[&video_pid], StreamType::H264);
assert_eq!(demux.streams[&audio_pid], StreamType::Aac);
let pes = minimal_pes(90_000, b"nalunalunalu");
let pkt = make_ts_packet(video_pid, true, &pes);
assert!(demux.feed(&pkt).is_empty());
let pes2 = minimal_pes(180_000, b"nalu2");
let pkt2 = make_ts_packet(video_pid, true, &pes2);
let packets = demux.feed(&pkt2);
assert_eq!(packets.len(), 1);
assert_eq!(packets[0].pid, video_pid);
assert_eq!(packets[0].stream_type, StreamType::H264);
assert_eq!(packets[0].pts, Some(90_000));
assert_eq!(packets[0].payload, b"nalunalunalu");
}
#[test]
fn sync_recovery_skips_garbage() {
let mut demux = TsDemuxer::new();
let pmt_pid = 0x1000;
let mut data = vec![0xDE, 0xAD, 0xBE, 0xEF];
data.extend_from_slice(&make_ts_packet(PAT_PID, true, &minimal_pat(pmt_pid)));
demux.feed(&data);
assert_eq!(demux.pmt_pid, Some(pmt_pid));
}
#[test]
fn cross_call_buffering_handles_partial_packets() {
let mut demux = TsDemuxer::new();
let pmt_pid = 0x1000;
let full = make_ts_packet(PAT_PID, true, &minimal_pat(pmt_pid));
demux.feed(&full[..100]);
assert_eq!(demux.pmt_pid, None);
demux.feed(&full[100..]);
assert_eq!(demux.pmt_pid, Some(pmt_pid));
}
#[test]
fn parse_ts_timestamp_round_trips() {
let pts: u64 = 123_456_789;
let mut buf = [0u8; 5];
buf[0] = 0x21 | ((pts >> 29) as u8 & 0x0E);
buf[1] = (pts >> 22) as u8;
buf[2] = 0x01 | ((pts >> 14) as u8 & 0xFE);
buf[3] = (pts >> 7) as u8;
buf[4] = 0x01 | ((pts << 1) as u8 & 0xFE);
assert_eq!(parse_ts_timestamp(&buf), pts);
}
}