#![forbid(unsafe_code)]
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct TsPacket {
pub pid: u16,
pub payload_unit_start: bool,
pub continuity_counter: u8,
pub has_adaptation: bool,
pub has_payload: bool,
pub payload: Vec<u8>,
pub pts: Option<u64>,
pub dts: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct Pat {
pub transport_stream_id: u16,
pub programs: Vec<PatEntry>,
}
#[derive(Debug, Clone)]
pub struct PatEntry {
pub program_number: u16,
pub pmt_pid: u16,
}
#[derive(Debug, Clone)]
pub struct Pmt {
pub program_number: u16,
pub pcr_pid: u16,
pub streams: Vec<PmtStream>,
}
#[derive(Debug, Clone)]
pub struct PmtStream {
pub stream_type: u8,
pub elementary_pid: u16,
pub descriptors: Vec<u8>,
}
#[derive(Debug, Clone, Default)]
pub struct PidInfo {
pub pid: u16,
pub stream_type: Option<u8>,
pub packet_count: u64,
pub total_bytes: u64,
pub pts_first: Option<u64>,
pub pts_last: Option<u64>,
}
#[derive(Debug, Clone, Default)]
pub struct TsStreamInfo {
pub pat: Option<Pat>,
pub pmts: HashMap<u16, Pmt>,
pub pids: HashMap<u16, PidInfo>,
}
#[must_use]
pub fn parse_ts_packet(data: &[u8], offset: usize) -> Option<TsPacket> {
const TS_SIZE: usize = 188;
const SYNC: u8 = 0x47;
if offset + TS_SIZE > data.len() {
return None;
}
let pkt = &data[offset..offset + TS_SIZE];
if pkt[0] != SYNC {
return None;
}
let transport_error = (pkt[1] & 0x80) != 0;
if transport_error {
return None;
}
let payload_unit_start = (pkt[1] & 0x40) != 0;
let pid = (u16::from(pkt[1] & 0x1F) << 8) | u16::from(pkt[2]);
let adaptation_field_control = (pkt[3] >> 4) & 0x03;
let continuity_counter = pkt[3] & 0x0F;
let has_adaptation = (adaptation_field_control & 0x02) != 0;
let has_payload = (adaptation_field_control & 0x01) != 0;
let mut payload_start = 4usize;
if has_adaptation {
let af_len = pkt[4] as usize;
payload_start = 5 + af_len;
if payload_start > TS_SIZE {
payload_start = TS_SIZE;
}
}
let payload = if has_payload && payload_start < TS_SIZE {
pkt[payload_start..].to_vec()
} else {
Vec::new()
};
let (pts, dts) = if payload_unit_start && payload.len() >= 9 {
extract_pes_timestamps(&payload)
} else {
(None, None)
};
Some(TsPacket {
pid,
payload_unit_start,
continuity_counter,
has_adaptation,
has_payload,
payload,
pts,
dts,
})
}
#[must_use]
pub fn parse_pat(payload: &[u8]) -> Option<Pat> {
if payload.len() < 8 {
return None;
}
if payload[0] != 0x00 {
return None; }
let section_length = (usize::from(payload[1] & 0x0F) << 8) | usize::from(payload[2]);
if payload.len() < section_length + 3 {
return None;
}
let transport_stream_id = (u16::from(payload[3]) << 8) | u16::from(payload[4]);
let entries_end = 3 + section_length - 4;
let mut programs = Vec::new();
let mut offset = 8usize;
while offset + 4 <= entries_end {
let program_number = (u16::from(payload[offset]) << 8) | u16::from(payload[offset + 1]);
let pmt_pid = (u16::from(payload[offset + 2] & 0x1F) << 8) | u16::from(payload[offset + 3]);
if program_number != 0 {
programs.push(PatEntry {
program_number,
pmt_pid,
});
}
offset += 4;
}
Some(Pat {
transport_stream_id,
programs,
})
}
#[must_use]
pub fn parse_pmt(payload: &[u8]) -> Option<Pmt> {
if payload.len() < 12 {
return None;
}
if payload[0] != 0x02 {
return None; }
let section_length = (usize::from(payload[1] & 0x0F) << 8) | usize::from(payload[2]);
if payload.len() < section_length + 3 {
return None;
}
let program_number = (u16::from(payload[3]) << 8) | u16::from(payload[4]);
let pcr_pid = (u16::from(payload[8] & 0x1F) << 8) | u16::from(payload[9]);
let program_info_length = (usize::from(payload[10] & 0x0F) << 8) | usize::from(payload[11]);
let streams_start = 12 + program_info_length;
let streams_end = 3 + section_length - 4;
if streams_start > streams_end {
return None;
}
let mut streams = Vec::new();
let mut offset = streams_start;
while offset + 5 <= streams_end {
let stream_type = payload[offset];
let elementary_pid =
(u16::from(payload[offset + 1] & 0x1F) << 8) | u16::from(payload[offset + 2]);
let es_info_length =
(usize::from(payload[offset + 3] & 0x0F) << 8) | usize::from(payload[offset + 4]);
let desc_end = offset + 5 + es_info_length;
let descriptors = if desc_end <= streams_end {
payload[offset + 5..desc_end].to_vec()
} else {
Vec::new()
};
streams.push(PmtStream {
stream_type,
elementary_pid,
descriptors,
});
offset += 5 + es_info_length;
}
Some(Pmt {
program_number,
pcr_pid,
streams,
})
}
fn extract_pes_timestamps(payload: &[u8]) -> (Option<u64>, Option<u64>) {
if payload.len() < 9 {
return (None, None);
}
if payload[0] != 0x00 || payload[1] != 0x00 || payload[2] != 0x01 {
return (None, None);
}
if payload.len() < 14 {
return (None, None);
}
let pts_dts_flags = (payload[7] >> 6) & 0x03;
let header_data_length = payload[8] as usize;
if payload.len() < 9 + header_data_length {
return (None, None);
}
let mut pts: Option<u64> = None;
let mut dts: Option<u64> = None;
if pts_dts_flags >= 0x02 && payload.len() >= 9 + 5 {
pts = Some(decode_timestamp(&payload[9..14]));
}
if pts_dts_flags == 0x03 && payload.len() >= 14 + 5 {
dts = Some(decode_timestamp(&payload[14..19]));
}
(pts, dts)
}
fn decode_timestamp(b: &[u8]) -> u64 {
let hi = (u64::from(b[0]) & 0x0E) << 29;
let mid = u64::from(b[1]) << 22;
let lo1 = (u64::from(b[2]) & 0xFE) << 14;
let lo2 = u64::from(b[3]) << 7;
let lo3 = u64::from(b[4]) >> 1;
hi | mid | lo1 | lo2 | lo3
}
#[derive(Debug, Default)]
struct SectionBuf {
data: Vec<u8>,
expected: Option<usize>,
}
impl SectionBuf {
fn push(&mut self, payload: &[u8], unit_start: bool) -> Option<Vec<u8>> {
if unit_start {
self.data.clear();
self.expected = None;
let pointer = *payload.first()? as usize;
let start = 1 + pointer;
if start >= payload.len() {
return None;
}
self.data.extend_from_slice(&payload[start..]);
} else {
self.data.extend_from_slice(payload);
}
if self.expected.is_none() && self.data.len() >= 3 {
let sec_len = (usize::from(self.data[1] & 0x0F) << 8) | usize::from(self.data[2]);
self.expected = Some(sec_len + 3);
}
if let Some(exp) = self.expected {
if self.data.len() >= exp {
let section = self.data[..exp].to_vec();
self.data.clear();
self.expected = None;
return Some(section);
}
}
None
}
}
const PAT_PID: u16 = 0x0000;
#[derive(Debug, Default)]
pub struct TsDemuxer {
info: TsStreamInfo,
section_bufs: HashMap<u16, SectionBuf>,
pmt_pids: Vec<u16>,
}
impl TsDemuxer {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn feed(&mut self, data: &[u8]) -> Vec<TsPacket> {
let mut packets = Vec::new();
let mut offset = 0usize;
while offset + 188 <= data.len() {
if data[offset] != 0x47 {
offset += 1;
continue;
}
if let Some(pkt) = parse_ts_packet(data, offset) {
self.ingest_packet(&pkt);
packets.push(pkt);
}
offset += 188;
}
packets
}
#[must_use]
pub fn stream_info(&self) -> &TsStreamInfo {
&self.info
}
#[must_use]
pub fn duration_ms(&self) -> Option<u64> {
let mut global_first: Option<u64> = None;
let mut global_last: Option<u64> = None;
for pid_info in self.info.pids.values() {
if let Some(f) = pid_info.pts_first {
global_first = Some(match global_first {
Some(cur) => cur.min(f),
None => f,
});
}
if let Some(l) = pid_info.pts_last {
global_last = Some(match global_last {
Some(cur) => cur.max(l),
None => l,
});
}
}
match (global_first, global_last) {
(Some(f), Some(l)) if l > f => {
Some((l - f) / 90)
}
_ => None,
}
}
fn ingest_packet(&mut self, pkt: &TsPacket) {
{
let pid_info = self.info.pids.entry(pkt.pid).or_insert_with(|| PidInfo {
pid: pkt.pid,
..Default::default()
});
pid_info.packet_count += 1;
pid_info.total_bytes += pkt.payload.len() as u64;
if let Some(pts) = pkt.pts {
if pid_info.pts_first.is_none() {
pid_info.pts_first = Some(pts);
}
pid_info.pts_last = Some(pts);
}
}
if pkt.pid == PAT_PID {
self.handle_pat(pkt);
return;
}
if self.pmt_pids.contains(&pkt.pid) {
self.handle_pmt(pkt);
}
}
fn handle_pat(&mut self, pkt: &TsPacket) {
let buf = self.section_bufs.entry(PAT_PID).or_default();
if let Some(section) = buf.push(&pkt.payload, pkt.payload_unit_start) {
if let Some(pat) = parse_pat(§ion) {
for entry in &pat.programs {
if !self.pmt_pids.contains(&entry.pmt_pid) {
self.pmt_pids.push(entry.pmt_pid);
}
}
self.info.pat = Some(pat);
}
}
}
fn handle_pmt(&mut self, pkt: &TsPacket) {
let buf = self.section_bufs.entry(pkt.pid).or_default();
if let Some(section) = buf.push(&pkt.payload, pkt.payload_unit_start) {
if let Some(pmt) = parse_pmt(§ion) {
for stream in &pmt.streams {
let pid_info =
self.info
.pids
.entry(stream.elementary_pid)
.or_insert_with(|| PidInfo {
pid: stream.elementary_pid,
..Default::default()
});
pid_info.stream_type = Some(stream.stream_type);
}
self.info.pmts.insert(pmt.program_number, pmt);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ts_packet(pid: u16, pusi: bool, cc: u8, payload: &[u8]) -> Vec<u8> {
let mut pkt = vec![0u8; 188];
pkt[0] = 0x47; let pusi_bit: u8 = if pusi { 0x40 } else { 0x00 };
pkt[1] = pusi_bit | ((pid >> 8) as u8 & 0x1F);
pkt[2] = (pid & 0xFF) as u8;
pkt[3] = 0x10 | (cc & 0x0F); let copy_len = payload.len().min(184);
pkt[4..4 + copy_len].copy_from_slice(&payload[..copy_len]);
pkt
}
fn make_pat_section(tsid: u16, program_number: u16, pmt_pid: u16) -> Vec<u8> {
let mut s = Vec::new();
s.push(0x00);
let _section_length: u16 = 4 + 4 + 4; let sl: u16 = 13;
s.push(0xB0 | ((sl >> 8) as u8 & 0x0F));
s.push((sl & 0xFF) as u8);
s.push((tsid >> 8) as u8);
s.push((tsid & 0xFF) as u8);
s.push(0xC1);
s.push(0x00);
s.push(0x00);
s.push((program_number >> 8) as u8);
s.push((program_number & 0xFF) as u8);
s.push(0xE0 | ((pmt_pid >> 8) as u8 & 0x1F));
s.push((pmt_pid & 0xFF) as u8);
s.extend_from_slice(&[0u8; 4]);
s
}
fn make_pmt_section(
program_number: u16,
pcr_pid: u16,
stream_type: u8,
es_pid: u16,
) -> Vec<u8> {
let mut s = Vec::new();
s.push(0x02);
let sl: u16 = 18;
s.push(0xB0 | ((sl >> 8) as u8 & 0x0F));
s.push((sl & 0xFF) as u8);
s.push((program_number >> 8) as u8);
s.push((program_number & 0xFF) as u8);
s.push(0xC1); s.push(0x00); s.push(0x00); s.push(0xE0 | ((pcr_pid >> 8) as u8 & 0x1F));
s.push((pcr_pid & 0xFF) as u8);
s.push(0xF0); s.push(0x00);
s.push(stream_type);
s.push(0xE0 | ((es_pid >> 8) as u8 & 0x1F));
s.push((es_pid & 0xFF) as u8);
s.push(0xF0); s.push(0x00);
s.extend_from_slice(&[0u8; 4]);
s
}
#[test]
fn test_parse_ts_packet_bad_sync() {
let mut data = vec![0u8; 188];
data[0] = 0x00; assert!(parse_ts_packet(&data, 0).is_none());
}
#[test]
fn test_parse_ts_packet_pid() {
let pkt = make_ts_packet(0x0100, false, 0, &[]);
let parsed = parse_ts_packet(&pkt, 0).expect("valid packet");
assert_eq!(parsed.pid, 0x0100);
}
#[test]
fn test_parse_ts_packet_pusi() {
let pkt = make_ts_packet(0x0200, true, 3, &[]);
let parsed = parse_ts_packet(&pkt, 0).expect("valid packet");
assert!(parsed.payload_unit_start);
assert_eq!(parsed.continuity_counter, 3);
}
#[test]
fn test_parse_ts_packet_offset() {
let mut buf = vec![0xFFu8; 200];
let pkt = make_ts_packet(0x0042, false, 7, &[0xAB; 10]);
buf[12..200].copy_from_slice(&pkt);
let parsed = parse_ts_packet(&buf, 12).expect("valid packet");
assert_eq!(parsed.pid, 0x0042);
assert_eq!(parsed.continuity_counter, 7);
}
#[test]
fn test_parse_pat_round_trip() {
let section = make_pat_section(0x0001, 1, 0x0100);
let pat = parse_pat(§ion).expect("valid PAT");
assert_eq!(pat.transport_stream_id, 0x0001);
assert_eq!(pat.programs.len(), 1);
assert_eq!(pat.programs[0].program_number, 1);
assert_eq!(pat.programs[0].pmt_pid, 0x0100);
}
#[test]
fn test_parse_pat_bad_table_id() {
let mut section = make_pat_section(1, 1, 0x0100);
section[0] = 0xFF;
assert!(parse_pat(§ion).is_none());
}
#[test]
fn test_parse_pmt_round_trip() {
let section = make_pmt_section(1, 0x01FF, 0x81, 0x0101);
let pmt = parse_pmt(§ion).expect("valid PMT");
assert_eq!(pmt.program_number, 1);
assert_eq!(pmt.pcr_pid, 0x01FF);
assert_eq!(pmt.streams.len(), 1);
assert_eq!(pmt.streams[0].stream_type, 0x81);
assert_eq!(pmt.streams[0].elementary_pid, 0x0101);
}
#[test]
fn test_parse_pmt_bad_table_id() {
let mut section = make_pmt_section(1, 0x01FF, 0x81, 0x0101);
section[0] = 0x00; assert!(parse_pmt(§ion).is_none());
}
#[test]
fn test_demuxer_feed_count() {
let mut demux = TsDemuxer::new();
let mut stream = Vec::new();
for i in 0..5u16 {
stream.extend(make_ts_packet(0x0100 + i, false, 0, &[]));
}
let pkts = demux.feed(&stream);
assert_eq!(pkts.len(), 5);
}
#[test]
fn test_demuxer_discovers_pat() {
let mut demux = TsDemuxer::new();
let section = make_pat_section(1, 1, 0x0100);
let mut payload = vec![0x00u8]; payload.extend_from_slice(§ion);
let pkt = make_ts_packet(0x0000, true, 0, &payload);
demux.feed(&pkt);
assert!(demux.stream_info().pat.is_some());
}
#[test]
fn test_demuxer_duration_none_without_range() {
let demux = TsDemuxer::new();
assert!(demux.duration_ms().is_none());
}
#[test]
fn test_decode_timestamp_known_value() {
let pts_val: u64 = 0x1_2345_6789;
let b0: u8 = 0x20 | (((pts_val >> 29) & 0x0E) as u8) | 0x01; let b1: u8 = ((pts_val >> 22) & 0xFF) as u8;
let b2: u8 = (((pts_val >> 14) & 0xFE) as u8) | 0x01;
let b3: u8 = ((pts_val >> 7) & 0xFF) as u8;
let b4: u8 = (((pts_val << 1) & 0xFE) as u8) | 0x01;
let decoded = decode_timestamp(&[b0, b1, b2, b3, b4]);
assert_eq!(decoded, pts_val);
}
}